现在的位置: 首页 > 综合 > 正文

【树套树】【bzoj 2141】: 排队

2017年04月24日 ⁄ 综合 ⁄ 共 11433字 ⁄ 字号 评论关闭

http://www.lydsy.com/JudgeOnline/problem.php?id=2141

明明是个水水哒树套树,可我就是狂T

难道B站会专门卡我的树套树不成。。。。?

找了两份AC代码,一个分块,一个树套树

怎么拍我都比那树套树快!!!

可只有我的会T。。。。。。。。。。。。。。。。。。。。。。。。。。

为什么B站看我树套树那么不顺眼捏。?

m扩大10倍:

my code:(TLE)

//#define _TEST _TEST
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <map>
using namespace std;
/************************************************
Code By willinglive    Blog:http://willinglive.cf
************************************************/
#define rep(i,l,r) for(int i=l,___t=(r);i<=___t;i++)
#define per(i,r,l) for(int i=r,___t=(l);i>=___t;i--)
#define MS(arr,x) memset(arr,x,sizeof(arr))
#define LL long long
#define INE(i,u,e) for(int i=head[u];~i;i=e[i].next)
inline const int getint()
{
    int r=0,k=1;char c=getchar();
    for(;c<'0'||c>'9';c=getchar())if(c=='-')k=-1;
    for(;c>='0'&&c<='9';c=getchar())r=r*10+c-'0';
    return k*r;
}
/////////////////////////////////////////////////
int n,m;
int h[20010],cnt;
map<int,int>M;
/////////////////////////////////////////////////
int rnd(){return rand()<<16|rand();}
namespace Treap
{//////////////////////////////
#define LS T[o].l
#define RS T[o].r
int root,sz;
struct data{int l,r,s,rnd,w;}T[10000010];
void update(int o){T[o].s=T[LS].s+T[RS].s+1;}
void l_rot(int &o){int t=RS;RS=T[t].l;T[t].l=o;T[t].s=T[o].s;update(o);o=t;}
void r_rot(int &o){int t=LS;LS=T[t].r;T[t].r=o;T[t].s=T[o].s;update(o);o=t;}
void insert(int &o,int x)
{
	if(!o)
	{
		o=++sz; T[o].s=1; T[o].rnd=rnd(); T[o].w=x;
		return;
	}
	T[o].s++;
	if(x<T[o].w)
	{
		insert(LS,x);
		if(T[LS].rnd<T[o].rnd) r_rot(o);
	}
	else
	{
		insert(RS,x);
		if(T[RS].rnd<T[o].rnd) l_rot(o);
	}
}
void del(int &o,int x)
{
	if(!o) return;
	if(x==T[o].w)
	{
		if(LS==0||RS==0) o=LS+RS;
		else if(T[LS].rnd<T[RS].rnd) r_rot(o),del(o,x);
		else l_rot(o),del(o,x);
		return;
	}
	T[o].s--;
	if(x<T[o].w) del(LS,x);
	else del(RS,x);
}
int getgreater(int o,int x)
{
	if(!o) return 0;
	if(T[o].w==x) return T[RS].s+1;
	if(x<T[o].w) return T[RS].s+1+getgreater(LS,x);
	else return getgreater(RS,x);
}
int getlower(int o,int x)
{
	if(!o) return 0;
	if(T[o].w==x) return T[LS].s+1;
	if(x<T[o].w) return getlower(LS,x);
	else return T[LS].s+1+getlower(RS,x);
}
}//////////////////////////////
namespace BIT
{//////////////////////////////
int root[20010];
void insert(int o,int x)
{
	using namespace Treap;
	for(;o<=n;o+=o&-o)
	   Treap::insert(root[o],x);
}
void del(int o,int x)
{
	using namespace Treap;
	for(;o<=n;o+=o&-o)
	    Treap::del(root[o],x);
}
int getgreater(int o,int x)
{
	using namespace Treap;
	int s=0;
	for(;o;o-=o&-o)
	    s+=Treap::getgreater(root[o],x);
    return s;
}
int getlower(int o,int x)
{
	using namespace Treap;
	int s=0;
	for(;o>0;o-=o&-o)
	    s+=Treap::getlower(root[o],x);
    return s;
}
}//////////////////////////////
/////////////////////////////////////////////////
void input()
{
	srand(1313131);
    n=getint();
    rep(i,1,n) h[i]=getint(),M[h[i]]=i;
    for(map<int,int>::iterator it=M.begin();it!=M.end();it++)
        h[it->second]=++cnt;
    m=getint();
}
void solve()
{
	using namespace BIT;
	int ans=0;
	int l,r;
	rep(i,1,n)
	{
		insert(i,h[i]);
		ans+=getgreater(i-1,h[i]);
	}
	printf("%d\n",ans);
    while(m--)
    {
    	l=getint(); r=getint();
    	//两句特判!!!!!!!!!!!!!!!!!!!!! 
    	if(l>r) swap(l,r);
    	if(l==r){printf("%d\n",ans);continue;}
    	ans+=2*(getgreater(r-1,h[l]) - getgreater(l,h[l]))+l-r+1;
    	ans+=2*(getlower(r-1,h[r]) - getlower(l,h[r]))+l-r+1;
    	if(h[l]>h[r]) ans--;
    	else ans++;
    	del(l,h[l]); insert(r,h[l]);
    	del(r,h[r]); insert(l,h[r]);
    	swap(h[l],h[r]);
    	printf("%d\n",ans);
    }
}
/////////////////////////////////////////////////
int main()
{
    #ifndef _TEST
    freopen("std.in","r",stdin); freopen("std.out","w",stdout);
    #endif
    input(),
    solve();
    return 0;
}

hzwer:比我的快一点

注意一点:

分块的大小应该为sqrt(n*logn)这样时间复杂度会减小O(sqrt(logn))

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define ll long long
#define inf 1000000000
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,m,block,cnt,ans;
int l[605],r[605];
int a[20005],b[20005],t[20005],disc[20005],belong[20005];
inline int lowbit(int x){return x&(-x);}
inline void update(int x,int val)
{
	for(int i=x;i<=n;i+=lowbit(i))
		t[i]+=val;
}
inline int query(int x)
{
	int sum=0;
	for(int i=x;i;i-=lowbit(i))
		sum+=t[i];
	return sum;
}
inline int disc_find(int x)
{
	int l=1,r=n;
	while(l<=r)
	{
		int mid=(l+r)>>1;
		if(disc[mid]==x)return mid;
		else if(disc[mid]<x)l=mid+1;
		else r=mid-1;
	}
}
int finddown(int l,int r,int x)
{
	int ans=l-1,t=l;
	while(l<=r)
	{
		int mid=(l+r)>>1;
		if(a[mid]<x)ans=mid,l=mid+1;
		else r=mid-1;
	}
	return ans-t+1;
}
int findup(int l,int r,int x)
{
	int ans=r+1,t=r;
	while(l<=r)
	{
		int mid=(l+r)>>1;
		if(a[mid]>x)ans=mid,r=mid-1;
		else l=mid+1;
	}
	return t-ans+1;
}
void rebuild(int x)
{
	for(int i=l[x];i<=r[x];i++)
		a[i]=b[i];
	sort(a+l[x],a+r[x]+1);
}
void pre()
{
	for(int i=n;i;i--)
	{
		ans+=query(b[i]-1);
		update(b[i],1);
	}
	for(int i=1;i<=cnt;i++)
		rebuild(i);
}
void solve(int x,int y)
{
	if(x==y)return;
	int L=r[belong[x]],R=l[belong[y]];
	if(b[x]<b[y])ans++;
	if(b[x]>b[y])ans--;
	if(belong[x]==belong[y])
	{
		for(int i=x+1;i<y;i++)
		{
			if(b[i]>b[x])ans++;
			if(b[i]<b[x])ans--;
			if(b[i]>b[y])ans--;
			if(b[i]<b[y])ans++;
		}
	}
	else 
	{
		for(int i=x+1;i<=L;i++)
		{
			if(b[i]>b[x])ans++;
			if(b[i]<b[x])ans--;
			if(b[i]>b[y])ans--;
			if(b[i]<b[y])ans++;
		}
		for(int i=R;i<y;i++)
		{
			if(b[i]>b[x])ans++;
			if(b[i]<b[x])ans--;
			if(b[i]>b[y])ans--;
			if(b[i]<b[y])ans++;
		}
		for(int i=belong[x]+1;i<belong[y];i++)
		{
			ans-=finddown(l[i],r[i],b[x]);
			ans+=finddown(l[i],r[i],b[y]);
			ans+=findup(l[i],r[i],b[x]);
			ans-=findup(l[i],r[i],b[y]);
		}
	}
	swap(b[x],b[y]);
	rebuild(belong[x]);rebuild(belong[y]);
}
int main()
{
    freopen("std.in","r",stdin); freopen("hzwer.out","w",stdout);
	n=read();
	block=sqrt(n);
	if(n%block)cnt=n/block+1;
	else cnt=n/block;
	for(int i=1;i<=cnt;i++)
		l[i]=(i-1)*block+1,r[i]=i*block;
	r[cnt]=n;
	for(int i=1;i<=n;i++)
		belong[i]=(i-1)/block+1;
	for(int i=1;i<=n;i++)
		disc[i]=a[i]=read();
	sort(disc+1,disc+n+1);
	for(int i=1;i<=n;i++)
		a[i]=b[i]=disc_find(a[i]);
	pre();
	printf("%d\n",ans);
	m=read();
	for(int i=1;i<=m;i++)
	{
		int x=read(),y=read();
		if(x>y)swap(x,y);
		solve(x,y);
		printf("%d\n",ans);
	}
	return 0;
}

明显慢啊啊啊

#include <iostream>
#include <cstdio>
#include <string.h>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <map>
   
   
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define abs(x) ((x)>=0?(x):-(x))
#define i64 long long
#define u32 unsigned int
#define u64 unsigned long long
#define clr(x,y) memset(x,y,sizeof(x))
#define CLR(x) x.clear()
#define ph(x) push(x)
#define pb(x) push_back(x)
#define Len(x) x.length()
#define SZ(x) x.size()
#define PI acos(-1.0)
#define sqr(x) ((x)*(x))
#define MP(x,y) make_pair(x,y)
#define EPS 1e-10
   
   
#define FOR0(i,x) for(i=0;i<x;i++)
#define FOR1(i,x) for(i=1;i<=x;i++)
#define FOR(i,a,b) for(i=a;i<=b;i++)
#define FORL0(i,a) for(i=a;i>=0;i--)
#define FORL1(i,a) for(i=a;i>=1;i--)
#define FORL(i,a,b)for(i=a;i>=b;i--)
   
   
#define rush() int CC;for(scanf("%d",&CC);CC--;)
#define Rush(n)  while(scanf("%d",&n)!=-1)
using namespace std;
   
   
void RD(int &x){scanf("%d",&x);}
void RD(i64 &x){scanf("%lld",&x);}
void RD(u64 &x){scanf("%llu",&x);}
void RD(u32 &x){scanf("%u",&x);}
void RD(double &x){scanf("%lf",&x);}
void RD(int &x,int &y){scanf("%d%d",&x,&y);}
void RD(i64 &x,i64 &y){scanf("%lld%lld",&x,&y);}
void RD(u32 &x,u32 &y){scanf("%u%u",&x,&y);}
void RD(double &x,double &y){scanf("%lf%lf",&x,&y);}
void RD(int &x,int &y,int &z){scanf("%d%d%d",&x,&y,&z);}
void RD(i64 &x,i64 &y,i64 &z){scanf("%lld%lld%lld",&x,&y,&z);}
void RD(u32 &x,u32 &y,u32 &z){scanf("%u%u%u",&x,&y,&z);}
void RD(double &x,double &y,double &z){scanf("%lf%lf%lf",&x,&y,&z);}
void RD(char &x){x=getchar();}
void RD(char *s){scanf("%s",s);}
void RD(string &s){cin>>s;}
   
   
void PR(int x) {printf("%d\n",x);}
void PR(int x,int y) {printf("%d %d\n",x,y);}
void PR(i64 x) {printf("%lld\n",x);}
void PR(u32 x) {printf("%u\n",x);}
void PR(u64 x) {printf("%llu\n",x);}
void PR(double x) {printf("%.2lf\n",x);}
void PR(char x) {printf("%c\n",x);}
void PR(char *x) {printf("%s\n",x);}
void PR(string x) {cout<<x<<endl;}
   
void upMin(int &x,int y) {if(x>y) x=y;}
void upMin(i64 &x,i64 y) {if(x>y) x=y;}
void upMin(double &x,double y) {if(x>y) x=y;}
void upMax(int &x,int y) {if(x<y) x=y;}
void upMax(i64 &x,i64 y) {if(x<y) x=y;}
void upMax(double &x,double y) {if(x<y) x=y;}
   
const int mod=20101009;
const i64 inf=((i64)1)<<60;
const double dinf=1000000000000000000.0;
const int INF=2147483647;
const int N=20005;
   
struct node
{
    int val,size,pri,L,R,cnt;
};
  
  
node a[N*300];
int e;
  
 
int newNode(int val)
{
    int x=++e;;
    a[x].val=val;
    a[x].size=1;
    a[x].cnt=1;
    a[x].L=a[x].R=0;
    a[x].pri=rand();
    return x;
}
 
void pushUp(int x)
{
    if(x==0) return;
    a[x].size=a[x].cnt+a[a[x].L].size+a[a[x].R].size;
}
 
void rotL(int &x)
{
    int y=a[x].R;
    a[x].R=a[y].L;
    a[y].L=x;
      
    pushUp(x);
    pushUp(y);
    x=y;
}
  
void rotR(int &x)
{
    int y=a[x].L;
    a[x].L=a[y].R;
    a[y].R=x;
      
    pushUp(x);
    pushUp(y);
    x=y;
}
  
void insert(int &k,int val)
{
    if(k==0) k=newNode(val);
    else if(val<a[k].val) 
    {
        insert(a[k].L,val);
        if(a[a[k].L].pri>a[k].pri) rotR(k);
    }
    else if(val>a[k].val)
    {
        insert(a[k].R,val);
        if(a[a[k].R].pri>a[k].pri) rotL(k);
    }
    else a[k].cnt++;
    pushUp(k);
}
  
void del(int val,int &k)
{
    if(k==0) return;
    else if(val<a[k].val) del(val,a[k].L);
    else if(val>a[k].val) del(val,a[k].R);
    else
    {
        a[k].cnt--;
        if(a[k].cnt<=0)
        {
            if(a[k].L==0&&a[k].R==0) k=0;
            else if(a[k].L==0) k=a[k].R;
            else if(a[k].R==0) k=a[k].L;
            else
            {
                if(a[a[k].L].pri<a[a[k].R].pri) rotL(k),del(val,a[k].L);
                else rotR(k),del(val,a[k].R);
            }
        }
    }
    pushUp(k);
}
  
int d[N];
  
int getCnt(int t,int x)
{
    if(t==0) return 0;
    if(a[t].val==x) 
    {
        return a[a[t].L].size+a[t].cnt;
    }
    if(a[t].val>x) return getCnt(a[t].L,x);
    return a[t].cnt+a[a[t].L].size+getCnt(a[t].R,x);
}
 
 
int n,m;
int A[N];
  
void Set(int x,int k)
{
    while(x<=n) 
    {
        insert(A[x],k);
        x+=x&-x;
    }
}
 
void erase(int x,int k)
{
    while(x<=n) 
    {
        del(k,A[x]);
        x+=x&-x;
    }
}
 
int get(int x,int k)
{
    int ans=0;
    while(x) 
    {
        ans+=getCnt(A[x],k);
        x-=x&-x;
    }
    return ans;
}
  
pair<int,int> cal(int L,int R,int x)
{
    i64 a=get(R,x)-get(L-1,x);
    i64 b=get(R,x-1)-get(L-1,x-1);
    return MP(a-b,b);
} 
 
int p[N];
 
int find(int low,int high,int x)
{
    int M;
    while(low<=high)
    {
        M=(low+high)>>1;
        if(p[M]==x) return M;
        if(p[M]>x) high=M-1;
        else low=M+1;
    }
}
  
void getInt(int &x)
{
    char c=getchar();
    while(!isdigit(c)) c=getchar();
    x=0;
    while(isdigit(c)) x=x*10+c-'0',c=getchar();
}
  
int main()
{
    freopen("std.in","r",stdin); freopen("ac.out","w",stdout);
    getInt(n);
    int i;
    FOR1(i,n) getInt(d[i]),p[i]=d[i];
    sort(p+1,p+n+1);
    int M=unique(p+1,p+n+1)-(p+1);
    FOR1(i,n) d[i]=find(1,M,d[i]);
    i64 sum=0;
    pair<int,int> temp;
    FOR1(i,n) 
    {
        Set(i,d[i]);
        temp=cal(1,i-1,d[i]);
        sum+=i-1-temp.first-temp.second;
    }
     
    PR(sum);
    RD(m);
    int x,y;
    while(m--)
    {
        getInt(x);
        getInt(y);
        if(x>y) swap(x,y);
        if(d[x]==d[y])
        {
            PR(sum);
            continue;
        }
         
        if(d[x]>d[y]) sum--;
        else sum++;
         
        if(y-x>1) 
        {
            temp=cal(x+1,y-1,d[x]);
            sum-=temp.second;
            sum+=y-x-1-temp.first-temp.second;
              
            temp=cal(x+1,y-1,d[y]);
            sum+=temp.second;
            sum-=y-x-1-temp.first-temp.second;
        }
         
        erase(x,d[x]);
        erase(y,d[y]);
        swap(d[x],d[y]);
        Set(x,d[x]);
        Set(y,d[y]);
         
        PR(sum);
    }
}

附赠makedata.cpp,求查错。。。。。。。。。。。。。。。

//#define _TEST _TEST
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <time.h>
#include <map>
using namespace std;
/************************************************
Code By willinglive    Blog:http://willinglive.cf
************************************************/
#define rep(i,l,r) for(int i=l,___t=(r);i<=___t;i++)
#define per(i,r,l) for(int i=r,___t=(l);i>=___t;i--)
#define MS(arr,x) memset(arr,x,sizeof(arr))
#define LL long long
#define INE(i,u,e) for(int i=head[u];~i;i=e[i].next)
inline const int getint()
{
    int r=0,k=1;char c=getchar();
    for(;c<'0'||c>'9';c=getchar())if(c=='-')k=-1;
    for(;c>='0'&&c<='9';c=getchar())r=r*10+c-'0';
    return k*r;
}
/////////////////////////////////////////////////
int n,m;
map<int,int>M;
/////////////////////////////////////////////////
int rnd(){return rand()<<16|rand();}
/////////////////////////////////////////////////
void input()
{
    srand(time(0));
}
void solve()
{
    n=20000; m=2000;
    printf("%d\n",n);
    rep(i,1,n)
    {
    	int x=rnd()%1000000000+1;
    	while(M[x]) x=rnd()%1000000000+1;
    	M[x]=1;
    	printf("%d ",x);
    }
    printf("\n%d\n",m);
    rep(i,1,m)
    {
    	int l=rnd()%n+1;
    	int r=rnd()%n+1;
    	printf("%d %d\n",l,r);
    }
}
/////////////////////////////////////////////////
int main()
{
    #ifndef _TEST
    freopen("std.in","r",stdin); freopen("std.in","w",stdout);
    #endif
    input(),
    solve();
    return 0;
}

抱歉!评论已关闭.