Tree
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v. Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. Write a program that will count how many pairs which are valid for a given tree. Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros. Output
For each test case output the answer on a single line.
Sample Input 5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0 Sample Output 8 Source |
我们知道一条路径要么过根结点,要么在一棵子树中,这启发了我们可以使用分治算法。
路径在子树中的情况只需递归处理即可,下面我们来分析如何处理路径过根结点的情况。
记Depth(i)表示点i到根结点的路径长度,Belong(i) =X ( X 为
根结点的某个儿子,且结点i 在以X 为根的子树内)。那么我们要统计
的就是:
满足Depth(i) +Depth( j) <=K 且Belong(i) !=Belong( j) 的(i, j) 个
数
= 满足Depth(i) + Depth( j) <=K的(i, j)个数
– 满足Depth(i) + Depth( j) <=K且Belong(i) =Belong( j)的(i, j)个数
而对于这两个部分,都是要求出满足Ai+Aj <=k的(i, j)的对数。
将A排序后利用单调性我们很容易得出一个O(N)的算法,所以我们
可以用O(N log N)的时间来解决这个问题。
综上,此题使用树的分治算法时间复杂度为( log ) 2 O N N 。
引自与国家集训队集训队论文
思路:
最容易想到的算法是:从每个点出发遍历整棵树,统计数对个数。
由于时间复杂度O(N^2),明显是无法满足要求的。
对于一棵有根树, 树中满足要求的一个数对所对应的一条路径,必然是以下两种情况之一:
1、经过根节点
2、不经过根节点,也就是说在根节点的一棵子树中
对于情况2,可以递归求解,下面主要来考虑情况1。
设点i的深度为Depth[i],父亲为Parent[i]。
若i为根,则Belong[i]=-1,若Parent[i]为根,则Belong[i]=i,否则Belong[i]=Belong[Parent[i]]。
这三个量都可以通过一次BFS求得。
我们的目标是要统计:有多少对(i,j)满足i<j,Depth[i]+Depth[j]<=K且Belong[i]<>Belong[j]
如果这样考虑问题会变得比较麻烦,我们可以考虑换一种角度:
设X为满足i<j且Depth[i]+Depth[j]<=K的数对(i,j)的个数
设Y为满足i<j,Depth[i]+Depth[j]<=K且Belong[i]=Belong[j]数对(i,j)的个数
那么我们要统计的量便等于X-Y
求X、Y的过程均可以转化为以下问题:
已知A[1],A[2],...A[m],求满足i<j且A[i]+A[j]<=K的数对(i,j)的个数
对于这个问题,我们先将A从小到大排序。
设B[i]表示满足A[i]+A[p]<=K的最大的p(若不存在则为0)。我们的任务便转化为求出A所对应的B数组。那么,若B[i]>i,那么i对答案的贡献为B[i]-i。
显然,随着i的增大,B[i]的值是不会增大的。利用这个性质,我们可以在线性的时间内求出B数组,从而得到答案。
综上,设递归最大层数为L,因为每一层的时间复杂度均为“瓶颈”——排序的时间复杂度O(NlogN),所以总的时间复杂度为O(L*NlogN)
然而,如果遇到极端情况——这棵树是一根链,那么随意分割势必会导致层数达到O(N)级别,对于N=10000的数据是无法承受的。因此,我们在每一棵子树中选择“最优”的点分割。所谓“最优”,是指删除这个点后最大的子树尽量小。这个点可以通过树形DP在O(N)时间内求出,不会增加时间复杂度。这样一来,即使是遇到一根链的情况时,L的值也仅仅是O(logN)的。
个人总结出算法步骤:
第一步找到重心所以结点维护以该树为根的结点总数。和去掉该结点后儿子的最大结点数。
用一个dfs计算出所需数据。
找到的根因该满足去掉后。儿子最大结点数因最小。防止退化。
第二步。从重心开始遍历该子树。算出各结点到根的距离。
第三部。进行两次计数。第一次计算该子树所有满足距离小于等于k的点对数。
第二次计算在同一子树重复的点对数。在减去重复的即得答案
#include <stdio.h> #include <string.h> #include <cstdlib> #include <algorithm> #include <iostream> using namespace std; struct node1//边结构。 { int to;//记录边的终点 int weight;//记录边的长度 node1 *next;//记录相同起点的另一条边 } edge[20010],*head[10010];//head[i]存以i为起点边的链表头 struct node2//点结构 { int sum;//记录子树的总结点数 int mson;//记录去掉该结点后所有子树的最大规模(规模即子树包含多少结点) } point[10010]; int vis[10010],pos[10010],msons[10010],dis[10010];//vis记录结点是否访问过。 int ptr,cnt,n,k,ans,root;//见下 int ma(int a,int b){ return a>b?a:b; } void adde(int f,int s,int w)//添加边 { edge[ptr].to=s;//记录终点 edge[ptr].weight=w;//记录两点间距离 edge[ptr].next=head[f];//建立边链表 head[f]=&edge[ptr++]; } void dfs(int f,int s)//f父结点。s子结点。计算找重心需要数据 { point[s].sum=1;//子树总结点数 point[s].mson=0;//子树最大规模 node1 *p=head[s]; while(p!=NULL)//遍历以s为起点的边 { if(p->to!=f&&!vis[p->to])//只计算子树防止访问父结点。结点要属于子树 { dfs(s,p->to);//递归计算 point[s].sum+=point[p->to].sum; point[s].mson=ma(point[s].mson,point[p->to].sum); } p=p->next; } pos[cnt]=s;//把算的结果存在数组里。pos记录标号。msons记录最大规模 msons[cnt++]=point[s].mson; } int froot(int s)//计算子树s重心 { int tsum,minson,minp,i;//tsum记录子树总规模。minson记录最小的子树最大规模。minp记录下标 cnt=0; dfs(0,s); tsum=point[s].sum; minson=0x3f3f3f3f;//初始为无穷大 for(i=0; i<cnt; i++) { msons[i]=ma(msons[i],tsum-point[pos[i]].sum); if(msons[i]<minson) { minson=msons[i]; minp=pos[i]; } } return minp; } void getdist(int f,int s,int d)//计算以f为父结点的子树各结点到根的距离 { node1 *p=head[s]; dis[cnt++]=d;//把距离记录在dis数组中 while(p!=NULL) { if(p->to!=f&&!vis[p->to]&&d+p->weight<=k) getdist(s,p->to,d+p->weight);//还是递归计算 p=p->next; } } void count1()//计算以重心为根所有满足条件的点对。 { int l,r; sort(dis,dis+cnt);//排序方便快速计算 l=0,r=cnt-1; while(l<r) { if(dis[l]+dis[r]<=k)//以l为起点l+1到r的点距离都满足 ans+=r-l,l++; else r--; } } void count2(int s)//排除在同一子树中的点对。 { int l,r; vis[s]=1;//去除该点 node1 *p=head[s]; while(p!=NULL) { if(!vis[p->to])//遍历边 { cnt=0; getdist(s,p->to,p->weight);//算出以s为根的距离 sort(dis,dis+cnt); l=0,r=cnt-1; while(l<r) { if(dis[l]+dis[r]<=k) ans-=r-l,l++; else r--; } } p=p->next; } } void solve(int f,int s) { root=froot(s);//先找到重心 cnt=0; getdist(0,root,0);//见算法步骤 count1(); count2(root); node1 *p=head[root]; while(p!=NULL) { if(p->to!=f&&!vis[p->to]) solve(root,p->to);//递归计算 p=p->next; } } int main() { int i,u,v,l; while(scanf("%d%d",&n,&k),n||k) { memset(head,0,sizeof head); memset(vis,0,sizeof vis); ptr=0; ans=0; for(i=1;i<n;i++) { scanf("%d%d",&u,&v); scanf("%d",&l); adde(u,v,l); adde(v,u,l); } solve(0,1); printf("%d\n",ans); } return 0; }