现在的位置: 首页 > 综合 > 正文

HDOJ 4638: Group

2014年10月10日 ⁄ 综合 ⁄ 共 1533字 ⁄ 字号 评论关闭

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4638

题目大意:

给定一个序列,是1~n的一个排列。

每次查询一个区间,问这个区间中含有多少个子集,自己中的序号是连续递增的(顺序可打乱)。

例如区间1,5,3,2可以分成1,2,3和5。

算法

不难看出一个点只能影响两个点,也只能被两个点影响。

我们先算出每个点前面有多少个点的数值跟它相邻,保存在一个数组tr[]里。

首先假设区间中的点都是间断的,那么段数等于区间长度。

当某个数与跟它相邻的数同处一个区间,那么两个段就会连接起来,段数减一。

然后把区间按左端点排列。

我们从序列的左端依次扫过去,

每处理一个点,就处理以它为做节点的所有询问。

显然我只要求出tr[]中区间右端点之前的所有值之和,就知道了段数要减少多少。

当然,一个点已经被处理过之后,剩下的询问,一定都不再覆盖它了,它的信息就是无用的,要消去它的影响。

也就是找到位置在它之后的与它相邻的点,把相应的tr[]减去1。

当然求和肯定不能是朴素求,是用树状数组。

代码如下:

#include <iostream>
#include <cstring>
#include <stdio.h>
#include <math.h>
#include <queue>
#include <vector>
#include <algorithm>
#include <stack>
#include <map>
using namespace std;

#define ll long long
#define inf 2e9
#define pii pair<int,int>
#define st first
#define nd second

const int MAXN=110000;
int tr[MAXN],a[MAXN],pos[MAXN],ans[MAXN];
vector<pair<int,int> > mm[MAXN];
int n,m;

void add(int x, int c)
{
    while(x<=n)
    {
        tr[x]+=c;
        x+=x&(-x);
    }
}

int sum(int x)
{
    int ret=0;
    while(x)
    {
        ret+=tr[x];
        x-=x&(-x);
    }
    return ret;
}

int main()
{
    int cas;
    scanf("%d",&cas);
    for(int T=1; T<=cas; T++)
    {
        scanf("%d%d",&n,&m);
        memset(tr,0,sizeof(tr));
        for(int i=1; i<=n; i++)
        {
            mm[i].clear();
            scanf("%d",&a[i]);
            pos[a[i]]=i;
        }
        for(int i=1; i<=m; i++)
        {
            int u,v;
            scanf("%d %d",&u,&v);
            mm[u].push_back(make_pair(i,v));
        }
        for(int i=1; i<=n; i++)
        {
            if(i>1)
            {
                if(pos[i-1]>pos[i])
                {
                    add(pos[i-1],1);
                }
            }
            if(i<n)
            {
                if(pos[i+1]>pos[i])
                {
                    add(pos[i+1],1);
                }
            }
        }
        for(int u=1; u<=n; u++)
        {
            for(int i=0; i<mm[u].size(); i++)
            {
                int v=mm[u][i].nd;
                ans[mm[u][i].st]=v-u+1-sum(v);
            }
            if(a[u]>1)
            {
                if(pos[a[u]-1]>u)
                {
                    add(pos[a[u]-1],-1);
                }
            }
            if(a[u]<n)
            {
                if(pos[a[u]+1]>u)
                {
                    add(pos[a[u]+1],-1);
                }
            }
        }
        for(int i=1; i<=m; i++)
        {
            printf("%d\n",ans[i]);
        }
    }
    return 0;
}

抱歉!评论已关闭.