现在的位置: 首页 > 综合 > 正文

POJ 3666 Making the Grade(左偏树)

2019年02月11日 ⁄ 综合 ⁄ 共 1777字 ⁄ 字号 评论关闭

题意:给出一串数,改变当个数字的大小的代价是改动的绝对值(比如2变成5,代价是3),求让这个数串变成非递减(或非递增)数列的最小代价。


下面只讨论非递减的情况(非递增类似):

做法非常神奇,是黄源河前辈论文中的例题:左偏树的特点及其应用


对于一个非递减序列,最小代价是把每个数都变成这个序列的中位数。

而有增有减的序列,可以分段,划分成阶梯状。

至于为什么可以扩展到有增有减,论文里面的解释更为想详细!^ ^

大体思想是:用左边树保存每一段部分的中位数。

把每个数单独建一颗左偏树。因为只有一个数,当然中位数是自己。

树里面只保存(len+1)/2个节点,len为这课左偏树所管理的长度(即影响范围)。

然后从左往右扫,一旦扫到后面的中位数比前面的中位数要小,就把这两棵树合并。

这时候一个很关键的操作,弹去树根:这是保证左偏树保存的是中位数的重要操作。


原来,左偏树只保存了len长度范围内,前(len+1)/2的元素(按大小排序)。

一个影响范围为lena的左偏树,与一个影响范围为lenb的左偏树,合并之后,只会保存(lena+lenb+1)/2个节点;

前者已经存了(lena+1)/2个节点,后者为(lenb+1)/2。

直接合并,会得到一个节点总数为(lena+1)/2+(lenb+1)/2的树。

如果(lena+lenb+1)/2小于(lena+1)/2+(lenb+1)/2,则弹去树根(最大值)。

易得(lena+1)/2+(lenb+1)/2最多只比(lena+lenb+1)/2大1,所以只需删去一个节点,即最大的那个节点

#include <algorithm>
#include <iostream>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
using namespace std;
#define MAXN 2010
#define LL long long
#define min(a, b) (a < b ? a : b)

struct Node
{
    int v, l, r, dis;
    Node() {}
    Node(int _v, int _l, int _r, int _d):
        v(_v), l(_l), r(_r), dis(_d) {}
}nn[2][MAXN];

int merge(Node n[], int x, int y)
{
    if(!x) return y;
    if(!y) return x;
    if(n[x].v < n[y].v) swap(x, y);
    n[x].r = merge(n, n[x].r, y);
    if(n[n[x].l].dis < n[n[x].r].dis) swap(n[x].l, n[x].r);
    n[x].dis = n[n[x].r].dis + 1;
    return x;
}

int N, v[MAXN], len[MAXN], stk[MAXN];
LL ans[2];

void solve(Node n[], int t)
{
    int top = 0;
    for(int i = 0; i < N; i++)
    {
        int ct = 1; int id = i;
        while(top > 0 && n[stk[top - 1]].v > n[id].v)
        {
            top--;
            id = merge(n, stk[top], id);
            if((len[top] + 1) / 2 + (ct + 1) / 2 > (len[top] + ct + 1) / 2)
                id = merge(n, n[id].l, n[id].r);
            ct += len[top];
        }
        len[top] = ct;
        stk[top++] = id;
    }

    for(int i = 0, j = 0; i < top; i++)
    {
        int k = n[stk[i]].v;
        while(len[i]--) ans[t] += abs(v[j++] - k);
    }

}

int main()
{
//    freopen("H.in", "r", stdin);

    while(~scanf("%d", &N))
    {
        memset(len, 0, sizeof(len));
        memset(nn, 0, sizeof(nn));
        for(int i = 0; i < N; i++)
        {
            scanf("%d", &v[i]);
            nn[0][i] = nn[1][N - i + 1] = Node(v[i], 0, 0, 0);
        }
        ans[0] = ans[1] = 0;
        for(int i = 0; i < 2; i++) solve(nn[i], i);
        printf("%I64d\n", min(ans[0], ans[1]));
    }

    return 0;
}

抱歉!评论已关闭.