现在的位置: 首页 > 综合 > 正文

poj 2486 (树形dp (好题))

2014年02月07日 ⁄ 综合 ⁄ 共 1705字 ⁄ 字号 评论关闭

题目链接:http://poj.org/problem?id=2486

最近做了些树形dp的题目, 感觉对这种将每棵子树都当做一个物品处理的树形背包题目理解还不够深刻, 果然这题就被破的很惨, 但好在想了很长时间想清楚了, 这题就是求从一颗带点权树的根节点出发, 走过k条边能获得的最大权值。 考虑某一个节点u, 肯定有一维状态表示从该节点开始还可以走多少条边, 但最优决策有可能是从u的某个子树v1下去再上来再到另一棵子树v2中,
所以我们可以定义状态dp[u][j][0]表示从点u出发最多走j条边并回到该点能获得的最大权值, dp[u][j][1]则表示可以不回到该点的。。。

我们不难得到下面两个转移方程:

1. dp[u][j][0] = max(dp[u][j][0], dp[v][j - k - 2][0] + dp[u][k][0]); (k <= j && v是u的一个子节点)

2. dp[u][j][1] = max(dp[u][j][1], dp[v][j - k - 1][[1] + dp[u][k][0]);

我一开始就是这样做的结果WA了而且测了discuss中的数据也过了, 后来看到discuss中有人说少了一个转移就WA, 我就开始想是不是有问题, 后来成功造出了一组数据, 就发现问题了, 其实我们在最外层枚举u的子节点v, 转移时的dp[u]***都是前i个子树的状态, 而且最后一维必须是0因为我们在考虑是否在当前这棵子树继续向下时必须使得在遍历之前的子树后回到点u, 但这样就有可能丢失了最优决策,
因为最优决策有可能就是在之前已经转移过的子树中所以我们有第三个转移


3.dp[u][j][1] = max(dp[u][j][1], dp[u][k][1] + dp[v][j - k - 2][0]);

另附一组数据

6 6
3 2 1 4 5 6
1 6
1 3
1 5
3 4
3 2

  ans: 19

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <vector>
#include <cmath>

using namespace std;

const int N = 105;
const int M = N << 1;

int head[N], next[M], to[M];
int W[N];
int dp[N][N << 1][2];
int tot, n, m;

void init() {
	for (int i = 1; i <= n; i++) {
		head[i] = -1;
	}
	tot = 0;	
}

void add(int u, int v) {
	to[tot] = v, next[tot] = head[u], head[u] = tot++;
	to[tot] = u, next[tot] = head[v], head[v] = tot++;
}

void dfs(int u, int fa) {

	for (int i = m; i >= 0; i--)
		dp[u][i][0] = dp[u][i][1] = W[u];

	for (int i = head[u]; i != -1; i = next[i]) {
		int v = to[i];
		if (v == fa) continue;
		dfs(v, u);
		for (int j = m; j >= 0; j--) {
			for (int k = 0; k <= j; k++) {
				if (j - k >= 2)
					dp[u][j][0] = max(dp[u][j][0], dp[u][k][0] + dp[v][j - k - 2][0]);
				if (j - k >= 1)
					dp[u][j][1] = max(dp[u][j][1], dp[u][k][0] + dp[v][j - k - 1][1]);
				if (j - k >= 2)
					dp[u][j][1] = max(dp[u][j][1], dp[u][k][1] + dp[v][j - k - 2][0]);
			}
		}
	}
}

int main() {
	int u, v;
	while (~scanf("%d%d", &n, &m)) {
		init();

		for (int i = 1; i <= n; i++)
			scanf("%d", W + i);

		for (int i = 0; i < n - 1; i++) {
			scanf("%d%d", &u, &v);
			add(u, v);
		}

		dfs(1, 0);
		
		printf("%d\n", dp[1][m][1]);
	}	
	return 0;
}

抱歉!评论已关闭.