现在的位置: 首页 > 综合 > 正文

hdu4747 Mex 线段树 (2013网络赛)

2014年07月18日 ⁄ 综合 ⁄ 共 1963字 ⁄ 字号 评论关闭

题意:给你一个序列,让你求出对于所有区间<i, j>的mex和,mex表示该区间没有出现过的最小的整数。

思路:从时限和点数就可以看出是线段树,并且我们可以枚举左端点i, 然后求出所有左端点为i的区间内mex值的和。

先把数插满,然后先询问后删除当前最左边的断点i。而且显然线段树里面保存的是mex值,而且这个序列是非递减的。

分析:我们先预处理出对于右端点为i的所有<1,i>的mex,分别插入线段树的i位置。然后每次删除最左边的左端点i

,假如当前我们要删除a[i] ,我们找到它之后第一个位置j满足a[i] == a[j],  那么区间i------j-1里面的所有mex都要更新,取线段树内的值和a[i]的最小值。 实际操作我们只要找到第一个比a[i]大的位置l,
 r = j-1,  更新<l,r>之间的mex为a[i]即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define lson l, m, rt<<1
#define rson m+1, r, rt<<1|1
#define ls rt<<1
#define rs rt<<1|1
#define Mid int m = l+r>>1
const int maxn = 2000006;
int next[maxn], pre[maxn], n;
int a[maxn], mex;
bool vis[maxn];
ll sum[maxn<<2];
int mx[maxn<<2], col[maxn<<2];
void build(int l=1, int r=n, int rt=1) {
    col[rt] = -1;
    sum[rt] = 0;
    mx[rt] = 0;
    if(l == r) return;
    Mid;
    build(lson);
    build(rson);
}
inline void down(int l, int r, int rt) {
    if(~col[rt]) {
        col[ls] = col[rs] = col[rt];
        Mid;
        sum[ls] = (ll)(m-l+1)*col[rt];
        mx[ls] = mx[rs] = col[rt];
        sum[rs] = (ll)(r-m)*col[rt];
        col[rt] = -1;
    }
}
inline void up(int rt) {
    sum[rt] = sum[ls] + sum[rs];
    mx[rt] = max(mx[ls], mx[rs]);
}
void update(int L, int R, int v, int l=1, int r=n, int rt=1) {
    if(L <= l && r <= R) {
        col[rt] = mx[rt] = v;
        sum[rt] = (ll)(r-l+1)*v;
        return;
    }
    Mid; down(l, r, rt);
    if(L <= m) update(L, R, v, lson);
    if(R > m) update(L, R, v, rson);
    up(rt);
}
ll query(int L, int R, int l=1, int r=n, int rt=1) {
    if(L <= l && r <= R)
        return sum[rt];
    Mid; down(l, r, rt);
    ll ret = 0;
    if(L <= m) ret += query(L, R, lson);
    if(R > m) ret += query(L, R, rson);
    up(rt);
    return ret;
}
int find(int v, int l=1, int r=n, int rt=1) {
    if(mx[rt] <= v) return n+1;
    if(l == r) return l;
    Mid; down(l, r, rt);
    if(mx[ls] > v) return find(v, lson);
    else return find(v, rson);
}
int main() {
    int i, j;
    while(~scanf("%d", &n) && n) {
        for(i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
            pre[i] = vis[i] = 0;
            next[i] = n+1;
        }
        pre[0] = vis[0] = 0;
        for(i = 1; i <= n; i++)
            if(a[i] <= n) {
                if(pre[a[i]])
                    next[pre[a[i]]] = i;
                pre[a[i]] = i;
            }
        build();
        mex = 0;
        for(i = 1; i <= n; i++) {
            if(a[i] <= n){
                vis[a[i]] = 1;
                while(vis[mex]) mex++;
            }
            update(i, i, mex);
        }
        ll ans = 0;
        for(i = 1; i <= n; i++) {
            ans += query(i, n);
            if(a[i] <= mex) {
                int l = max(find(a[i]), i);
                int r = next[i]-1;
                if(l <= r) update(l, r, a[i]);
            }
        }
        printf("%I64d\n", ans);
    }
    return 0;
}
/*
3
0 10000 20000
*/

抱歉!评论已关闭.