现在的位置: 首页 > 综合 > 正文

splay 伸展树 代码实现

2013年09月01日 ⁄ 综合 ⁄ 共 10307字 ⁄ 字号 评论关闭

Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734

叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43

Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp

hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431

HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/

poj 3468 HH写法

View Code 
 /*

http://acm.pku.edu.cn/JudgeOnline/problem?id=3468

 区间跟新,区间求和
 */
 #include <cstdio>
 #define keyTree (ch[ ch[root][1] ][0])
 const int maxn = 222222;
 struct SplayTree{
     int sz[maxn];
     int ch[maxn][2];
     int pre[maxn];
     int root , top1 , top2;
     int ss[maxn] , que[maxn];
  
     inline void Rotate(int x,int f) {
         int y = pre[x];
         push_down(y);
         push_down(x);
         ch[y][!f] = ch[x][f];
         pre[ ch[x][f] ] = y;
         pre[x] = pre[y];
         if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x;
         ch[x][f] = y;
         pre[y] = x;
         push_up(y);
     }
     inline void Splay(int x,int goal) {
         push_down(x);
         while(pre[x] != goal) {
             if(pre[pre[x]] == goal) {
                 Rotate(x , ch[pre[x]][0] == x);
             } else {
                 int y = pre[x] , z = pre[y];
                 int f = (ch[z][0] == y);
                 if(ch[y][f] == x) {
                     Rotate(x , !f) , Rotate(x , f);
                 } else {
                     Rotate(y , f) , Rotate(x , f);
                 }
             }
         }
         push_up(x);
         if(goal == 0) root = x;
     }
     inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边
         int x = root;
         push_down(x);
         while(sz[ ch[x][0] ] != k) {
             if(k < sz[ ch[x][0] ]) {
                 x = ch[x][0];
             } else {
                 k -= (sz[ ch[x][0] ] + 1);
                 x = ch[x][1];
             }
             push_down(x);
         }
         Splay(x,goal);
     }
     inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存
         int father = pre[x];
         int head = 0 , tail = 0;
         for (que[tail++] = x ; head < tail ; head ++) {
             ss[top2 ++] = que[head];
             if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0];
             if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1];
         }
         ch[ father ][ ch[father][1] == x ] = 0;
         pushup(father);
     }
     //以上一般不修改//////////////////////////////////////////////////////////////////////////////
     void debug() {printf("%d\n",root);Treaval(root);}
     void Treaval(int x) {
         if(x) {
             Treaval(ch[x][0]);
             printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]);
             Treaval(ch[x][1]);
         }
     }
     //以上Debug
  
  
     //以下是题目的特定函数:
     inline void NewNode(int &x,int c) {
         if (top2) x = ss[--top2];//用栈手动压的内存池
         else x = ++top1;
         ch[x][0] = ch[x][1] = pre[x] = 0;
         sz[x] = 1;
  
         val[x] = sum[x] = c;/*这是题目特定函数*/
         add[x] = 0;
     }
  
     //把延迟标记推到孩子
     inline void push_down(int x) {/*这是题目特定函数*/
         if(add[x]) {
             val[x] += add[x];
             add[ ch[x][0] ] += add[x];
             add[ ch[x][1] ] += add[x];
             sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x];
             sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x];
             add[x] = 0;
         }
     }
     //把孩子状态更新上来
     inline void push_up(int x) {
         sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ];
         /*这是题目特定函数*/
         sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ];
     }
  
     /*初始化*/
     inline void makeTree(int &x,int l,int r,int f) {
         if(l > r) return ;
         int m = (l + r)>>1;
         NewNode(x , num[m]);        /*num[m]权值改成题目所需的*/
         makeTree(ch[x][0] , l , m - 1 , x);
         makeTree(ch[x][1] , m + 1 , r , x);
         pre[x] = f;
         push_up(x);
     }
     inline void init(int n) {/*这是题目特定函数*/
         ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0;
         add[0] = sum[0] = 0;
  
         root = top1 = 0;
         //为了方便处理边界,加两个边界顶点
         NewNode(root , -1);
         NewNode(ch[root][1] , -1);
         pre[top1] = root;
         sz[root] = 2;
  
  
         for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]);
         makeTree(keyTree , 0 , n-1 , ch[root][1]);
         push_up(ch[root][1]);
         push_up(root);
     }
     /*更新*/
     inline void update( ) {/*这是题目特定函数*/
         int l , r , c;
         scanf("%d%d%d",&l,&r,&c);
         RotateTo(l-1,0);
         RotateTo(r+1,root);
         add[ keyTree ] += c;
         sum[ keyTree ] += (long long)c * sz[ keyTree ];
     }
     /*询问*/
     inline void query() {/*这是题目特定函数*/
         int l , r;
         scanf("%d%d",&l,&r);
         RotateTo(l-1 , 0);
         RotateTo(r+1 , root);
         printf("%lld\n",sum[keyTree]);
     }
  
  
     /*这是题目特定变量*/
     int num[maxn];
     int val[maxn];
     int add[maxn];
     long long sum[maxn];
 }spt;
  
  
 int main() {
     int n , m;
     scanf("%d%d",&n,&m);
     spt.init(n);
     while(m --) {
         char op[2];
         scanf("%s",op);
         if(op[0] == 'Q') {
             spt.query();
         } else {
             spt.update();
         }
     }
     return 0;
 }

叉姐

View Code 
 #include <cstdio>
 #include <cstring>
 #include <vector>
 #include <climits>
 #include <algorithm>
 using namespace std;
 
 const int N = 200000;
 const int M = 1 + (N << 1);
 const int EMPTY = M - 1;
 
 const int MOD = 99990001;
 
 int nodeCount, type[M], parent[M], children[M][2], id[M];
 
 int scale[M], delta[M], weight[M], size[M], minimum[M];
 
 void update(int x) {
     size[x] = size[children[x][0]] + 1 + size[children[x][1]];
     minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]);
 }
 
 void modify(int x, int k, int b) {
     weight[x] = ((long long)k * weight[x] + b) % MOD;
     scale[x] = (long long)k * scale[x] % MOD;
     delta[x] = ((long long)k * delta[x] + b) % MOD;
 }
 
 void pushDown(int x) {
     for (int i = 0; i < 2; ++ i) {
         if (children[x][i] != EMPTY) {
             modify(children[x][i], scale[x], delta[x]);
         }
     }
     scale[x] = 1;
     delta[x] = 0;
 }
 
 void rotate(int x) {
     int t = type[x];
     int y = parent[x];
     int z = children[x][1 ^ t];
     type[x] = type[y];
     parent[x] = parent[y];
     if (type[x] != 2) {
         children[parent[x]][type[x]] = x;
     }
     type[y] = 1 ^ t;
     parent[y] = x;
     children[x][1 ^ t] = y;
     if (z != EMPTY) {
         type[z] = t;
         parent[z] = y;
     }
     children[y][t] = z;
     update(y);
 }
 
 void splay(int x) {
     if (x == EMPTY) {
         return;
     }
     vector <int> stack(1, x);
     for (int i = x; type[i] != 2; i = parent[i]) {
         stack.push_back(parent[i]);
     }
     while (!stack.empty()) {
         pushDown(stack.back());
         stack.pop_back();
     }
     while (type[x] != 2) {
         int y = parent[x];
         if (type[x] == type[y]) {
             rotate(y);
         } else {
             rotate(x);
         }
         if (type[x] == 2) {
             break;
         }
         rotate(x);
     }
     update(x);
 }
 
 int goLeft(int x) {
     while (children[x][0] != EMPTY) {
         x = children[x][0];
     }
     return x;
 }
 
 int join(int x, int y) {
     if (x == EMPTY || y == EMPTY) {
         return x != EMPTY ? x : y;
     }
     y = goLeft(y);
     splay(y);
     splay(x);
     type[x] = 0;
     parent[x] = y;
     children[y][0] = x;
     update(y);
     return y;
 }
 
 pair <int, int> split(int x) {
     splay(x);
     int a = children[x][0];
     int b = children[x][1];
     children[x][0] = children[x][1] = EMPTY;
     if (a != EMPTY) {
         type[a] = 2;
         parent[a] = EMPTY;
     }
     if (b != EMPTY) {
         type[b] = 2;
         parent[b] = EMPTY;
     }
     return make_pair(a, b);
 }
 
 int newNode(int init, int vid) {
     int x = nodeCount ++;
     type[x] = 2;
     parent[x] = children[x][0] = children[x][1] = EMPTY;
     id[x] = vid;
     weight[x] = init;
     scale[x] = 1;
     delta[x] = 0;
     update(x);
     return x;
 }
 
 int n;
 int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M];
 
 int root;
 
 void addEdge(int u, int v) {
     to[edgeCount] = v;
     nextEdge[edgeCount] = firstEdge[u];
     firstEdge[u] = edgeCount ++;
 }
 
 void dfs(int p, int u) {
     for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) {
         int v = to[iter];
         if (v != p) {
             position[iter] = nodeCount;
             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
             dfs(u, v);
             position[iter ^ 1] = nodeCount;
             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
         }
     }
 }
 
 int getRank(int x) { // 1-based
     splay(x);
     return size[children[x][0]] + 1;
 }
 
 void print(int root) {
     if (root != EMPTY) {
         printf("[ ");
         print(children[root][0]);
         printf(" %d ", root);
         print(children[root][1]);
         printf(" ]");
     }
 }
 
 int main() {
     size[EMPTY] = 0;
     minimum[EMPTY] = INT_MAX;
     parent[EMPTY] = 2;
     scanf("%d", &n);
     edgeCount = 0;
     memset(firstEdge, -1, sizeof(firstEdge));
     for (int i = 0; i < n - 1; ++ i) {
         int a, b;
         scanf("%d%d%d", &a, &b, initWeight + i);
         a --;
         b --;
         addEdge(a, b);
         addEdge(b, a);
     }
     nodeCount = 0;
     root = EMPTY;
     dfs(-1, 0);
     for (int i = 0; i < n - 1; ++ i) {
         int id;
         scanf("%d", &id);
         id --;
 
         int a = position[id << 1];
         int b = position[(id << 1) ^ 1];
         if (getRank(a) > getRank(b)) {
             swap(a, b);
         }
         splay(a);
 
         int output = weight[a];
         printf("%d\n", output);
         fflush(stdout);
 
         pair <int, int> ret1 = split(a);
         pair <int, int> ret2 = split(b);
         int x = ret1.first;
         int y = ret2.first;
         int z = ret2.second;
         x = join(z, x);
         splay(x);
         splay(y);
         if (size[x] > size[y]) {
             swap(x, y);
         }
         if (size[x] == size[y] && minimum[x] > minimum[y]) {
             swap(x, y);
         }
         modify(x, output, 0);
         modify(y, 1, output);
     }
     return 0;
 }

spoj SEQ2

Vani

View Code 
 #include <cstdio>
 #include <cctype>
 #include <algorithm>
 #include <cstring>
 
 using namespace std;
 
 namespace Solve {
     const int MAXN = 500010;
     const int inf = 500000000;
 
     char BUF[50000000], *pos = BUF;
     inline int ScanInt(void) {
         int r = 0, d = 0;
         while (!isdigit(*pos) && *pos != '-') pos++;
         if (*pos != '-') r = *pos - 48; else d = 1; pos++;
         while ( isdigit(*pos)) r = r * 10 + *pos++ - 48;
         return d ? -r : r;
     }
     inline void ScanStr(char *st) {
         int l = 0;
         while (!(isupper(*pos) || *pos == '-')) pos++;
         st[l++] = *pos++;
         while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0;
     }
 
     struct Node {
         Node *ch[2], *p;
         int v, lmax, rmax, m, same, rev, sum, size;
         inline bool dir(void) {return this == p->ch[1];}
         inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;}
         inline void Update(void) {
             Node *L = ch[0], *R = ch[1];
             size = L->size + R->size + 1;
             m = max(L->m, R->m);
             m = max(m, L->rmax + v + R->lmax);
             lmax = max(L->lmax, L->sum + v + R->lmax);
             rmax = max(R->rmax, R->sum + v + L->rmax);
             sum = L->sum + R->sum + v;
         }
         inline void Rev(void) {
             if (v == -inf) return;
             rev ^= 1;
             swap(ch[0], ch[1]);
             swap(lmax, rmax);
         }
         inline void Same(int u) {
             if (v == -inf) return;
             same = u;
             sum = u * size;
             if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u;
             v = u;
         }
         inline void Down(void) {
             if (rev) {
                 ch[0]->Rev(), ch[1]->Rev();
                 rev = 0;
             }
             if (same != -inf) {
                 ch[0]->Same(same), ch[1]->Same(same);
                 same = -inf;
             }
         }
     } Tnull, *null = &Tnull;
 
     class Splay {public:
         Node *root;
         inline void rotate(Node *x) {
             Node *p = x->p; bool d = x->dir();
             p->Down(); x->Down();
             p->p->SetC(x, p->dir());
             p->SetC(x->ch[!d], d);
             x->SetC(p, !d);
             p->Update();
         }
         inline void splay(Node *x, Node *G) {
             if (G == null) root = x;
             while (x->p != G) {
                 if (x->p->p == G) {rotate(x); break;}
                 else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);}
             }
             x->Update();
         }
         inline Node *Select(int k) {
             Node *t = root;
             while (t->Down(), t->ch[0]->size + 1 != k) {
                 if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1];
                 else t = t->ch[0];
             }
             splay(t, null);
             return t;
         }
         inline Node *getInterval(int l, int r) {
             Node *L = Select(l), *R = Select(r + 2);
             splay(L, null); splay(R, L);
             L->Down(); R->Down();
             return R;
         }
         inline void Insert(int pos, Node *x) {
             Node *now = getInterval(pos + 1, pos);
             now->SetC(x, 0);
             now->Update(); root->Update();
         }
         inline void Delete(int l, int r) {
             Node *now = getInterval(l, r);
             now->ch[0] = null;
             now->Update(); root->Update();
         }
         inline void Make(int l, int r, int c) {
             Node *now = getInterval(l, r);
             now->ch[0]->Same(c);
             now->Update(); root->Update();
         }
         inline void Reverse(int l, int r) {
             Node *now = getInterval(l, r);
             now->ch[0]->Rev();
             now->Update(); root->Update();
         }
         inline int Sum(int l, int r) {
             Node *now = getInterval(l, r);
             root->Down(); now->Down();
             return now->ch[0]->sum;
         }
         inline int maxSum(int l, int r) {
             Node *now = getInterval(l, r);
             root->Down(); now->Down();
             return now->ch[0]->m;
         }
         inline Node* Renew(int c) {
             Node *ret = new Node;
             ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1;
             ret->Same(c); ret->same = -inf;
             return ret;
         }
         inline Node* Build(int l, int r, int *a) {
             if (l > r) return null;
             int mid = (l + r) >> 1;
             Node *ret = Renew(a[mid]);
             ret->ch[0] = Build(l, mid - 1, a);
             ret->ch[1] = Build(mid + 1, r, a);
             ret->ch[0]->p = ret->ch[1]->p = ret;
             ret->Update();
             return ret;
         }
         inline void P(Node *t) {
             if (t == null) return;
             t->Down(); t->Update();
             P(t->ch[0]);
             printf("%d ", t->v);
             P(t->ch[1]);
         }
     }T;
 
 
     int a[MAXN]; char ch[10];
 
     inline void solve(void) {
         fread(BUF, 1, 50000000, stdin);
         null->same = null->m = null->v = -inf;
         int kase = ScanInt();
         while (kase--) {
             int n = ScanInt(), m = ScanInt();
             for (int i = 1; i <= n; i++) a[i] = ScanInt();
             T.root = T.Build(0, n + 1, a);
             for (int i = 1; i <= m; i++) {
                 ScanStr(ch);
                 if (strcmp(ch, "INSERT") == 0) {
                     int pos = ScanInt(), t = ScanInt();
                     for (int j = 1; j <= t; j++) a[j] = ScanInt();
                     Node *tmp = T.Build(1, t, a);
                     T.Insert(pos, tmp);
                 }
                 if (strcmp(ch, "DELETE") == 0) {
                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
                     T.Delete(l, r);
                 }
                 if (strcmp(ch, "MAKE-SAME") == 0) {
                     int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1;
                     T.Make(l, r, c);
                 }
                 if (strcmp(ch, "REVERSE") == 0) {
                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
                     T.Reverse(l, r);
                 }
                 if (strcmp(ch, "GET-SUM") == 0) {
                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
                     int ret = T.Sum(l, r);
                     printf("%d\n", ret);
                 }
                 if (strcmp(ch, "MAX-SUM") == 0) {
                     int ret = T.maxSum(1, T.root->size - 2);
                     printf("%d\n", ret);
                 }
             }
         }
     }
 }
 
 int main(void) {
     freopen("in", "r", stdin);
     Solve::solve();
     return 0;
 }

【上篇】
【下篇】

抱歉!评论已关闭.