Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734
叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43
Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp
hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431
HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/
View Code /* http://acm.pku.edu.cn/JudgeOnline/problem?id=3468 区间跟新,区间求和 */ #include <cstdio> #define keyTree (ch[ ch[root][1] ][0]) const int maxn = 222222; struct SplayTree{ int sz[maxn]; int ch[maxn][2]; int pre[maxn]; int root , top1 , top2; int ss[maxn] , que[maxn]; inline void Rotate(int x,int f) { int y = pre[x]; push_down(y); push_down(x); ch[y][!f] = ch[x][f]; pre[ ch[x][f] ] = y; pre[x] = pre[y]; if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x; ch[x][f] = y; pre[y] = x; push_up(y); } inline void Splay(int x,int goal) { push_down(x); while(pre[x] != goal) { if(pre[pre[x]] == goal) { Rotate(x , ch[pre[x]][0] == x); } else { int y = pre[x] , z = pre[y]; int f = (ch[z][0] == y); if(ch[y][f] == x) { Rotate(x , !f) , Rotate(x , f); } else { Rotate(y , f) , Rotate(x , f); } } } push_up(x); if(goal == 0) root = x; } inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边 int x = root; push_down(x); while(sz[ ch[x][0] ] != k) { if(k < sz[ ch[x][0] ]) { x = ch[x][0]; } else { k -= (sz[ ch[x][0] ] + 1); x = ch[x][1]; } push_down(x); } Splay(x,goal); } inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存 int father = pre[x]; int head = 0 , tail = 0; for (que[tail++] = x ; head < tail ; head ++) { ss[top2 ++] = que[head]; if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0]; if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1]; } ch[ father ][ ch[father][1] == x ] = 0; pushup(father); } //以上一般不修改////////////////////////////////////////////////////////////////////////////// void debug() {printf("%d\n",root);Treaval(root);} void Treaval(int x) { if(x) { Treaval(ch[x][0]); printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]); Treaval(ch[x][1]); } } //以上Debug //以下是题目的特定函数: inline void NewNode(int &x,int c) { if (top2) x = ss[--top2];//用栈手动压的内存池 else x = ++top1; ch[x][0] = ch[x][1] = pre[x] = 0; sz[x] = 1; val[x] = sum[x] = c;/*这是题目特定函数*/ add[x] = 0; } //把延迟标记推到孩子 inline void push_down(int x) {/*这是题目特定函数*/ if(add[x]) { val[x] += add[x]; add[ ch[x][0] ] += add[x]; add[ ch[x][1] ] += add[x]; sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x]; sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x]; add[x] = 0; } } //把孩子状态更新上来 inline void push_up(int x) { sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ]; /*这是题目特定函数*/ sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ]; } /*初始化*/ inline void makeTree(int &x,int l,int r,int f) { if(l > r) return ; int m = (l + r)>>1; NewNode(x , num[m]); /*num[m]权值改成题目所需的*/ makeTree(ch[x][0] , l , m - 1 , x); makeTree(ch[x][1] , m + 1 , r , x); pre[x] = f; push_up(x); } inline void init(int n) {/*这是题目特定函数*/ ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0; add[0] = sum[0] = 0; root = top1 = 0; //为了方便处理边界,加两个边界顶点 NewNode(root , -1); NewNode(ch[root][1] , -1); pre[top1] = root; sz[root] = 2; for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]); makeTree(keyTree , 0 , n-1 , ch[root][1]); push_up(ch[root][1]); push_up(root); } /*更新*/ inline void update( ) {/*这是题目特定函数*/ int l , r , c; scanf("%d%d%d",&l,&r,&c); RotateTo(l-1,0); RotateTo(r+1,root); add[ keyTree ] += c; sum[ keyTree ] += (long long)c * sz[ keyTree ]; } /*询问*/ inline void query() {/*这是题目特定函数*/ int l , r; scanf("%d%d",&l,&r); RotateTo(l-1 , 0); RotateTo(r+1 , root); printf("%lld\n",sum[keyTree]); } /*这是题目特定变量*/ int num[maxn]; int val[maxn]; int add[maxn]; long long sum[maxn]; }spt; int main() { int n , m; scanf("%d%d",&n,&m); spt.init(n); while(m --) { char op[2]; scanf("%s",op); if(op[0] == 'Q') { spt.query(); } else { spt.update(); } } return 0; }
View Code #include <cstdio> #include <cstring> #include <vector> #include <climits> #include <algorithm> using namespace std; const int N = 200000; const int M = 1 + (N << 1); const int EMPTY = M - 1; const int MOD = 99990001; int nodeCount, type[M], parent[M], children[M][2], id[M]; int scale[M], delta[M], weight[M], size[M], minimum[M]; void update(int x) { size[x] = size[children[x][0]] + 1 + size[children[x][1]]; minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]); } void modify(int x, int k, int b) { weight[x] = ((long long)k * weight[x] + b) % MOD; scale[x] = (long long)k * scale[x] % MOD; delta[x] = ((long long)k * delta[x] + b) % MOD; } void pushDown(int x) { for (int i = 0; i < 2; ++ i) { if (children[x][i] != EMPTY) { modify(children[x][i], scale[x], delta[x]); } } scale[x] = 1; delta[x] = 0; } void rotate(int x) { int t = type[x]; int y = parent[x]; int z = children[x][1 ^ t]; type[x] = type[y]; parent[x] = parent[y]; if (type[x] != 2) { children[parent[x]][type[x]] = x; } type[y] = 1 ^ t; parent[y] = x; children[x][1 ^ t] = y; if (z != EMPTY) { type[z] = t; parent[z] = y; } children[y][t] = z; update(y); } void splay(int x) { if (x == EMPTY) { return; } vector <int> stack(1, x); for (int i = x; type[i] != 2; i = parent[i]) { stack.push_back(parent[i]); } while (!stack.empty()) { pushDown(stack.back()); stack.pop_back(); } while (type[x] != 2) { int y = parent[x]; if (type[x] == type[y]) { rotate(y); } else { rotate(x); } if (type[x] == 2) { break; } rotate(x); } update(x); } int goLeft(int x) { while (children[x][0] != EMPTY) { x = children[x][0]; } return x; } int join(int x, int y) { if (x == EMPTY || y == EMPTY) { return x != EMPTY ? x : y; } y = goLeft(y); splay(y); splay(x); type[x] = 0; parent[x] = y; children[y][0] = x; update(y); return y; } pair <int, int> split(int x) { splay(x); int a = children[x][0]; int b = children[x][1]; children[x][0] = children[x][1] = EMPTY; if (a != EMPTY) { type[a] = 2; parent[a] = EMPTY; } if (b != EMPTY) { type[b] = 2; parent[b] = EMPTY; } return make_pair(a, b); } int newNode(int init, int vid) { int x = nodeCount ++; type[x] = 2; parent[x] = children[x][0] = children[x][1] = EMPTY; id[x] = vid; weight[x] = init; scale[x] = 1; delta[x] = 0; update(x); return x; } int n; int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M]; int root; void addEdge(int u, int v) { to[edgeCount] = v; nextEdge[edgeCount] = firstEdge[u]; firstEdge[u] = edgeCount ++; } void dfs(int p, int u) { for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) { int v = to[iter]; if (v != p) { position[iter] = nodeCount; root = join(root, newNode(initWeight[iter >> 1], min(u, v))); dfs(u, v); position[iter ^ 1] = nodeCount; root = join(root, newNode(initWeight[iter >> 1], min(u, v))); } } } int getRank(int x) { // 1-based splay(x); return size[children[x][0]] + 1; } void print(int root) { if (root != EMPTY) { printf("[ "); print(children[root][0]); printf(" %d ", root); print(children[root][1]); printf(" ]"); } } int main() { size[EMPTY] = 0; minimum[EMPTY] = INT_MAX; parent[EMPTY] = 2; scanf("%d", &n); edgeCount = 0; memset(firstEdge, -1, sizeof(firstEdge)); for (int i = 0; i < n - 1; ++ i) { int a, b; scanf("%d%d%d", &a, &b, initWeight + i); a --; b --; addEdge(a, b); addEdge(b, a); } nodeCount = 0; root = EMPTY; dfs(-1, 0); for (int i = 0; i < n - 1; ++ i) { int id; scanf("%d", &id); id --; int a = position[id << 1]; int b = position[(id << 1) ^ 1]; if (getRank(a) > getRank(b)) { swap(a, b); } splay(a); int output = weight[a]; printf("%d\n", output); fflush(stdout); pair <int, int> ret1 = split(a); pair <int, int> ret2 = split(b); int x = ret1.first; int y = ret2.first; int z = ret2.second; x = join(z, x); splay(x); splay(y); if (size[x] > size[y]) { swap(x, y); } if (size[x] == size[y] && minimum[x] > minimum[y]) { swap(x, y); } modify(x, output, 0); modify(y, 1, output); } return 0; }
spoj SEQ2
Vani
View Code #include <cstdio> #include <cctype> #include <algorithm> #include <cstring> using namespace std; namespace Solve { const int MAXN = 500010; const int inf = 500000000; char BUF[50000000], *pos = BUF; inline int ScanInt(void) { int r = 0, d = 0; while (!isdigit(*pos) && *pos != '-') pos++; if (*pos != '-') r = *pos - 48; else d = 1; pos++; while ( isdigit(*pos)) r = r * 10 + *pos++ - 48; return d ? -r : r; } inline void ScanStr(char *st) { int l = 0; while (!(isupper(*pos) || *pos == '-')) pos++; st[l++] = *pos++; while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0; } struct Node { Node *ch[2], *p; int v, lmax, rmax, m, same, rev, sum, size; inline bool dir(void) {return this == p->ch[1];} inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;} inline void Update(void) { Node *L = ch[0], *R = ch[1]; size = L->size + R->size + 1; m = max(L->m, R->m); m = max(m, L->rmax + v + R->lmax); lmax = max(L->lmax, L->sum + v + R->lmax); rmax = max(R->rmax, R->sum + v + L->rmax); sum = L->sum + R->sum + v; } inline void Rev(void) { if (v == -inf) return; rev ^= 1; swap(ch[0], ch[1]); swap(lmax, rmax); } inline void Same(int u) { if (v == -inf) return; same = u; sum = u * size; if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u; v = u; } inline void Down(void) { if (rev) { ch[0]->Rev(), ch[1]->Rev(); rev = 0; } if (same != -inf) { ch[0]->Same(same), ch[1]->Same(same); same = -inf; } } } Tnull, *null = &Tnull; class Splay {public: Node *root; inline void rotate(Node *x) { Node *p = x->p; bool d = x->dir(); p->Down(); x->Down(); p->p->SetC(x, p->dir()); p->SetC(x->ch[!d], d); x->SetC(p, !d); p->Update(); } inline void splay(Node *x, Node *G) { if (G == null) root = x; while (x->p != G) { if (x->p->p == G) {rotate(x); break;} else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);} } x->Update(); } inline Node *Select(int k) { Node *t = root; while (t->Down(), t->ch[0]->size + 1 != k) { if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1]; else t = t->ch[0]; } splay(t, null); return t; } inline Node *getInterval(int l, int r) { Node *L = Select(l), *R = Select(r + 2); splay(L, null); splay(R, L); L->Down(); R->Down(); return R; } inline void Insert(int pos, Node *x) { Node *now = getInterval(pos + 1, pos); now->SetC(x, 0); now->Update(); root->Update(); } inline void Delete(int l, int r) { Node *now = getInterval(l, r); now->ch[0] = null; now->Update(); root->Update(); } inline void Make(int l, int r, int c) { Node *now = getInterval(l, r); now->ch[0]->Same(c); now->Update(); root->Update(); } inline void Reverse(int l, int r) { Node *now = getInterval(l, r); now->ch[0]->Rev(); now->Update(); root->Update(); } inline int Sum(int l, int r) { Node *now = getInterval(l, r); root->Down(); now->Down(); return now->ch[0]->sum; } inline int maxSum(int l, int r) { Node *now = getInterval(l, r); root->Down(); now->Down(); return now->ch[0]->m; } inline Node* Renew(int c) { Node *ret = new Node; ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1; ret->Same(c); ret->same = -inf; return ret; } inline Node* Build(int l, int r, int *a) { if (l > r) return null; int mid = (l + r) >> 1; Node *ret = Renew(a[mid]); ret->ch[0] = Build(l, mid - 1, a); ret->ch[1] = Build(mid + 1, r, a); ret->ch[0]->p = ret->ch[1]->p = ret; ret->Update(); return ret; } inline void P(Node *t) { if (t == null) return; t->Down(); t->Update(); P(t->ch[0]); printf("%d ", t->v); P(t->ch[1]); } }T; int a[MAXN]; char ch[10]; inline void solve(void) { fread(BUF, 1, 50000000, stdin); null->same = null->m = null->v = -inf; int kase = ScanInt(); while (kase--) { int n = ScanInt(), m = ScanInt(); for (int i = 1; i <= n; i++) a[i] = ScanInt(); T.root = T.Build(0, n + 1, a); for (int i = 1; i <= m; i++) { ScanStr(ch); if (strcmp(ch, "INSERT") == 0) { int pos = ScanInt(), t = ScanInt(); for (int j = 1; j <= t; j++) a[j] = ScanInt(); Node *tmp = T.Build(1, t, a); T.Insert(pos, tmp); } if (strcmp(ch, "DELETE") == 0) { int l = ScanInt(), r = ScanInt(); r = l + r - 1; T.Delete(l, r); } if (strcmp(ch, "MAKE-SAME") == 0) { int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1; T.Make(l, r, c); } if (strcmp(ch, "REVERSE") == 0) { int l = ScanInt(), r = ScanInt(); r = l + r - 1; T.Reverse(l, r); } if (strcmp(ch, "GET-SUM") == 0) { int l = ScanInt(), r = ScanInt(); r = l + r - 1; int ret = T.Sum(l, r); printf("%d\n", ret); } if (strcmp(ch, "MAX-SUM") == 0) { int ret = T.maxSum(1, T.root->size - 2); printf("%d\n", ret); } } } } } int main(void) { freopen("in", "r", stdin); Solve::solve(); return 0; }