现在的位置: 首页 > 综合 > 正文

HDU 3053 Group Travel

2013年09月07日 ⁄ 综合 ⁄ 共 1905字 ⁄ 字号 评论关闭

基础的四边形不等式优化DP,原本的方程为:
m[i][j]=min{m[i-1][l]+w[l+1][j};  (i-1<=l<=j-1).
其中,m[i][j]表示前j个人用i个下车点的最优解,w[i][j]表示i到j共用1个下车点的最小距离和(此时显然是在(i+j)/2下车最好)。
因为l要枚举从i-1到j-1复杂度达到了O(k*n*n),n和k的上限都是3000,肯定超时。

我们可以用s[i][j]记下每次m[i][j]取的断点k,这个可以用来构造四边形不等式对于s[i][j]的单调性进行讨论。
首先对于i<=i'<=j<=j',w有w[i][j]+w[i'][j']<=w[i'][j]+w[i][j']。 -------①
(我这里暂时略去这一步证明 ,可以用数学归纳法证明。)
那么我假设s[i][j]<=s[i][j+1]。
用mk[i][j]表示去k为端点所产生的状态值(不一定是最优那个)。
证明这一个不等式,只需证明对于i<=k<k'<j并且有mk[i][j]>=mk'[i][j]时,有mk[i][j+1]>=mk'[i][j+1]能够成立(数学归纳法)。
首先我们根据k+1<k'+1<=j<j+1有四边形不等式:
w[k+1][j]+w[k'+1][j+1]<=w[k'+1][j]+w[k+1][j+1](带入①即可)

这时两边同时加上m[i-1][k]+m[i-1][k']即可得到:
m[i-1][k]+w[k+1][j]+m[i-1][k']+w[k'+1][j+1]<=m[i-1][k']+w[k'+1][j]+m[i-1][k]+w[k+1][j+1]
也就是mk[i][j]+mk'[i][j+1]<=mk'[i][j]+mk[i][j+1]
移位后mk[i][j]-mk'[i][j]<=mk[i][j+1]-mk'[i][j+1]
因为mk[i][j]-mk'[i][j]>=0
所以mk[i][j+1]-mk'[i][j+1]>=0.
那么得证。

同样的,对于s[i-1][j]<=s[i][j]能类似地证明。(这里不赘述了)

那么我们有结论s[i-1][j]<=s[i][j]<=s[i][j+1],原状态转移方程可改进为:
m[i][j]=min{m[i-1][l]+w[l+1][j};  (s[i-1][j]<=l<=s[i][j+1]).复杂度O(n^2)(这个复杂度的证明比较简单,就不赘述了)
如果顺推则每次没有s[i][j+1]的值,这个时候,我们只要设出s[i][n+1]的初值逆推就好了。

这一题卡了内存,如果全开MAX*MAX的话会超内存。

考虑到m[i][j]只与m[i-1][j]有关,w[i][j]可推出简单公式,并且s[i][j]也只和s[i-1][j]和s[i][j+1]有关,可以使用滚动数组。
至此,题目完美解决。

附上渣代码:

#include <cstdio>
#include <algorithm>
#define MAX 3001
#define INF 0x3fffffff
using namespace std;
int n,k;
int d[MAX],sum[MAX];
int s[2][MAX],m[2][MAX];
int to(int i,int j)
{
    int l=(i+j)/2;
    return (l*2-i-j)*d[l]+sum[i-1]+sum[j]-sum[l-1]-sum[l];
}
int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--){
        scanf("%d%d",&n,&k);
        int i,j,l;
        for(i=1;i<=n;i++){
            scanf("%d",d+i);
        }
        sort(d+1,d+n+1);
        d[0]=0;
        sum[0]=0;
        for(i=1;i<=n;i++){
            sum[i]=sum[i-1]+d[i];
        }
        for(i=1;i<=n;i++){
            l=(i+1)>>1;
            m[0][i]=(l*2-i-1)*d[l]+sum[i]-sum[l-1]-sum[l];
            s[0][i]=1;
        }
        int t=0;
        for(i=2;i<=k;i++){
            m[t^1][i]=0;
            s[t^1][n+1]=n-1;
            for(j=n;j>=i+1;j--){
                m[t^1][j]=INF;
                for(int v=s[t][j];v<=s[t^1][j+1];v++){
                    int x=to(v+1,j);
                    if(m[t][v]+x<m[t^1][j]){
                        m[t^1][j]=m[t][v]+x;
                        s[t^1][j]=v;
                    }
                }
            }
            t^=1;
        }
        printf("%d\n",m[t][n]);
    }
    return 0;
}

抱歉!评论已关闭.