现在的位置: 首页 > 综合 > 正文

poj3237 树链剖分

2018年02月23日 ⁄ 综合 ⁄ 共 3551字 ⁄ 字号 评论关闭

比上一篇那题多了一种取相反数的操作,直接在那题代码上改了,这里线段树上的每个节点同时记录最大值和最小值,遇到相反操作则最大值的相反数变成最小值,最小值的相反数变成最大值,中间利用好lazy和pushdown,然后就差不多了。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 1<<29
using namespace std;
struct edge
{
    int from,to,val;
};
struct node
{
    int l,r,mx,mn,ne;
} t[41111];
vector<edge>edges;
vector<int>g[11111];
int n,a,b,c,sz;
int f[11111],top[11111],w[11111],s[11111];
int dep[11111],son[11111];
char str[21];
void init()
{
    sz=0;
    edges.clear();
    for(int i=1; i<=n; i++)g[i].clear();
}
void add(int from,int to,int val)
{
    edges.push_back((edge)
    {
        from,to,val
    });
    g[from].push_back(edges.size()-1);
}
void dfs1(int u,int ff)
{
    f[u]=ff;
    s[u]=1;
    son[u]=0;
    int size=g[u].size();
    for(int i=0; i<size; i++)
    {
        edge e=edges[g[u][i]];
        if(e.to==ff)continue;
        dep[e.to]=dep[u]+1;
        dfs1(e.to,u);
        s[u]+=s[e.to];
        if(!son[u]||s[e.to]>s[son[u]])son[u]=e.to;
    }
}
void dfs2(int u,int ff,int r)
{
    top[u]=ff;
    w[u]=++sz;
    if(son[u])dfs2(son[u],ff,u);
    int size=g[u].size();
    for(int i=0; i<size; i++)
    {
        edge e=edges[g[u][i]];
        if(e.to==r||e.to==son[u])continue;
        dfs2(e.to,e.to,u);
    }
}
void build(int ll,int rr,int rot)
{
    t[rot].l=ll;
    t[rot].r=rr;
    t[rot].mx=-maxn;
    t[rot].mn=maxn;
    t[rot].ne=0;
    if(ll==rr)return;
    int mid=(ll+rr)/2;
    build(ll,mid,rot<<1);
    build(mid+1,rr,rot<<1|1);
}
void pushdown(int rot)
{
    t[rot].ne^=1;
    t[rot<<1].ne^=1;
    t[rot<<1|1].ne^=1;
    int u,v;
    u=t[rot<<1].mx;
    v=t[rot<<1].mn;
    t[rot<<1].mx=-v;
    t[rot<<1].mn=-u;
    u=t[rot<<1|1].mx;
    v=t[rot<<1|1].mn;
    t[rot<<1|1].mx=-v;
    t[rot<<1|1].mn=-u;
}
void update(int x,int vv,int rot)
{
    if(t[rot].l==x&&t[rot].r==x)
    {
        t[rot].mx=vv;
        t[rot].mn=vv;
        return;
    }
    if(t[rot].ne)pushdown(rot);
    int mid=(t[rot].l+t[rot].r)/2;
    if(mid>=x)update(x,vv,rot<<1);
    else if(x>mid)update(x,vv,rot<<1|1);
    t[rot].mx=max(t[rot<<1].mx,t[rot<<1|1].mx);
    t[rot].mn=min(t[rot<<1].mn,t[rot<<1|1].mn);
}
int query(int ll,int rr,int rot)
{
    if(ll>rr)return -maxn;
    if(t[rot].l==ll&&t[rot].r==rr)return t[rot].mx;
    if(t[rot].ne)pushdown(rot);
    int mid=(t[rot].l+t[rot].r)/2;
    if(rr<=mid)return query(ll,rr,rot<<1);
    else if(ll>mid)return query(ll,rr,rot<<1|1);
    else return max(query(ll,mid,rot<<1),query(mid+1,rr,rot<<1|1));
}
int solve(int u,int v)
{
    int ans=-maxn;
    int uu=top[u];
    int vv=top[v];
    while(uu!=vv)
    {
        if(dep[uu]>dep[vv])
        {
            swap(u,v);
            swap(uu,vv);
        }
        ans=max(ans,query(w[vv],w[v],1));
        v=f[vv];
        vv=top[v];
    }
    if(dep[u]>dep[v])swap(u,v);
    ans=max(ans,query(w[son[u]],w[v],1));
    return ans;
}
void update2(int ll,int rr,int rot)
{
    if(ll>rr)return;
    if(t[rot].l==ll&&t[rot].r==rr)
    {
        t[rot].ne^=1;
        int u,v;
        u=t[rot].mx;
        v=t[rot].mn;
        t[rot].mx=-v;
        t[rot].mn=-u;
        return;
    }
    if(t[rot].ne)pushdown(rot);
    int mid=(t[rot].l+t[rot].r)/2;
    if(rr<=mid)update2(ll,rr,rot<<1);
    else if(ll>mid)update2(ll,rr,rot<<1|1);
    else
    {
        update2(ll,mid,rot<<1);
        update2(mid+1,rr,rot<<1|1);
    }
    t[rot].mx=max(t[rot<<1].mx,t[rot<<1|1].mx);
    t[rot].mn=min(t[rot<<1].mn,t[rot<<1|1].mn);
}
void negete(int u,int v)
{
    int uu=top[u];
    int vv=top[v];
    while(uu!=vv)
    {
        if(dep[uu]>dep[vv])
        {
            swap(u,v);
            swap(uu,vv);
        }
        update2(w[vv],w[v],1);
        //ans=max(ans,query(w[vv],w[v],1));
        v=f[vv];
        vv=top[v];
    }
    if(dep[u]>dep[v])swap(u,v);
    update2(w[son[u]],w[v],1);
    //ans=max(ans,query(w[son[u]],w[v],1));
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        init();
        for(int i=1; i<n; i++)
        {
            scanf("%d%d%d",&a,&b,&c);
            add(a,b,c);
            add(b,a,c);
        }
        dep[1]=0;
        dfs1(1,1);
        dfs2(1,1,1);
        build(1,sz,1);
        int size=edges.size();
        for(int i=0; i<size; i+=2)
        {
            edge e=edges[i];
            if(dep[e.from]>dep[e.to])e=edges[i^1];
            update(w[e.to],e.val,1);
        }
        while(scanf("%s",str))
        {
            if(str[0]=='D')break;
            else if(str[0]=='C')
            {
                scanf("%d%d",&a,&c);
                int en=(a-1)*2;
                edge e=edges[en];
                if(dep[e.from]>dep[e.to])e=edges[en^1];
                update(w[e.to],c,1);
            }
            else if(str[0]=='N')
            {
                scanf("%d%d",&a,&b);
                negete(a,b);
            }
            else
            {
                scanf("%d%d",&a,&b);
                printf("%d\n",solve(a,b));
            }
        }
    }
    return 0;
}
/*
1
3
1 2 1
1 3 2
QUERY 1 2
QUERY 1 3
QUERY 2 3
NEGATE 2 3
QUERY 2 3
DONE
*/

抱歉!评论已关闭.