现在的位置: 首页 > 综合 > 正文

poj1741——树上的分治,树dp

2013年11月06日 ⁄ 综合 ⁄ 共 1639字 ⁄ 字号 评论关闭

题意:求树上所有满足路径长度小于等于k的两点的对数。

思路还是很好理解的,对每一棵子树,以重心为根节点统计在不同子树且相互距离小于等于k的对数。就是实现的细节处理很多,主要都是递归处理的比较多。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <vector>
using namespace std;

const int maxn = 50000 + 10;
struct edge
{
    int to, dist, next;
}edges[maxn];
int edgehead[maxn], tot;
int n, k, vis[maxn], ans, root, num;
int mx[maxn], size[maxn], dis[maxn];

void addedge(int from, int to, int dist)
{
    edges[tot].dist = dist;
    edges[tot].to = to;
    edges[tot].next = edgehead[from];
    edgehead[from] = tot++;
}

void dfssize(int u, int fa)
{
    size[u] = 1;
    mx[u] = 0;
    for(int i = edgehead[u]; i != -1; i = edges[i].next)
    {
        int v = edges[i].to;
        if(v != fa && !vis[v])
        {
            dfssize(v, u);
            size[u] += size[v];
            mx[u] = max(mx[u], size[v]);
        }
    }
}

void dfsroot(int r, int u, int fa)
{
    mx[u] = max(mx[u], size[r] - size[u]);
    if(mx[u] < mx[root]) root = u;
    for(int i = edgehead[u]; i != -1; i = edges[i].next)
    {
        int v = edges[i].to;
        if(v != fa && !vis[v]) dfsroot(r, v, u);
    }
}

void dfsdis(int u, int d, int fa)
{
    dis[num++] = d;
    for(int i = edgehead[u]; i != -1; i = edges[i].next)
    {
        int v = edges[i].to;
        if(v != fa && !vis[v])  dfsdis(v, d + edges[i].dist, u);
    }
}

int calc(int u, int d)
{
    int ret = 0;
    num = 0;
    dfsdis(u, d, 0);
    sort(dis, dis + num);
    int i = 0, j = num - 1;
    while(i < j)
    {
        while(dis[i] + dis[j] > k && i < j) j--;
        ret += j - i;
        i++;
    }
    return ret;
}

void dfs(int u)
{
    root = u;
    dfssize(u, 0);
    dfsroot(u, u, 0);
    ans += calc(root, 0);
    vis[root] = 1;
    for(int i = edgehead[root]; i != -1; i = edges[i].next)
    {
        int v = edges[i].to;
        if(!vis[v])
        {
            ans -= calc(v, edges[i].dist);
            dfs(v);
        }
    }
}

void prework()
{
    memset(vis, 0, sizeof(vis));
    memset(edgehead, 0xff, sizeof(edgehead)); tot = 0;
    for(int i = 0; i < n - 1; ++i)
    {
        int a, b, c;
        scanf("%d %d %d", &a, &b, &c);
        addedge(a, b, c);
        addedge(b, a, c);
    }
}

void solve()
{
    ans = 0;
    dfs(1);
    printf("%d\n", ans);
}

int main()
{
    while(~scanf("%d %d", &n, &k))
    {
        if(!n && !k) break;
        prework();
        solve();
    }
    return 0;
}

抱歉!评论已关闭.