现在的位置: 首页 > 综合 > 正文

【树dp】 缩地(d步内最大权值)

2018年04月13日 ⁄ 综合 ⁄ 共 5573字 ⁄ 字号 评论关闭

这个题目的递推方程有两种:

 第一种:

  •    V[0][root][0]=value[root]  //0步内,权值为当前节点value[root]
  •    V[1][root][0]=value[root]  //0步内,权值为当前节点value[root]
  •    for each son:
    • for len in V[0/1][root][len]:

      • V[1][root][len] => tmp[1][root][len]   //不经过son
      • V[0][root][len] + V[1][son][k-dist]  => tmp[1][root][len+k]
      • V[1][root][len] + V[0][son][k-2*dist] => tmp[1][root][len+k]
    • tmp[1][root]=>unique=>V[1][root]
    • for  len in V[0][root][len]: 

      • V[0][root][len] => tmp[0][root][len]   //不经过son
      • V[0][root][len]+V[0][son][k-2*dist] => tmp[0][root][len+k] //经过son
    • tmp[0][root]=>unique=>V[0][root]
注意问题:
   1  修正dp函数原型,防止父子不分(添加fa父节点编号到dp): dp(tree,value,maxV,fa,idx)
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
   
typedef pair<int,int> Pair;
bool cmp(const Pair &x,const Pair &y){//if x is before y?
    if(x.first!=y.first) return x.first<y.first;
    return x.second>=y.second;
}
void myUnique(vector<Pair> &p){
    vector<Pair> tmp;
    int lst=-1,n=p.size();
    for(int i=0;i<n;i++){
        if(p[i].second<=lst) continue;
        else {
            lst=p[i].second;//@error: remove pair<x1,y1> where y1<=y0 with <x0,y0> exists
            tmp.push_back(p[i]);
        }
    }
    /*for(int i=0;i<tmp.size();i++){
      Pair &p=tmp[i];
      cout<<p.first<<":"<<p.second<<" ";
    }cout<<endl;
    */
    p=tmp;
}
int dp(vector<vector<Pair> > &tree,vector<int> &value, 
    vector<vector<vector<Pair> > > &maxV, int fa, int idx, int d){//@error:fa should be added into para-list
    for(int i=0;i<tree[idx].size();i++){
        int v=tree[idx][i].first;
        if(v==fa) continue;
        dp(tree,value,maxV,idx,v,d);
    }
    maxV[0][idx].push_back(Pair(0,value[idx]));
    maxV[1][idx].push_back(Pair(0,value[idx])); //@error
    for(int i=0;i<tree[idx].size();i++){
        vector<Pair> tmp=maxV[1][idx];
        int v=tree[idx][i].first,dist=tree[idx][i].second;
        if(v==fa) continue;
        
        for(int j=0;j<maxV[1][idx].size();j++){//@error: i<maxV...
            Pair &p=maxV[1][idx][j];
            for(int k=0;k<maxV[0][v].size();k++){
                Pair &q=maxV[0][v][k];
                if(p.first+q.first+2*dist>d) break;
                tmp.push_back(Pair(p.first+q.first+2*dist,
                    p.second+q.second));
            }
        }
        for(int j=0;j<maxV[0][idx].size();j++){
            Pair &p=maxV[0][idx][j];
            for(int k=0;k<maxV[1][v].size();k++){
               Pair &q=maxV[1][v][k];
               if(p.first+q.first+dist>d) break;
               tmp.push_back(Pair(p.first+q.first+dist,
                    p.second+q.second));
            }
        }
        sort(tmp.begin(),tmp.end());
        myUnique(tmp);
        maxV[1][idx]=tmp;
    //}//@error
    //for(int i=0;i<tree[idx].size();i++){//@error
        tmp=maxV[0][idx];
        for(int j=0;j<maxV[0][idx].size();j++){
            Pair &p=maxV[0][idx][j];
            for(int k=0;k<maxV[0][v].size();k++){
               Pair &q=maxV[0][v][k];
               if(p.first+q.first+2*dist>d) break;
               tmp.push_back(Pair(p.first+q.first+2*dist,
                    p.second+q.second));
            }
        }
        sort(tmp.begin(),tmp.end());
        myUnique(tmp);
        maxV[0][idx]=tmp;
    }
}

int main(){
    int n;cin>>n;
    vector<int> v(n,0);
    vector<vector<Pair> > tree(n,vector<Pair>());
    for(int i=0;i<n;i++) cin>>v[i];
    for(int i=0;i<n-1;i++){
        int a,b,w;cin>>a>>b>>w;
        a--,b--;
        tree[a].push_back(Pair(b,w));
        tree[b].push_back(Pair(a,w));
    }
    int m;cin>>m;
  
    vector<int> q(m,0); int mm=0;
    for(int i=0;i<m;i++){
        cin>>q[i];mm=max(mm,q[i]);
    }
    vector<vector<vector<Pair> > > maxV(2,vector<vector<Pair> >  (n,vector<Pair>()));
    dp(tree,v,maxV,-1,0,mm);

    for(int i=0;i<m;i++){
        vector<Pair> &s=maxV[1][0];//error:maxV[1][0] not maxV[1][n-1]
        vector<Pair>::iterator it=upper_bound(s.begin(),s.end(),
            Pair(q[i],-1),cmp);//@error: Pair(qi,-1) rather than pair(qi,0)
        it--;
       //cout<<q[i]<<":"<<endl;
       int res=it->second;
       cout<<res<<endl;//max value at d=p[i]
    }
    return 0;
}

第二种递推方法:

  • num[0][root][value]=0 //到达权值value只需要0步
  • num[1][root][value]=0 
  • for each son:
  •       for value=totalValue(current) to 0:
  • for s=1 to value-1: //value-s in [1,totalValue(son) ]
  • if value-s > totalValue(son) : continue
  • cmin( num[1][root][value], num[1][root][s]+num[0][son][value-s]+2*dist )
  • cmin( num[1][root][value], num[0][root][s]+num[1][son][value-s]+1*dist )
  • cmin( num[0][root][value], num[0][root][s]+num[0][son][value-s]+2*dist )
缺点: num[][root][value]的value下标存在冗余
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<string>
#include<set>
#include<vector>
#include<map>
using namespace std;
#define AA first
#define BB second
#define MP make_pair
#define PB push_back
#define cmax(x,y) x=max(x,y)
#define cmin(x,y) x=min(x,y)
typedef long long LL;
typedef pair<int,int> PII;
const LL MOD=1000000007;

const int N=105;

int n;
LL dp[N][2*N][2];
int top[N],vv[N];
struct edgenode
{
    int u,v,c,next;
}edge[N<<1|1];
int last[N],tot;
int ans[1000010];

void init()
{
    tot=0;
    memset(last,-1,sizeof last);
}
void addedge(int u,int v,int c)
{
    edge[tot].u=u;
    edge[tot].v=v;
    edge[tot].c=c;
    edge[tot].next=last[u];
    last[u]=tot++;
}

int getdp(int u,int fa)
{
    int i,j,k,v;
    for(i=0;i<=n*2;i++)
        dp[u][i][0]=dp[u][i][1]=MOD;
    dp[u][0][0]=0;
    dp[u][0][1]=0;
    dp[u][vv[u]][0]=0;
    dp[u][vv[u]][1]=0;
    top[u]=vv[u];
    for(j=last[u];j!=-1;j=edge[j].next)
    {
        v=edge[j].v;
        if(v==fa)continue;
        getdp(v,u);
        for(i=top[u]+top[v];i>=0;i--)
        {
            for(k=0;k<=top[v] && k<=i;k++)
            {
                if(i-k>top[u])continue;
                cmin(dp[u][i][0],dp[u][i-k][0]+2*edge[j].c+dp[v][k][0]);
                cmin(dp[u][i][1],dp[u][i-k][0]+edge[j].c+dp[v][k][1]);
                cmin(dp[u][i][1],dp[u][i-k][1]+2*edge[j].c+dp[v][k][0]);
            }
        }
        top[u]+=top[v];
    }
//    printf("u=%d\n",u);
//    for(i=0;i<=top[u];i++)
//    {
//        if(dp[u][i][0]>=MOD)printf("%3d:",-1);
//        else printf("%3d:",dp[u][i][0]);
//        if(dp[u][i][1]>=MOD)printf("%-3d\t",-1);
//        else printf("%-3d\t",dp[u][i][1]);
//    }
//    printf("\n");
}

int main()
{
//    freopen("A.txt","r",stdin);
//    freopen("Amy.txt","w",stdout);
    int i,j,q;
    scanf("%d",&n);
    for(i=1;i<=n;i++)
        scanf("%d",&vv[i]);
    init();
    for(i=1;i<n;i++)
    {
        int s,t,c;
        scanf("%d%d%d",&s,&t,&c);
        addedge(s,t,c);
        addedge(t,s,c);
    }
    getdp(1,-1);
    for(i=1;i<=top[1];i++)
        cmin(dp[1][i][1],dp[1][i][0]);
    for(i=0;i<=1000000;i++)
        ans[i]=0;
    ans[0]=vv[1];
    for(i=0;i<=top[1];i++)
        if(dp[1][i][1]<=1000000)
        {
            cmax(ans[dp[1][i][1]],i);
        }
    for(i=1;i<=1000000;i++)
        cmax(ans[i],ans[i-1]);
    scanf("%d",&q);
    for(j=1;j<=q;j++)
    {
        int dis;
//        int ans=0;
        scanf("%d",&dis);
        printf("%d\n",ans[dis]);
//        for(i=1;i<=top[1];i++)///
//        {
//            int tp=min(dp[1][i][0],dp[1][i][1]);
//            if(tp<=dis)ans=i;
//        }
//        printf("%d\n",ans);
    }
    return 0;
}
/*
3
0 1 1
1 2 5
1 3 3
3
3
10
11

7
2 1 1 1 1 2 0
1 2 1
1 3 1
1 4 1
1 5 1
1 6 1
1 7 1
*/

抱歉!评论已关闭.