现在的位置: 首页 > 综合 > 正文

POJ 3321 Apple Tree(树状数组+DFS)

2014年04月05日 ⁄ 综合 ⁄ 共 1878字 ⁄ 字号 评论关闭

http://poj.org/problem?id=3321

一道搁置了好多天的树状数组

题意:有一颗苹果树,主干是1,下面会有分支,每个分支一个编号一直到n。每个分支只能有一个苹果或没有苹果。初始状态是每个分支一个苹果。示例:

   5

   1 2

   1 3

   3 5

   3 4

它形成的树是这样的:


现在每个分支的苹果都是1。

下面是两种操作,Q 和C

C   j  的意思是如果 j 这个枝子上面有苹果就摘下来,如果没有,那么就会长出新的一个

Q  j  就是问 j 这个叉及其下面的苹果总数。如 Q 3,那么答案是3,因为3及其下面的分支共有三个苹果。


思路:更新某个节点的值,询问区间的和。这是一道树状数组的题目,但是树状数组对应的是一维数组。那么应该先把这棵树转化为一维数组。用DFS遍历数的同时记录每个节点的起始和结束的编号。相当于时间戳。比如上图遍历的顺序是 1 2 3 5 4,那么其对应的编号是 1  2 3 4 5。用start[]和end[]记录每个节点的起始时间和结束时间。

start[1] = 1,end[1] = 5(代表1上的树枝是1~5),同理start[2] = 2,end[2] = 2,start[3] = 3,end[3] = 5,start[5] = 4,end[5] = 4,start[4] = 5,end[4] = 5.这就转化为一维数组了。

对于 C j:先判断j分支上是否有苹果,就是利用线段树,只需计算sum( start[j] ) - sum( start[j] -1)是否等于1,是1说明有苹果,更新该点,即减去一个苹果,否则就加上一个苹果,属于单点更新问题,注意是对start[j]更新。

对于 Q j:询问j和j的子树上的苹果数,就是区间和.即 sum( end[j] ) - sum( start[j] - 1)。

最后不能用vector,超时。手写结构体数组。

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <vector>
using namespace std;

const int maxn = 100010;

struct node
{
	int data;
	struct node *next;
};
struct node edge[maxn];

int start[maxn],end[maxn];
int n,m,dep;
int c[maxn];

int Lowbit(int x)
{
	return x&(-x);
}

void dfs(int u)//dfs找每个节点的时间戳,起始时间和结束时间
{
	start[u] = ++dep;
	struct node * tmp = edge[u].next;
	while(tmp)
	{
		if(start[tmp->data] == 0)
			dfs(tmp->data);
		tmp = tmp->next;
	}
	end[u] = dep;
}

int sum(int end)
{
	int s = 0;
	while(end > 0)
	{
		s += c[end];
		end -= Lowbit(end);
	}
	return s;
}

void update(int pos, int num)
{
	while(pos <= n)
	{
		c[pos] += num;
		pos += Lowbit(pos);
	}
}

int main()
{
	int u,v;
	scanf("%d",&n);

	for(int i = 0; i < n-1; i++)
	{
		scanf("%d %d",&u,&v);

		struct node *P = new struct node;
		P->data = v;
		P->next = edge[u].next;
		edge[u].next = P;

		struct node *Q = new struct node;
		Q->data = u;
		Q->next = edge[v].next;
		edge[v].next = Q;
	}

	memset(c,0,sizeof(c));
	memset(start,0,sizeof(start));
	memset(end,0,sizeof(end));
	dep = 0;
	dfs(1);

	for(int i = 1; i <= n; i++)
	{
		update(i,1);
	}

	scanf("%d",&m);
	char str[2];
	int x,res1,res2,res3;
	while(m--)
	{
		scanf("%s %d",str,&x);
		res1 = sum(start[x]);
		res2 = sum(start[x]-1);
		res3 = sum(end[x]);

		if(str[0] == 'C')
		{
			if(res1 - res2 == 1)
				update(start[x],-1);
			else update(start[x],1);
		}

		else
		{
			printf("%d\n",res3-res2);
		}
	}
	return 0;

}


抱歉!评论已关闭.