现在的位置: 首页 > 综合 > 正文

hdu 4742 Pinball Game 3D 分治+树状数组

2013年10月28日 ⁄ 综合 ⁄ 共 2129字 ⁄ 字号 评论关闭

离散化x然后用树状数组解决,排序y然后分治解决,z在分治的时候排序解决。

具体:先对y排序,solve(l,r)分成solve(l,mid),solve(mid+1,r), 然后因为是按照y排序,所以l,mid区间内的y值都小于mid+1,r。现在再对z排序,按照顺序以x做关键字插入到树状数组中,那么就可以一起解决l,mid对mid+1,r的影响。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=1e5+9,mod=1<<30;
int trsum[maxn],trmax[maxn];
int n;
struct P
{
    int x,y,z,id;
}point[maxn],now[maxn];
struct A
{
    int max,sum;
}ans[maxn],tr[maxn];
bool cmpx(const P a,const P b)
{
    return a.x<b.x;
}
bool cmpy(const P a,const P b)
{
    return a.y<b.y;
}
bool cmpz(const P a,const P b)
{
    return a.z<b.z;
}

int lowbit(int x)
{
    return (x&-x);
}

void insert(int x,A tmp)
{
    for(int i=x;i<=n;i+=lowbit(i))
    {
        if(tr[i].max==tmp.max)
        {
            tr[i].sum+=tmp.sum;
            tr[i].sum%=mod;
        }
        else if(tr[i].max<tmp.max)
        {
            tr[i].sum=tmp.sum;
            tr[i].max=tmp.max;
        }
    }
}

A getsum(int x)
{
    A ret;
    ret.max=-1;
    for(int i=x;i>=1;i-=lowbit(i))
    {
        if(tr[i].max>ret.max)
        {
            ret.max=tr[i].max;
            ret.sum=tr[i].sum;
        }
        else if(tr[i].max==ret.max)
        {
            ret.sum+=tr[i].sum;
            ret.sum%=mod;
        }
    }
    return ret;
}

void clear(int x)
{
    for(int i=x;i<=n;i+=lowbit(i))
    {
        tr[i].max=0;
        tr[i].sum=0;
    }
}

void solve(int l,int r)
{
    if(l==r) return ;
    int mid=l+r>>1;
    solve(l,mid);
    for(int i=mid+1;i<=r;i++)
    now[i]=point[i];
    sort(point+l,point+mid+1,cmpz);
    sort(point+mid+1,point+r+1,cmpz);
    for(int i=mid+1,top=l;i<=r;i++)
    {
        while(top<=mid&&point[top].z<=point[i].z)
        {
            insert(point[top].x,ans[point[top].id]);
            top++;
        }
        A ret=getsum(point[i].x);
        ret.max++;
        if(ret.max==ans[point[i].id].max)
        {
            ans[point[i].id].sum+=ret.sum;
            ans[point[i].id].sum%=mod;
        }
        else if(ret.max>ans[point[i].id].max)
        {
            ans[point[i].id]=ret;
        }
    }
    for(int i=l;i<=mid;i++) clear(point[i].x);
    for(int i=mid+1;i<=r;i++)
    point[i]=now[i];
    solve(mid+1,r);
}

int main()
{
//    freopen("in.txt","r",stdin);
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++)
        {
            scanf("%d %d %d",&point[i].x,&point[i].y,&point[i].z);
            point[i].id=i;
        }
        sort(point+1,point+1+n,cmpx);
        for(int i=1,xx=point[1].x-1,num=0;i<=n;i++)
        {
            if(point[i].x!=xx) num++,xx=point[i].x;
            point[i].x=num;
        }
        sort(point+1,point+1+n,cmpy);

        for(int i=1;i<=n;i++)
        {
            ans[i].max=1;
            ans[i].sum=1;
        }
        solve(1,n);
        A ret;
        ret.max=-1;
        for(int i=1;i<=n;i++)
        {
            if(ret.max==ans[i].max)
            {
                ret.sum+=ans[i].sum;
                ret.sum%=mod;
            }
            else if(ret.max<ans[i].max)
            {
                ret=ans[i];
            }
        }
        printf("%d %d\n",ret.max,ret.sum);
    }
    return 0;
}

抱歉!评论已关闭.