现在的位置: 首页 > 综合 > 正文

传说中的斜率优化

2013年10月12日 ⁄ 综合 ⁄ 共 4761字 ⁄ 字号 评论关闭

对于形如f[i]由其之前若干f[j]决定的dp来说,优化大致有单调队列,线段树,四边形不等式和斜率优化
单调队列一般是累加关系,可以提前直接预先计算后面的dp值,从而决策。而四边形不等式则多用于第i个目标支配k到j这一段。至于乘除关系的则可以使用斜率优化。

当我们想知道两个决策点哪一个更优时,可以将两个决策点所产生的决策表达式写出来。设为f[j1]和f[j2]。令j1<j2,则如果f[j1]<f[j2],那么可以得到关于j1,j2和i的表达式。

由于对于f[i]的最优解为k时,所有j<k满足slope[j,k]<=s[i],又因为s[i+1]为递增的,所以slope[j,k]<s[i+1]。因此f[i+1]的决策k1>=k。我们可以看出这类dp的决策序列一般是递增的。

那么我们看维护的决策序列。f[i]在决策之时选取了j1,则slope[j1,j2]>s[i].之后对f[i+1]进行决策时,最优决策点向后移动,那么就有slope[j1,j2]<=s[i+1].但是对于f[i+1]的最优取值集合来说,slope[j2,j3]>s[i+1]。显然,slope[j2,j3]>slope[j1,j2],因此从这个角度也可以说明应该维护斜率递增序列。

可用于斜率优化的表达式整理出后左边必须形如(g1[j1]-g1[j2])/(g2[j1]-g2[j2])的形式。而且左边必须可以不含带i的项(否则无法维护队列)。由于式子本身具有单调性,分母一般为负。如果为正,则后面维护队列的时候取值标准也应该刚好相反。

由于我们想要找到(g1[j1]-g1[j2])/(g2[j1]-g2[j2])>S(i)的项,故维护队列的时候满足斜率递增的顺序。前面的项如果(g1[j1]-g1[j2])/(g2[j1]-g2[j2])<=S(i)则离开队列,因为这表示j2比j1优。而之所以维护斜率上升,是为了找到一个突变点,在此点处,(g1[j1]-g1[j2])/(g2[j1]-g2[j2])>S(i)。说明j1优于j2。这时候j1为抛物线的最低点,也就是最优值的决策点。

斜率优化用的时候要自己推一下表达式,比傻瓜式地直接套用数据结构优化要好些。呵呵。

下面有两题,一个是HNOI2008的题,这个是裸的,比较基础。另一个是pku3709,北大月赛的题目,维护队列的时候有个小技巧。

因为在斜率还没有超过S(i)的时候,最优决策值一定是取最后面的一个。而pku3709在维护队列的时候,不能算完了直接加进去。因为对于i+1来说,i是不能拿来更新的,而i的加入则可能会将本来能用来决策的j给去掉。

所以说,先不要放进去,需要它的时候再放。什么时候需要呢?就是在k个位置之后了。所以说计算i之前先把i-k放进队列,再进行相关计算和维护。

/*HNOI2008Toy*/

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn = 50010;
int n;
long long l;
long long f[maxn],c[maxn];
int q[maxn];
inline bool slide(int x,int y,int z)
{
     long long sx = f[x]+(x+c[x])*(x+c[x]);
     long long sy = f[y]+(y+c[y])*(y+c[y]);
     long long sz = f[z]+(z+c[z])*(z+c[z]);
     if((sx-sy)*(y+c[y]-z-c[z])<(sy-sz)*(x+c[x]-y-c[y]))return true;
     return false;
}
int main()
{
   int i,j;
  // freopen("test2.txt","r",stdin);
  // freopen("out1.txt","w",stdout);
   while(scanf("%d%I64d",&n,&l)!=EOF)
   {
      c[0] = 0;
      for(i = 1;i<=n;i++)
        scanf("%I64d",&c[i]),c[i]+=c[i-1];
      int head = 0,tail = 0;
      q[tail++] = 0;      //0这个决策一定要放进去!!
      f[0] = 0;
      f[1] = (c[1]-l)*(c[1]-l);
      q[tail++] = 1;
      for(i = 2;i<=n;i++)
      {
         long long T = i+c[i]-l-1;
         while(1)
         {
            if(head+1>=tail)break;
            long long s = f[q[head]]+(q[head]+c[q[head]])*(q[head]+c[q[head]]);
            long long t = f[q[head+1]]+(q[head+1]+c[q[head+1]])*(q[head+1]+c[q[head+1]]);
            long long x = q[head]+c[q[head]]-q[head+1]-c[q[head+1]];
            if(s-t<2*T*x)break;
            head++;
         }
         if(head+1<tail)
         {
            f[i] = f[q[head]]+(i-q[head]-1+c[i]-c[q[head]]-l)*(i-q[head]-1+c[i]-c[q[head]]-l);
            while(head+1<tail&&!slide(q[tail-2],q[tail-1],i))tail--;
            q[tail++] = i;
         }
         else
         {
            if(f[q[head]]+(i-q[head]-1+c[i]-c[q[head]]-l)*(i-q[head]-1+c[i]-c[q[head]]-l)<=(c[i]+i-1-l)*(c[i]+i-1-l))
              f[i] = f[q[head]]+(i-q[head]-1+c[i]-c[q[head]]-l)*(i-q[head]-1+c[i]-c[q[head]]-l);
            else
            {
               head++;
               f[i] = (c[i]+i-1-l)*(c[i]+i-1-l);
            }
            q[tail++] = i;
         }
      }
      printf("%I64d\n",f[n]);
   }
   return 0;
}

/*pku3709*/

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn = 500010;
int T,n,k;
long long f[maxn],a[maxn],sum[maxn];
int q[maxn];
inline bool slide(int x,int y,int z)
{
   long long xv = f[x]-sum[x]+x*a[x+1];
   long long yv = f[y]-sum[y]+y*a[y+1];
   long long zv = f[z]-sum[z]+z*a[z+1];
   if((xv-yv)*(a[y+1]-a[z+1])<(yv-zv)*(a[x+1]-a[y+1]))return true;
   return false;
}
int main()
{
    int i,j;
  //  freopen("test2.txt","r",stdin);
   // freopen("out1.txt","w",stdout);
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d",&n,&k);
        sum[0] = 0;
        for(i = 1;i<=n;i++)
         scanf("%I64d",&a[i]),sum[i] = sum[i-1]+a[i];
        int head = 0,tail = 0;
        f[0] = 0;
        q[tail++] = 0;
        for(i = k;i<2*k;i++)
          f[i] =  sum[i]-i*a[1];
        for(i = 2*k;i<=n;i++)
        {
           while(head+1<tail&&!slide(q[tail-2],q[tail-1],i-k))tail--;
           q[tail++] = i-k;
        //   printf("i=%d tail = %d q[tail] = %d\n",i,tail-1,q[tail-1]);
           while(1)
           {
              if(head+1>=tail)break;
              long long x = f[q[head]]-sum[q[head]]+q[head]*a[q[head]+1];
              long long y = f[q[head+1]]-sum[q[head+1]]+q[head+1]*a[q[head+1]+1];
              long long z = a[q[head]+1]-a[q[head+1]+1];
              if(x-y<i*z)break;
              head++;
           }
       //    printf("q[head]=%d\n",q[head+1]);
           if(head+1<tail)
           {
              f[i] = f[q[head]]+sum[i]-sum[q[head]]-(i-q[head])*a[q[head]+1];
           }
           else
           {
              if(f[q[head]]+sum[i]-sum[q[head]]-(i-q[head])*a[q[head]+1]<sum[i]-i*a[1])
                f[i] = f[q[head]]+sum[i]-sum[q[head]]-(i-q[head])*a[q[head]+1];
              else
              {
                head++;
                f[i] = sum[i]-i*a[1];
              }
           }
        }
   /*     for(i = 1;i<=n;i++)
         printf("f[%d] = %lld\n",i,f[i]);*/
        printf("%I64d\n",f[n]);
    }
    return 0;
}

转载自:http://blog.163.com/myq_952/blog/static/863906320112711750378/

抱歉!评论已关闭.