现在的位置: 首页 > 综合 > 正文

POJ 1741 Tree 树形DP(分治)

2018年04月24日 ⁄ 综合 ⁄ 共 2825字 ⁄ 字号 评论关闭

链接:http://poj.org/problem?id=1741

题意:给出一棵树,节点数为N(N<=10000),给出N-1条边的两点和权值,给出数值k,问树上两点最短距离小于k的点对有多少个。

思路:拿到题的第一反应是LCA问题,不过细一想询问次数极限情况可以达到10000*5000次,即使用Tarjan也是超时没商量的。

2009年国家队论文提供了树的分治思想,对于本题就是树的分治的点分治的应用。每次找到能使含节点最多的子树的节点最少的根分而治之,同样方式分别处理它的所有子树,知道处理到单独的节点。这样可以使复杂度最低化。(具体找根的方式和上一题思想类似,传送门:http://blog.csdn.net/ooooooooe/article/details/38981129 )

对于每个根,记录其他点到根的距离,我要找出它的两个子节点分别处于它的不同子树并且距离小于k的情况数,不记录在同一子树的情况是因为对于同一子树的两个节点,它们的最短距离并不是它们到根的距离之和,而且如果对于每个根都记录处于相同子树的节点,那么会记重复情况。具体处理的方式是先记录到根距离之和小于等于k的点对的数量,然后对于每个子树分别除去在子树中到根距离之和小于等于k的点对的数量。

资料:http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###

代码:

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <ctype.h>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <vector>
#define eps 1e-8
#define INF 1000000000
#define maxn 10005
#define PI acos(-1.0)
#define seed 31//131,1313
typedef long long LL;
typedef unsigned long long ULL;
using namespace std;
int dp[maxn][2],from[maxn],head[maxn],all[maxn],top,tot,ans=0;
int T,k,x,y,z;
bool vis[maxn];
void init()
{
    memset(head,-1,sizeof(head));
    memset(dp,0,sizeof(dp));
    memset(vis,0,sizeof(vis));
    ans=0;
    top=0;
}
struct Edge
{
    int v,w;
    int next;
} edge[maxn*2];
void add_edge(int u,int v,int w)
{
    edge[top].v=v;
    edge[top].w=w;
    edge[top].next=head[u];
    head[u]=top++;
}
void dfs_first(int u,int f)
{
    from[u]=u;
    dp[u][0]=dp[u][1]=0;
    for(int i=head[u]; i!=-1; i=edge[i].next)
    {
        int v=edge[i].v,w=edge[i].w;
        if(v==f||vis[v])
            continue;
        dfs_first(v,u);
        if(dp[v][0]+w>dp[u][0])
        {
            from[u]=v;
            dp[u][1]=dp[u][0];
            dp[u][0]=dp[v][0]+w;
        }
        else if(dp[v][0]+w>dp[u][1])
            dp[u][1]=dp[v][0]+w;
    }
}
void dfs_second(int u,int f,int k,int &root,int &deep)
{
    if(u!=f)
        if(from[f]!=u)
        {
            if(dp[f][0]+k>dp[u][0])
            {
                from[u]=f;
                dp[u][1]=dp[u][0];
                dp[u][0]=dp[f][0]+k;
            }
            else if(dp[f][0]+k>dp[u][1])
                dp[u][1]=dp[f][0]+k;
        }
        else
        {
            if(dp[f][1]+k>dp[u][0])
            {
                from[u]=f;
                dp[u][1]=dp[u][0];
                dp[u][0]=dp[f][1]+k;
            }
            else if(dp[f][1]+k>dp[u][1])
                dp[u][1]=dp[f][1]+k;
        }
    if(dp[u][0]<deep)
    {
        deep=dp[u][0];
        root=u;
    }
    for(int i=head[u]; i!=-1; i=edge[i].next)
    {
        int v=edge[i].v,w=edge[i].w;
        if(v==f||vis[v])
            continue;
        dfs_second(v,u,w,root,deep);
    }
}
void dfs_third(int u,int f,int val)
{
    all[tot++]=val;
    for(int i=head[u]; i!=-1; i=edge[i].next)
    {
        int v=edge[i].v,w=edge[i].w;
        if(!vis[v]&&v!=f)
            dfs_third(v,u,val+w);
    }
}
void dfs(int u)
{
    int root=-1,deep=INF;
    memset(dp,0,sizeof(dp));
    dfs_first(u,u);
    dfs_second(u,u,0,root,deep);
    tot=0;
    dfs_third(root,root,0);
    sort(all,all+tot);
    int a=0,b=tot-1;
    while(a<b)
    {
        while(all[a]+all[b]>k&&b>a)
            b--;
        ans+=b-a;
        a++;
    }
    vis[root]=1;
    for(int i=head[root]; i!=-1; i=edge[i].next)
    {
        tot=0;
        int v=edge[i].v,w=edge[i].w;
        if(!vis[v])
        {
            dfs_third(v,0,w);
            sort(all,all+tot);
            a=0,b=tot-1;
            while(a<b)
            {
                while(all[a]+all[b]>k&&b>a)
                    b--;
                ans-=b-a;
                a++;
            }
            dfs(v);
        }
    }
}
int main()
{
    while(scanf("%d%d",&T,&k))
    {
        if(!T&&!k)
        break;
        init();
        for(int i=0; i<T-1; i++)
        {
            scanf("%d%d%d",&x,&y,&z);
            add_edge(x,y,z);
            add_edge(y,x,z);
        }
        dfs(1);
        printf("%d\n",ans);
    }
    return 0;
}

抱歉!评论已关闭.