题目链接:Click here~~
题意:
给一个序列 {an},有 4 种操作。
1、将一段区间的数全部加 c。
2、将一段区间的数全部乘 c。
3、将一段区间的数全部等于 c。
4、询问一段区间的和(和、平方和、立方和)。
解题思路:
很明显的一道线段树题,好久没刷线段树了,又手生了,真是弱。
自己写的代码太挫了,写完调了一天还是不对,后来看到了 zjut_DD 大牛的博客,代码写的真是赞,仿照写了一发。
扯回到这道题上,先将问题简化,假设不会询问区间的平方和与立方和。
首先每段区间需要做两个标记,mul 和 add,表示对这段区间进行的乘法和加法(操作3可以看做乘0加c)。
对于这种多标记的问题,一定要考虑先后顺序对问题产生的影响。
不难发现,先乘后加对问题没有影响,而先加后乘会使之前的加法改变。
然后只要考虑对于乘法和加法,如何在 O(1) 的时间更新区间的 sum 就行了,很好想的,不写了。
p.s.学到一个新的姿势,就是将 change 函数写在结构体里,并在这个函数里将 lazy 改变,很直观方便。
#include <stdio.h> #include <string.h> #include <algorithm> using namespace std; const int N = 1e5 + 5; const int mod = 10007; #define lson u<<1 #define rson u<<1|1 inline int sqr2(int x){ return x * x % mod; } inline int sqr3(int x){ return sqr2(x) * x % mod; } struct SegTree { int l,r; int sum[4]; int mul,add; inline int mid(){ return l + r >> 1; } inline int len(){ return r - l; } void flag_init(){ add = 0, mul = 1; } void to_mul(int m){ (sum[1] *= m) %= mod; (sum[2] *= sqr2(m)) %= mod; (sum[3] *= sqr3(m)) %= mod; (mul *= m) %= mod; (add *= m) %= mod; } void to_add(int a){ (sum[3] += sqr3(a) * len()) %= mod; (sum[3] += 3 * a * sum[2]) %= mod; (sum[3] += 3 * sqr2(a) * sum[1]) %= mod; (sum[2] += sqr2(a) * len()) %= mod; (sum[2] += 2 * a * sum[1]) %= mod; (sum[1] += a * len()) %= mod; (add += a) %= mod; } }T[N<<2]; void build(int u,int l,int r) { T[u].l = l , T[u].r = r; memset(T[u].sum,0,sizeof(T[u].sum)); T[u].flag_init(); if(l == r-1) return ; int m = T[u].mid(); build(lson,l,m); build(rson,m,r); } int op; void push_down(int u) { T[lson].to_mul(T[u].mul); T[rson].to_mul(T[u].mul); T[lson].to_add(T[u].add); T[rson].to_add(T[u].add); T[u].flag_init(); } void push_up(int u) { for(int i=1;i<=3;i++) T[u].sum[i] = (T[lson].sum[i] + T[rson].sum[i]) % mod; } void updata(int u,int l,int r,int mul,int add) { if(T[u].l == l && T[u].r == r) { T[u].to_mul(mul); T[u].to_add(add); return ; } push_down(u); int m = T[u].mid(); if(r <= m) updata(lson,l,r,mul,add); else if(l >= m) updata(rson,l,r,mul,add); else updata(lson,l,m,mul,add), updata(rson,m,r,mul,add); push_up(u); } int query(int u,int l,int r) { if(T[u].l == l && T[u].r == r) return T[u].sum[op]; push_down(u); int m = T[u].mid(); if(r <= m) return query(lson,l,r); else if(l >= m) return query(rson,l,r); else return (query(lson,l,m)+query(rson,m,r)) % mod; } int main() { //freopen("in.ads","r",stdin); int n,m,a,b,c; while(scanf("%d%d",&n,&m),n||m) { build(1,1,n+1); while(m--) { scanf("%d",&op); if(op == 4) { scanf("%d%d%d",&a,&b,&op); printf("%d\n",query(1,a,b+1)); } else { scanf("%d%d%d",&a,&b,&c); if(op == 1) updata(1,a,b+1,1,c); else if(op == 2) updata(1,a,b+1,c,0); else updata(1,a,b+1,0,c); } } } return 0; }