现在的位置: 首页 > 综合 > 正文

【伸展树】【NOI2005】维护数列

2014年02月26日 ⁄ 综合 ⁄ 共 9039字 ⁄ 字号 评论关闭

Input
输入文件的第1行包含两个数N和M,N表示初始时数列中数的个数,M表示要进行的操作数目。第2行包含N个数字,描述初始时的数列。以下M行,每行一条命令,格式参见问题描述中的表格。
Output
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
Sample Input
9 8
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM
Sample Output
-1
10
1
10
HINT

不得不说这是一道很牛的题。
我们利用伸展树可以将询问处旋转至根的性质并结合线段树的一些方法可以将此题完美解决。

为了方便处理,需首先在伸展树中加入两个权值为负无穷的虚拟结点(分别位于最左边和最右边),保证插入删除等操作不在空树上进行。(并且空结点也需要虚拟,否则会访问出错。)

1) 插入操作:
    假设需要在pos, pos + 1之间插入若干个数。那么先将pos + 1位置(包含最左边的虚拟结点,下同)的结点旋转至根,再将根的右子树的最左边的结点旋转至根的右子结点,然后将要插入的数串成一条链,接到根的右子树的左子树上,最后将链的末端旋转至根即可(其实这样做一方面是为了减小树的高度,另一方面是将途中的结点全部更新一次)。

2) 删除操作:
    假设需要删除从pos位置开始的tot个数。那么先将pos位置的结点旋转至根,再将根的右子树的tot + 1位置的结点旋转至根的右子结点,然后将根的右子树的左子树置空(相当于一次性全部删除),最后将根的右子结点旋转至根即可(旋转的同时将其更新)。

3) 修改操作:
    给伸展树的每个结点附加上一个懒标记sm(即该子树中是否所有数都相同),旋转操作的时候将标记向下传。
    假设需要将从pos位置开始的tot个数全部改成c,那么先将pos位置的结点旋转至根,再将根的右子树的tot + 1位置的结点旋转至根,然后将根的右子树的左子节点的sm懒标记置为1,最后将它旋转至根即可(旋转的同时将其更新)。

4) 翻转操作:
    给伸展树的每个结点附加上一个懒标记rev(即该子树是否被翻转过)。旋转操作的时候若遇到被翻转过的点则交换其左右子树并将其标记向下传。
    假设需要将从pos位置开始的tot个数全部翻转,那么先将pos位置的结点旋转至根,再将根的右子树的tot + 1位置的结点旋转至右子结点,然后修改根的右子树的左子节点的rev标记,最后将它旋转至根即可(旋转的同时将其更新)。

5) 求和:
    这个比较简单,用一个sum标记维护每棵子树中所有元素的和(特殊地,几个虚拟结点的sum标记都为0)。
    假设需要求出从pos位置开始的tot个数的和,那么先将pos位置的结点旋转至根,再将根的右子树的tot + 1位置的结点旋转至右子结点,再直接返回根的右子树的左子结点的sum标记即可。

6) 求和的最大子列:
    仿照线段树中的一些做法,维护三个标记:用mls和mrs表示当前子树从最左边和最右边开始的最大和,用Max表示当前子树中的最大和。那么,有:
    mls = max{lc.mls, lc.sum + key, lc.sum + key + rc.mls};
    mrs = max{rc.mrs, rc.sum + key, rc.sum + key + lc.mrs};
    Max = max{lc.Max, rc.Max, key, lc.sum + key, rc.sum + key, lc.mrs + key + rc.mls}.
    (注意到,这里用lc表示左子树,rc表示右子树,key表示当前结点的值。)
    那么,只需要一直动态维护好,直接返回根的Max标记即可。
代码:

/*****************************\
 * @prob: NOI2005 sequence   *
 * @auth: Wang Junji         *
 * @stat: Accepted.          *
 * @date: June. 4th, 2012    *
 * @memo: 伸展树代替块状链表   *
\*****************************/
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <string>

using std::max; const int maxN = 3000010, INF = 0x07070707; //

template <typename _Tp> inline void gmax(_Tp &a, _Tp b) {if (b > a) a = b;}

class SplayTree
{
private:
    struct Node
    {
        int key, sum, Max, mls, mrs, sz; bool rev, sm; Node *lc, *rc, *F; Node() {}
        Node(int key): key(key), sz(1), rev(0), sm(0) {Max = sum = mls = mrs = key;}
    } NIL[maxN], *tot, *T, *head, *tail;
    Node *NewNode(int key)
    {Node *T = new (++tot) Node(key); T -> lc = T -> rc = T -> F = NIL; return T;}
    void pushdn(Node *T)
    {
        if (T -> rev)
        {
            std::swap(T -> lc, T -> rc); std::swap(T -> mls, T -> mrs);
            T -> lc -> rev ^= 1, T -> rc -> rev ^= 1; T -> rev = 0;
        }
        if (T -> sm)
        {
            T -> sm = 0; T -> lc -> sm = T -> rc -> sm = 1;
            T -> lc -> key = T -> rc -> key = T -> key;
            T -> mls = T -> mrs = T -> sum = T -> Max = T -> key * T -> sz;
            if (T -> key < 0) T -> mls = T -> mrs = T -> Max = T -> key;
        }
        return;
    }
    void update(Node *T)
    {
        T -> sz = T -> lc -> sz + T -> rc -> sz + 1;
        T -> sum = T -> lc -> sum + T -> rc -> sum + T -> key;
        T -> mls = max(T -> lc -> mls, T -> lc -> sum + T -> key + max(0, T -> rc -> mls));
        T -> mrs = max(T -> rc -> mrs, T -> rc -> sum + T -> key + max(0, T -> lc -> mrs));
        T -> Max = max(T -> lc -> Max, T -> rc -> Max); gmax(T -> Max, T -> key);
        gmax(T -> Max, max(T -> lc -> mrs, T -> rc -> mls) + T -> key);
        gmax(T -> Max, T -> lc -> mrs + T -> key + T -> rc -> mls); return;
    } //
    void Zig(Node *T)
    {
        Node *P = T -> F, *tmp = T -> rc; pushdn(P -> rc); pushdn(T -> lc); pushdn(T -> rc);
        //旋转之前标记先向下传。
        if (P == this -> T) this -> T = T; //
        else (P -> F -> lc == P) ? (P -> F -> lc = T) : (P -> F -> rc = T);
        T -> F = P -> F; P -> F = T; P -> lc = tmp; T -> rc = tmp -> F = P;
        update(P); return;
    }
    void Zag(Node *T)
    {
        Node *P = T -> F, *tmp = T -> lc; pushdn(P -> lc); pushdn(T -> lc); pushdn(T -> rc);
        //旋转之前标记先向下传。
        if (P == this -> T) this -> T = T; //
        else (P -> F -> lc == P) ? (P -> F -> lc = T) : (P -> F -> rc = T);
        T -> F = P -> F; P -> F = T; P -> rc = tmp; T -> lc = tmp -> F = P;
        update(P); return;
    }
    void Splay(Node *&T, Node *t)
    {
        for (pushdn(t); T != t;)
        {
            Node *P = t -> F;
            if (P == T) (P -> lc == t) ? Zig(t) : Zag(t);
            else
            {
                if (P -> F -> lc == P) (P -> lc == t) ? Zig(P) : Zag(t), Zig(t);
                else (P -> lc == t) ? Zig(t) : Zag(P), Zag(t);
            }
        }
        update(t); return; //这里最后才对t进行更新。
    }
    void K_th(Node *&T, int k)
    {
        for (Node *t = T; pushdn(t), t;) //选择之前先标记向下传。
        {
            if (k == t -> lc -> sz + 1) {Splay(T, t); return;}
            if (k <= t -> lc -> sz) t = t -> lc;
            else k -= t -> lc -> sz + 1, t = t -> rc;
        }
        return;
    }
public:
    SplayTree()
    {
        tot = NIL; NIL -> key = NIL -> Max = NIL -> mls = NIL -> mrs = ~INF;
        NIL -> sum = NIL -> sz = 0; NIL -> lc = NIL -> rc = NIL -> F = NIL;
        //虚拟一个空结点出来。
        head = new (++tot) Node(~INF); tail = new (++tot) Node(~INF);
        head -> lc = head -> rc = head -> F
            = tail -> lc = tail -> rc = tail -> F = NIL;
        head -> sum = tail -> sum = 0;
        head -> rc = tail; tail -> F = head; ++head -> sz; T = head;
    }
    void Ins(int pos, int *fir, int *la)
    {
        K_th(T, pos + 1); K_th(T -> rc, 1);
        Node *t = NewNode(*fir++), *p = t, *q = t;
        while (fir != la) t = NewNode(*fir++), t -> F = p, p = p -> rc = t;
        T -> rc -> lc = q; q -> F = T -> rc; Splay(T, p); return;
    }
    void Del(int pos, int tot)
    {
        K_th(T, pos); K_th(T -> rc, tot + 1);
        T -> rc -> lc = NIL; Splay(T, T -> rc); return;
    }
    void Rev(int pos, int tot)
    {
        K_th(T, pos); K_th(T -> rc, tot + 1);
        T -> rc -> lc -> rev ^= 1; Splay(T, T -> rc -> lc); return;
    }
    void mksm(int pos, int tot, int key)
    {
        K_th(T, pos); K_th(T -> rc, tot + 1); T -> rc -> lc -> key = key;
        T -> rc -> lc -> sm = 1; Splay(T, T -> rc -> lc); return;
    }
    int gsum(int pos, int tot)
    {K_th(T, pos); K_th(T -> rc, tot + 1); return T -> rc -> lc -> sum;}
    int max_sum() {return T -> Max;}
} Tr; char str[20]; int a[maxN], n, m, pos, tot, c;

int main()
{
    freopen("sequence.in", "r", stdin);
    freopen("sequence.out", "w", stdout);
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; ++i) scanf("%d", a + i); Tr.Ins(0, a, a + n);
    while (m--)
    {
        scanf("%s", str);
        if (!strcmp(str, "MAX-SUM")) {printf("%d\n", Tr.max_sum()); continue;}
        scanf("%d%d", &pos, &tot);
        if (!strcmp(str, "INSERT"))
        {
            for (int i = 0; i < tot; ++i) scanf("%d", a + i);
            Tr.Ins(pos, a, a + tot);
        }
        else if (!strcmp(str, "DELETE")) Tr.Del(pos, tot);
        else if (!strcmp(str, "MAKE-SAME")) scanf("%d", &c), Tr.mksm(pos, tot, c);
        else if (!strcmp(str, "REVERSE")) Tr.Rev(pos, tot);
        else if (!strcmp(str, "GET-SUM")) printf("%d\n", Tr.gsum(pos, tot));
    }
    return 0;
}

还有一个用伪链表写的不能过样例但能过题的程序……

/******************************\
 * @prob: NOI2005 sequence    *
 * @auth: Wang Junji          *
 * @stat: ???                 *
 * @date: June. 4th, 2012     *
 * @memo: 伸展树代替块状链表    *
\******************************/
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <string>

using std::max; const int maxN = 4000010, INF = 0x07070707;

template <typename _Tp> inline void gmax(_Tp &a, _Tp b) {if (b > a) a = b;}

class SplayTree
{
private:
    struct Node
    {
        int sum, mls, mrs, Max, num; bool sm, rev; Node() {}
        Node(int num): num(num), sm(0), rev(0) {sum = mls = mrs = Max = num;}
    } key[maxN]; int lc[maxN], rc[maxN], F[maxN], sz[maxN], T, tot;
    int NewNode(int num)
    {key[++tot] = Node(num); sz[tot] = 1; lc[tot] = rc[tot] = F[tot] = 0; return tot;}
    void update(int T)
    {
        sz[T] = sz[lc[T]] + sz[rc[T]] + 1;
        key[T].sum = key[lc[T]].sum + key[rc[T]].sum + key[T].num;
        key[T].mls = max(key[lc[T]].mls, key[lc[T]].sum + key[T].num + max(0, key[rc[T]].mls));
        key[T].mrs = max(key[rc[T]].mrs, key[rc[T]].sum + key[T].num + max(0, key[lc[T]].mrs));
        key[T].Max = max(key[lc[T]].Max, key[rc[T]].Max); gmax(key[T].Max, key[T].num);
        gmax(key[T].Max, max(key[lc[T]].mrs, key[rc[T]].mls) + key[T].num);
        gmax(key[T].Max, key[lc[T]].mrs + key[T].num + key[rc[T]].mls); return;
    }
    void pushdn(int T)
    {
        if (key[T].rev)
        {
            std::swap(lc[T], rc[T]); std::swap(key[T].mls, key[T].mrs);
            key[lc[T]].rev ^= 1, key[rc[T]].rev ^= 1; key[T].rev = 0;
        }
        if (key[T].sm)
        {
            key[T].sm = 0; key[lc[T]].sm = key[rc[T]].sm = 1;
            key[lc[T]].num = key[rc[T]].num = key[T].num;
//            sz[T] = sz[lc[T]] + sz[rc[T]] + 1;
            key[T].mls = key[T].mrs = key[T].Max =
                max(key[T].num, key[T].sum = key[T].num * sz[T]); //
        }
        return;
    }
    void Zig(int T)
    {
        int P = F[T], tmp = rc[T]; pushdn(rc[P]); pushdn(lc[T]); pushdn(rc[T]);
        if (P == this -> T) this -> T = T; else (lc[F[P]] == P) ? (lc[F[P]] = T) : (rc[F[P]] = T);
        F[T] = F[P]; F[P] = T; lc[P] = tmp; update(F[tmp] = rc[T] = P); return;
    }
    void Zag(int T)
    {
        int P = F[T], tmp = lc[T]; pushdn(lc[P]); pushdn(lc[T]); pushdn(rc[T]);
        if (P == this -> T) this -> T = T; else (lc[F[P]] == P) ? (lc[F[P]] = T) : (rc[F[P]] = T);
        F[T] = F[P]; F[P] = T; rc[P] = tmp; update(F[tmp] = lc[T] = P); return;
    }
    void Splay(int T, int t)
    {
        pushdn(t); int tmp = F[T];
        while (F[t] - tmp)
        {
            int P = F[t];
            if (F[P] == tmp) (lc[P] == t) ? Zig(t) : Zag(t);
            else
            {
                if (lc[F[P]] == P) (lc[P] == t) ? Zig(P) : Zag(t), Zig(t);
                else (lc[P] == t) ? Zig(t) : Zag(P), Zag(t);
            }
        }
        update(t); return; //
    }
    void K_th(int T, int k)
    {
        for (int t = T; pushdn(t), t;)
        {
            if (k == sz[lc[t]] + 1) {Splay(T, t); return;}
            if (k <= sz[lc[t]]) t = lc[t]; else k -= sz[lc[t]] + 1, t = rc[t];
        }
        return;
    }
    void Build(int &T, int *fir, int *la)
    {
        int *Mid = fir + ((la - fir) >> 1); T = NewNode(*Mid);
        if (fir - Mid < 0) Build(lc[T], fir, Mid), F[lc[T]] = T;
        if (Mid + 1 - la < 0) Build(rc[T], Mid + 1, la), F[rc[T]] = T;
        update(T); return;
    }
public:
    SplayTree(): T(1), tot(2)
    {
        key[0].num = key[0].mls = key[0].mrs = key[0].Max =
        key[1].num = key[1].mls = key[1].mrs = key[1].Max =
        key[2].num = key[2].mls = key[2].mrs = key[2].Max = ~INF;
        rc[1] = 2; F[2] = 1; sz[1] = 2, sz[2] = 1;
    }
    void Build(int *fir, int *la)
    {
        if (fir - la >= 0) return; Build(lc[rc[T]], fir, la);
        F[lc[rc[T]]] = rc[T]; Splay(T, lc[rc[T]]); return;
    }
    void Ins(int pos, int *fir, int *la)
    {
        if (fir - la >= 0) return;
        K_th(T, pos + 1); K_th(rc[T], 1); Build(lc[rc[T]], fir, la);
        F[lc[rc[T]]] = rc[T]; Splay(T, lc[rc[T]]); return;
    }
    void Del(int pos, int tot)
    {
        K_th(T, pos); K_th(rc[T], tot + 1); int t = rc[T];
        lc[t] = 0; Splay(T, t); return;
    }
    void Rev(int pos, int tot)
    {
        K_th(T, pos); K_th(rc[T], tot + 1); int t = lc[rc[T]];
        key[t].rev ^= 1; Splay(T, t); return;
    }
    void mksm(int pos, int tot, int num)
    {
        K_th(T, pos); K_th(rc[T], tot + 1); int t = lc[rc[T]];
        key[t].num = num; key[t].sm = 1; Splay(T, t); return;
    }
    int gsum(int pos, int tot)
    {K_th(T, pos); K_th(rc[T], tot + 1); return key[lc[rc[T]]].sum;}
    int max_sum() {return key[T].Max;}
} Tr; char str[20]; int a[maxN], n, m, pos, tot, c;

int main()
{
    freopen("sequence.in", "r", stdin);
    freopen("sequence.out", "w", stdout);
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; ++i) scanf("%d", a + i);
    Tr.Build(a, a + n);
    while (m--)
    {
        scanf("%s", str);
        if (!strcmp(str, "MAX-SUM")) {printf("%d\n", Tr.max_sum()); continue;}
        scanf("%d%d", &pos, &tot);
        if (!strcmp(str, "INSERT"))
        {
            for (int i = 0; i < tot; ++i) scanf("%d", a + i);
            Tr.Ins(pos, a, a + tot);
        }
        else if (!strcmp(str, "DELETE")) Tr.Del(pos, tot);
        else if (!strcmp(str, "MAKE-SAME")) scanf("%d", &c), Tr.mksm(pos, tot, c);
        else if (!strcmp(str, "REVERSE")) Tr.Rev(pos, tot);
        else if (!strcmp(str, "GET-SUM")) printf("%d\n", Tr.gsum(pos, tot));
    }
    return 0;
}

抱歉!评论已关闭.