现在的位置: 首页 > 综合 > 正文

POJ 1741 Tree 树的分治(点分治)

2018年01月19日 ⁄ 综合 ⁄ 共 1907字 ⁄ 字号 评论关闭

题目大意:给出一颗无根树和每条边的权值,求出树上两个点之间距离<=k的点的对数。

思路:树的点分治。利用递归和求树的重心来解决这类问题。因为满足题意的点对一共只有两种:

1.在以该节点的子树中且不经过该节点。

2.路径经过该节点。

对于第一种点,我们递归处理;第二种点,我们可以将所有子树的节点到这个子树的根节点的距离处理出来,然后排序处理出满足要求的点对的个数。

按照正常的树的结构来分割子树,这样的做法的时间复杂度肯定是不好看的,为了让子树大小尽量相同,我们每次处理这个子树前找到这个子树的重心,把这个重心当为根,然后在分割子树,这样时间复杂度最坏会降到O(nlog^2n)。

CODE:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define MAX 20010
#define INF 0x3f3f3f3f
using namespace std;

int points,edges,k;
int head[MAX],total;
int next[MAX << 1],aim[MAX << 1],length[MAX << 1];

int cnt[MAX],c;  			//每个子树中经过根节点的满足条件的对数
int size[MAX],_size,dis[MAX],p;
int _total;
bool v[MAX];

inline void Initialize();
inline void Add(int x,int y,int len);
void Work(int x);
void GetRoot(int x,int last);
inline int Count(int x,int len);
void GetDis(int x,int last,int len);

int main()
{
	while(scanf("%d%d",&points,&k),points + k) {
		Initialize();
		for(int x,y,z,i = 1;i < points; ++i) {
			scanf("%d%d%d",&x,&y,&z);
			Add(x,y,z),Add(y,x,z);
		}
		Work(1);
		int ans = 0;
		for(int i = 1;i <= points; ++i)
			ans += cnt[i];
		printf("%d\n",ans);
	}
	return 0;
}

inline void Initialize()
{
	total = 0;
	memset(head,0,sizeof(head));
	memset(v,false,sizeof(v));
}

inline void Add(int x,int y,int len)
{
	next[++total] = head[x];
	aim[total] = y;
	length[total] = len;
	head[x] = total;
}

void Work(int x)
{
	_size = INF;
	_total = size[x] ? size[x]:points;
	GetRoot(x,0);
	x = c;
	v[x] = true;
	cnt[x] = Count(x,0);
	for(int i = head[x];i;i = next[i]) {
		if(v[aim[i]])	continue;
		cnt[x] -= Count(aim[i],length[i]);
		Work(aim[i]);
	}
}

void GetRoot(int x,int last)
{
	size[x] = 1;
	int max_size = 0;
	for(int i = head[x];i;i = next[i]) {
		if(v[aim[i]] || aim[i] == last)	continue;
		GetRoot(aim[i],x);
		size[x] += size[aim[i]];
		max_size = max(max_size,size[aim[i]]);
	}
	max_size = max(max_size,_total - size[x]);
	if(max_size < _size)
		_size = max_size,c = x;
}

inline int Count(int x,int len)
{
	int re = 0;
	p = 0;
	GetDis(x,0,len);
	sort(dis,dis + p);
	int l = 0,r = p - 1;
	while(l < r) {
		if(dis[l] + dis[r] <= k)
			re += (r - l),l++;
		else	r--;
	}
	return re;
}

void GetDis(int x,int last,int len)
{
	dis[p++] = len;
	for(int i = head[x];i;i = next[i]) {
		if(aim[i] == last || v[aim[i]])	continue;
		GetDis(aim[i],x,len + length[i]);
	}
}

抱歉!评论已关闭.