现在的位置: 首页 > 综合 > 正文

【FFT-快速傅立叶变换】 FFT乘法模板

2017年09月30日 ⁄ 综合 ⁄ 共 8400字 ⁄ 字号 评论关闭

【FFT-快速傅立叶变换】

HDU-1402  A
* B Problem Plus

题意:给定两个整数,整数的长度最多能达到50000位,输出两个整数的乘积。

分析:题意非常的明了,一个惊世骇俗的想法是使用两个数组将整数保留起来,然后模拟我们平常手算时的乘法,不过这样一来时间复杂度将是O(N^2),由于N过大,因此此题因寻求更加快速的解法。

  对于任何一个N位的整数都可以看作是An*10^(n-1) + An-1*10^(n-2) + ... + A2*10^2 + A1*10 + A0。如果把10看作是一个自变量,那么任何一个整数就可以视作为一个多项式,两个整数相乘也便可以看作是两个多项式相乘。对于一个多项式,我们平时所接触到的多是其系数表示法,普通的相乘也就是建立在两个整数均采用系数表示法的基础上进行的。那么要使得计算多项式相乘的复杂度下降的另一种方式就是寻找一种新的表示多项式的方法......

  若一个多项式的最高阶位为N-1,那么取N个点对(xi, yi)就能够唯一确定这个多项式,可以想象成有N个系数需要N个方程去求解。那么在此就可以寻找点对来表示一个多项式,对于一个大的数,看作多项式之后,那么舍弃掉原来以10为自变量的取值,而选取其他值,再通过计算多项式An*xi^(n-1) + An-1*xi^(n-2) + ... + A2*xi^2 + A1*xi + A0来保存这个多项式的信息。需要选取N个xi形成N对(xi, yi)方可唯一确定原来各个项前的系数,通过选取1的N次单位复根即可,并且利用单位复根的性质,可以使得计算量下降。

  通过点值法表示多项式后,计算乘法也就是O(N)的时间了,由于两个数相乘使得项数变多,因此需要在之前尽可能多取点。FFT算法能够在O(NlogN)时间内将系数法转化为点值法,相乘后再有点值法转为系数法,该题就是使用的这个方法。

  顺便说下一FFT过程中,计算叶子DTF时采用的二进制平摊反转置换,其作用是为了避免算法的递归而实现自底向上的计算方式。回顾一下在计算原串DFT的时候,假设离散点数为0-7,那么有以下过程:

(0 1 2 3 4 5 7) = (0 2 4 6) + (1 3 5 7)                           1

(0 2 4 6) = (0 4) + (2 6)                                               2

(1 3 5 7) = (1 3) + (5 7)                                               2

(0 4) = (0) + (4)                                                          3

(2 6) = (2) + (6)                                                          3

(1 3) = (1) + (3)                                                          3

(5 7) = (5) + (7)                                                          3

分析这些分组的二进制位会发现,第1次分组是根据第0位是否为1来划分的,即奇偶性;第2次分组是根据第1位是否为1来划分的;第三次分组是根据第2位是否为1来划分的。这个特性与与一般的按照大小划分的数很类似(首先按照最高为是否为1划分,然后是次高位...),因此就可以通过一个算法来使得00...00 - 11...11这样递增的序列中的每一个数实现高位和低位的翻转,二进制平摊反转置换就是用于达到这个目的。

算法从1开始到N-2(FFT算法要求N必须是2的幂,保证每次折半之后不会出现奇数),因此0和N-1翻转后还是本身,接着维护好一个下标 j ,这个数就是与递增中的第 i 个数翻转之后对应的数,初始化 j 的下标示 N/2,这个数要是从后往前来定义二进制数的话,就会是1,例如若N=8,那么4的二进制位为100,假定只有3位2进制位组成2进制数,从右往左看,其值为1。接下来就是要找 j 的下一个数了,这个数从右往左看应该是2才能够满足要求,于是从右往左寻找,遇到1变为0,知道遇到0就跳出,并且将该位赋值为1,这个伟大的过程的作用仅仅只是给在右往左定义的二进制数
j 加了一个1。

复制代码
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;

struct C {
    double r, i;
    C() {}
    C(double _r, double _i) : r(_r), i(_i) {}
    inline C operator + (const C & a) const {
        return C(r + a.r, i + a.i);
    }
    inline C operator - (const C & a) const {
        return C(r - a.r, i - a.i);
    }
    inline C operator * (const C & a) const {
        return C(r*a.r - i*a.i, r*a.i + i*a.r);
    }
};

typedef long long LL;
const double pi = acos(-1.0);
const int N = 50005;
C a[N<<2], b[N<<2];
char num1[N], num2[N];
LL ret[N<<2];

void brc(C *y, int L) {
    int i, j, k;
    for (i=1,j=L>>1; i<L-1; ++i) { // 二进制平摊反转置换 O(NlogN)
        if (i < j) swap(y[i], y[j]);
        k = L>>1;
        while (j >= k) {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
}

void FFT(C *y, int L, int dir) {
    brc(y, L);
    for (int h = 2; h <= L; h <<= 1) { // 枚举所需计算的点数 
        C wn(cos(dir*2*pi/h), sin(dir*2*pi/h)); // h次单位复根 
        for (int j = 0; j < L; j += h) { // 原序列被分成了L/h段h长序列 
            C w(1, 0); // 旋转因子 
            for (int k = j; k < j+h/2; ++k) { // 因为折半定理,只需要计算枚举一半的长度即可 
                C u = y[k];
                C t = w*y[k+h/2];
                y[k] = u + t;
                y[k+h/2] = u - t;
                w = w * wn; // 更新旋转因子 
            }
        }
    }
    if (dir == 1) {
        for (int i = 0; i < L; ++i) {
            y[i] = y[i] * C(1.0/L, 0);
        }
    }
}

int main() {
    while (scanf("%s %s", num1, num2) != EOF) {
        memset(ret, 0, sizeof (ret));
        int len1 = strlen(num1), len2 = strlen(num2);
        int ML = len1+len2-1, L = 1;
        while (L < ML) L <<= 1;
        for (int i = len1-1, j = 0; i >= 0; --i, ++j) {
            a[j] = C(num1[i]-'0', 0);
        }
        for (int i = len2-1, j = 0; i >= 0; --i, ++j) {
            b[j] = C(num2[i]-'0', 0);
        }
        for (int i = len1; i < L; ++i) a[i] = C(0, 0);
        for (int i = len2; i < L; ++i) b[i] = C(0, 0);
        FFT(a, L, -1), FFT(b, L, -1);
        for (int i = 0; i < L; ++i) {
            a[i] = a[i] * b[i];
        }
        FFT(a, L, 1);
        for (int i = 0; i < L; ++i) {
            ret[i] = (LL)floor(a[i].r + 0.5);
        }
        for (int i = 0; i < L; ++i) {
            ret[i+1] += ret[i] / 10;
            ret[i] %= 10;
        }
        int p = L;
        while (!ret[p] && p) --p;
        while (p >= 0) printf("%d", (int)ret[p--]);
        puts("");
    }
    return 0;
} 
复制代码

 

HDU-4609 3-idiots

题意:有N条线段,问从这N条线段中选出三条能过组成三角形的概率为多大?

分析:三条边能够组成三角形则满足等式x+y < z。首先将所有的边排序,然后将每条边的长度构成多项式的指数,同一长度的边的数量为系数。然后将这个多项式自己和自己作一个乘法,这里需要使用FFT来实现,去掉自己与自己的组合已经相互的组合情况,就能够得到两两之间组合形成边长和值为某一个值得方案数。使用这个方案数除以总方案数即可。

复制代码
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 300005;

typedef long long LL;
struct C {
    double r, i;
    C() {}
    C(double _r, double _i) : r(_r), i(_i) {}
    inline C operator + (const C &a) const {
        return C(r + a.r, i + a.i);
    }
    inline C operator - (const C &a) const {
        return C(r - a.r, i - a.i);
    }
    inline C operator * (const C &a) const {
        return C(r*a.r-i*a.i, r*a.i+i*a.r);
    }
}a[N], b[N];

const double pi = acos(-1.0);
int n, num[N], cnt[N];
LL res[N], sum[N];

void brc(C *y, int l) {
    int i, j, k;
    for (i=1,j=l>>1; i<l-1; ++i) {
        if (i < j) swap(y[i], y[j]);
        k = l>>1;
        while (j >= k) {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
}

void FFT(C *y, int l, int on) {
    int h, i, j, k;
    C u, t;
    brc(y, l); // 得到一个自底向上的序列 
    for (h = 2; h <= l; h <<= 1) { // 控制一个O(logn)的外层复杂度 
        C wn(cos(on*2*pi/h), sin(on*2*pi/h));
        for (j = 0; j < l; j+=h) { // 两个for循环共组成O(n)的复杂度 
            C w(1, 0);
            for (k = j; k <j+h/2; ++k) {
                u = y[k];
                t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;
            }
        }
    }
    if (on == 1) {
        for (i = 0; i < l; ++i) {
            y[i]= y[i] * C(1.0/l, 0.0);
        }
    }
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d", &n);
        int Max = 0;
        memset(cnt, 0, sizeof (cnt));
        for (int i = 0; i < n; ++i) {
            scanf("%d", &num[i]);
            Max = max(Max, num[i]);
            ++cnt[num[i]];
        }
        int L = 1;
        ++Max;
        while (L < (Max<<1)) L <<= 1;
        for (int i = 0; i < Max; ++i) {
            a[i] = C(cnt[i], 0);
        }
        for (int i = Max; i < L; ++i) {
            a[i] = C(0, 0);
        }
        FFT(a, L, -1);
        for (int i = 0; i < L; ++i) {
            a[i] = a[i] * a[i];
        }
        FFT(a, L, 1);
        for (int i = 0; i < L; ++i) {
            res[i] = (LL)floor(a[i].r + 0.5);
        }
        for (int i = 0; i < Max; ++i) {
            res[i<<1] -= cnt[i];
        }
        for (int i = 0; i < L; ++i) {
            res[i] >>= 1;
        }
        for (int i = 1; i < L; ++i) {
            sum[i] = sum[i-1] + res[i];
        }
        double ret = 0, den = 1.0*n*(n-1)*(n-2)/6.0;
        for (int i = 0; i < n; ++i) {
            ret += sum[num[i]] / den;
        }
        printf("%.7f\n", 1-ret);
    }
    return 0;
}
复制代码

FFT乘法模板

分类: FFT乘法 18人阅读 评论(0) 收藏 举报

思路:

   算法导论第30章有详细说明。此处只是简略说明其主要的步骤。

一个知识点是:

  A(x)=a0+a1x+a2x2+a3x3+……+an-1xn-1

 A[0](x)=a0+a2x+a4x2+……+an-2xn/2-1

 A[1](x)=a1+a3x+a5x2+……+an-1xn/2-1

 

 A[0](x2)+x*A[1](x2)=A(x)  

以上是 二进制平摊反转置换跟求和的主要式子。

多项式有两种表示形式:点值表示,系数表示。

快速FFT主要有以下四点:

   1. 使次数界(上界)增加一倍。A(x)、B(x)的长度扩充到2*n

   2. 求值。主要是求点值表示A(x)、B(x)的点值表示

   3. 点乘。C(x)=A(x)*B(x)

   4. 插值。对C(x)进行插值,求出其系数表示。

  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <cmath>  
  4. #include <cstring>  
  5. #define Pi acos(-1.0)//定义Pi的值  
  6. #define N 200000  
  7. using namespace std;  
  8. struct complex //定义复数结构体  
  9. {  
  10.     double re, im;  
  11.     complex ( double r = 0.0, double i = 0.0 )  
  12.     {  re = r, im = i; } //初始化  
  13.     //定义三种运算  
  14.     complex operator + ( complex o )  
  15.     { return complex ( re + o.re, im + o.im );}  
  16.     complex operator - ( complex o )  
  17.     { return complex ( re - o.re, im - o.im );}  
  18.     complex operator * ( complex o )  
  19.     { return complex ( re * o.re - im * o.im, re * o.im + im * o.re );}  
  20. } x1[N], x2[N];  
  21. char a[N / 2], b[N / 2];  
  22. int sum[N];    //存储最后的结果  
  23.   
  24. void BRC ( complex *y, int len ) //二进制反转倒置  
  25. {  
  26.     int i, j, k;  
  27.     for ( i = 1, j = len / 2; i < len - 1; i++ )  
  28.     {  
  29.         if ( i < j ) { swap ( y[i], y[j] ); } //i<j保证只交换一次  
  30.         k = len / 2;  
  31.         while ( j >= k )  
  32.         {  
  33.             j -= k; k = k / 2;  
  34.         }  
  35.         if ( j < k ) { j += k; }  
  36.     }  
  37. }  
  38. void FFT ( complex *y, int len , double on ) //on=1表示顺,-1表示逆  
  39. {  
  40.     int i, j, k, h;  
  41.     complex u, t;  
  42.     BRC ( y, len );  
  43.     for ( h = 2; h <= len; h <<= 1 ) //控制层数  
  44.     {  
  45.         //初始化单位复根  
  46.         complex wn ( cos ( on * 2 * Pi / h ), sin ( on * 2 * Pi / h ) );  
  47.         for ( j = 0; j < len; j += h ) //控制起始下标  
  48.         {  
  49.             //初始化螺旋因子  
  50.             complex w ( 1, 0 );  
  51.             for ( k = j; k < j + h / 2; k++ )  
  52.             {  
  53.                 u = y[k];  
  54.                 t = w * y[k + h / 2];  
  55.                 y[k] = u + t;  
  56.                 y[k + h / 2] = u - t;  
  57.                 w = w * wn; //更新螺旋因子  
  58.             }  
  59.         }  
  60.     }  
  61.     if ( on == -1 )  
  62.         for ( i = 0; i < len; i++ ) //逆FFT(IDFT)  
  63.         {  
  64.             y[i].re /= len;  
  65.         }  
  66.   
  67. }  
  68. int main()  
  69. {  
  70.     int len1, len2, len, i;  
  71.     while ( scanf ( "%s%s", a, b ) != EOF )  
  72.     {  
  73.         len1 = strlen ( a );  
  74.         len2 = strlen ( b );  
  75.         len = 1;  
  76. //扩充次数界至2*n  
  77.         while ( len < 2 * len1 || len < 2 * len2 ) { len <<= 1; } //右移相当于len=len*2  
  78. //倒置存储  
  79.         for ( i = 0; i < len1; i++ )  
  80.         { x1[i].re = a[len1 - 1 - i] - '0'; x1[i].im = 0.0;}  
  81.         for ( ; i < len1; i++ ) //多余次数界初始化为0  
  82.         {x1[i].re = x1[i].im = 0.0;}  
  83.         for ( i = 0; i < len2; i++ )  
  84.         { x2[i].re = b[len2 - 1 - i] - '0'; x2[i].im = 0.0;}  
  85.         for ( ; i < len2; i++ ) //多余次数界初始化为0  
  86.         {x2[i].re = x2[i].im = 0.0;}  
  87. //FFT求值  
  88.         FFT ( x1, len, 1 ); //FFT(a) 1表示顺 -1表示逆  
  89.         FFT ( x2, len, 1 ); //FFT(b)  
  90. //点乘,结果存入x1  
  91.         for ( i = 0; i < len; i++ )  
  92.         {  
  93.             x1[i] = x1[i] * x2[i];  
  94.         }  
  95. //插值,逆FFT(IDTF)  
  96.         FFT ( x1, len, -1 );  
  97.   
  98. //细节处理  
  99.         for ( i = 0; i < len; i++ )  
  100.         {  
  101.             sum[i] = x1[i].re + 0.5;    //四舍五入  
  102.         }  
  103.         for ( i = 0; i < len; i++ ) //进位  
  104.         {  
  105.             sum[i + 1] += sum[i] / 10;  
  106.             sum[i] %= 10;  
  107.         }  
  108. //输出  
  109.         len = len1 + len2 - 1;  
  110.         while ( sum[len] <= 0 && len > 0 ) { len--; } //检索最高位  
  111.         for ( i = len; i >= 0; i-- ) //倒序输出  
  112.         {  
  113.             cout << sum[i];  
  114.         }  
  115.         cout << endl;  
  116.     }  
  117.     return 0;  
  118. }
     

抱歉!评论已关闭.