本文最后更新于:2020年8月24日 下午

一、写在前面

最近数字图像处理课正在学快速傅里叶变换,发现自己对此理解的还不是很到位。于是借此机会,对照着《算法导论》,对这部分内容啃一啃。

两个nn次多项式相加的最直接方法所需的时间是O(n)O(n),但是相乘的最直接方法所需的时间为O(n2)O(n^2)。用快速傅里叶变换(Fast Fourier Transform,FFT)可以使多项式相乘的时间复杂度降低为O(nlogn)O(nlogn)

需要的一些前置技能:复数、多项式、线性代数。

二、多项式

一个以xx为变量的多项式定义在一个代数域FF上,将函数A(x)A(x)表示为形式和:

A(x)=j=0n1ajxjA(x)=\sum_{j=0}^{n-1}a_jx^j

我们称a0,a1,,an1a_0,a_1,\dots,a_{n-1}为如上多项式的系数,所有系数都属于域FF,典型的情形是复数集合CC

如果一个多项式A(x)A(x)的最高次的非零系数是aka_k,则称A(x)A(x)次数kk,记degree(A)=kdegree(A)=k。任何严格大于一个多项式次数的整数都是该多项式的次数界,因此,对于次数界为nn的多项式,其次数可以是0n10\sim n-1之间的任何整数。

多项式加法

如果A(x)A(x)B(x)B(x)是次数界为nn的多项式,那么它们的和也是一个次数界为nn的多项式C(x)C(x),对所有属于定义域的xx,都有C(x)=A(x)+B(x)C(x)=A(x)+B(x)。也就是说,

A(x)=j=0n1ajxjA(x)=\sum_{j=0}^{n-1}a_jx^j

B(x)=j=0n1bjxjB(x)=\sum_{j=0}^{n-1}b_jx^j

C(x)=j=0n1cjxj(cj=aj+bj)C(x)=\sum_{j=0}^{n-1}c_jx^j(c_j=a_j+b_j)

例如,如果有多项式A(x)=6x3+7x210x+9A(x)=6x^3+7x^2-10x+9B(x)=2x3+4x5B(x)=-2x^3+4x-5,那么C(x)=4x3+7x26x+4C(x)=4x^3+7x^2-6x+4

多项式乘法

如果A(x)A(x)B(x)B(x)是次数界为nn的多项式,那么它们的乘积C(x)C(x)是一个次数界为2n12n-1的多项式C(x)C(x),对所有属于定义域的xx,都有C(x)=A(x)B(x)C(x)=A(x)B(x)。方法类似还是用上一个例子,那么得到

C(x)=12x614x5+44x420x375x2+86x45C(x)=-12x^6-14x^5+44x^4-20x^3-75x^2+86x-45

形式化的式子有

C(x)=j=02n2cjxjC(x)=\sum_{j=0}^{2n-2}c_jx^j

其中

cj=k=0jakbjkc_j=\sum_{k=0}^{j}a_{k}b_{j-k}

此时

degree(C)=degree(A)+degree(B)degree(C)=degree(A)+degree(B)

多项式的表示

从某种意义上,多项式的系数表达与点值表达式等价的。

系数表达

对一个次数界为nn的多项式A(x)=j=0n1ajxjA(x)=\sum_{j=0}^{n-1}a_jx^j而言,其系数表达是一个由系数组成的(列)向量a=(a0,a1,,an1)a=(a_0,a_1,\dots,a_{n-1})。对于多项式乘法,系数向量cc成为输入向量aabb的卷积,表示成c=abc=a\otimes b

点值表达

一个次数界为nn的多项式A(x)A(x)点值表达就是一个由nn个点值对组成的集合

{(x0,y0),(x1,y1),,(xn1,yn1)}\{(x_0,y_0),(x_1,y_1),\dots,(x_{n-1},y_{n-1})\}

使得对k=0,1,,n1k=0,1,\dots,n-1,所有xkx_k各不相同,且yk=A(xk)y_k=A(x_k)

一个多项式可以有很多不同的点值表达。如果采用的点都相同的话,用点值表达多项式做乘法只需O(n)O(n)的时间。

求值与插值

从一个多项式的系数表达转化为点值表达的过程是求值,其逆运算称为插值。

定理(插值多项式的唯一性):对于任意n个点值对组成的集合{(x0,y0),(x1,y1),,(xn1,yn1)}\{(x_0,y_0),(x_1,y_1),\dots,(x_{n-1},y_{n-1})\},其中所有的xkx_k都不同,那么存在唯一的次数界为n的多项式A(x)A(x),满足yk=A(xk)y_k=A(x_k)

证明列出矩阵方程,然后结合范德蒙德矩阵的性质。

简单的求值和插值(拉格朗日插值)的时间复杂度都是O(n2)O(n^2)的。

我们之后就要通过巧妙选取点来加速这两个过程,使其运行时间变为O(nlogn)O(nlogn)


三、单位复数根

nn次单位复数根是满足ωn=1\omega ^n=1的复数ω\omega

nn次单位复数根恰好有nn个:

ωn0,ωn1,,ωnn1\omega _{n}^{0},\omega _{n}^{1},\dots,\omega _{n}^{n-1}

其中主nn次单位复数根为

ωn=e2πi/n=cos(2π/n)+isin(2π/n)\omega _n=e^{2\pi i/n}=\cos(2\pi/n)+i\sin(2\pi/n)

其他nn次单位复数根都是ωn\omega _n的幂次。

消去引理: 对于任何整数n0,k0,d>0n\ge 0,k\ge 0,d>0,有ωdndk=ωnk\omega _{dn}^{dk}=\omega _{n}^{k}

推论: 对于任意偶数n>0n>0,有ωnn/2=ω2=1\omega _{n}^{n/2}=\omega _{2}=-1

折半引理: 如果n>0n>0为偶数,那么nnnn次单位复数根的平方的集合就是n/2n/2n/2n/2次单位复数根的集合

求和引理: 对任意整数n1n\geq 1和不能被nn整除的非负整数kk,有j=0n1(ωnk)j=0\sum_{j=0}^{n-1}(\omega _n^k)^j=0


四、快速傅里叶变换

DFT

现在我们希望计算次数界nn的多项式

A(x)=j=0n1ajxjA(x)=\sum_{j=0}^{n-1}a_jx^j

ωnk\omega_{n}^{k}处的值,记为yky_k

yk=A(ωnk)=j=0n1ajωnkjy_k=A(\omega_{n}^{k})=\sum_{j=0}^{n-1}a_j\omega_{n}^{kj}

向量y=(y0,y1,,yn1)y=(y_0,y_1,\dots,y_{n-1})就是系数向量a=(a0,a1,,an1)a=(a_0,a_1,\dots,a_{n-1})离散傅里叶变换(DFT),记为y=DFTn(a)y=DFT_n(a)

FFT

快速傅里叶变换(FFT) 利用复数单位根的特殊性质,可以在O(nlogn)O(nlogn)时间内计算出DFTn(a)DFT_n(a)。首先通篇假设nn恰好是22的整数幂。

FFT利用了分治策略,采用A(x)A(x)中偶数下标的系数与奇数下标的系数,分别定义两个新的次数界为n/2n/2的多项式A[0](x)A^{[0]}(x)A[1](x)A^{[1]}(x):

A[0](x)=a0+a2x+a4x2++an2xn/21A^{[0]}(x)=a_{0}+a_{2}x+a_{4}x^2+\dots+a_{n-2}x^{n/2-1}

A[1](x)=a1+a3x+a5x2++an1xn/21A^{[1]}(x)=a_{1}+a_{3}x+a_{5}x^2+\dots+a_{n-1}x^{n/2-1}

于是有

A(x)=A[0](x2)+xA[1](x2)A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)

所以,求A(x)A(x)ωn0,ωn1,,ωnn1\omega _{n}^{0},\omega _{n}^{1},\dots,\omega _{n}^{n-1}处的值转换为求次数界为n/2n/2的多项式A[0](x)A^{[0]}(x)A[1](x)A^{[1]}(x)在点(ωn0)2,(ωn1)2,,(ωnn1)2(\omega _{n}^{0})^2,(\omega _{n}^{1})^2,\dots,(\omega _{n}^{n-1})^2的值。可以发现其实是n/2n/2n/2n/2次单位复数根,且每个根恰好出现两次。

IDFT

将点值表达的多项式转换回系数表达,是相似的过程。

我们把DFT写成矩阵乘积y=Vnay=V_{n}a

其中VnV_{n}是一个范德蒙德矩阵,在(k,j)(k,j)处的元素为ωnkj\omega _{n}^{kj}

对于逆运算a=DFTn1(y)a=DFT_{n}^{-1}(y),我们把yy乘以VnV_{n}的逆矩阵来处理。

定理:j,k=0,1,,n1j,k=0,1,\dots,n-1Vn1V_{n}^{-1}(j,k)(j,k)元素为ωnkj/n\omega _{n}^{-kj}/n

证明Vn1Vn=InV_{n}^{-1}V_{n}=I_n时用求和引理即可,注意使用条件。

所以可以推导出DFTn1(y)DFT_{n}^{-1}(y)

aj=1nk=0n1ykωnkja_j=\frac{1}{n}\sum_{k=0}^{n-1}y_{k}\omega_n^{-kj}

可以看出只需将单位根取倒数,做一次FFT,最后将结果都除以 nn,就做完逆变换了。


五、代码实现

首先是手写复数类,也可以用 std::complex<T>

struct Complex
{
    double x,y;
    Complex(double x_=0,double y_=0)
    {
        x=x_;
        y=y_;
    }
    Complex operator -(const Complex &t)const
    {
        return Complex(x-t.x,y-t.y);
    }
    Complex operator +(const Complex &t)const
    {
        return Complex(x+t.x,y+t.y);
    }
    Complex operator *(const Complex &t)const
    {
        return Complex(x*t.x-y*t.y,x*t.y+y*t.x);
    }
};

递归实现

void fft(Complex y[],int n)
{
    if(n==1) return;
    static Complex c[MAXN];
    int m=n/2;
    for(int i=0;i<m;++i)
    {
        c[i]=y[i*2];
        c[i+m]=y[i*2+1];
    }
    copy(c,c+n,y);
    Complex *a0=y,*a1=y+m;
    fft(a0,m);
    fft(a1,m);
    for(int i=0;i<m;++i)
    {
        Complex w(cos(-2*PI/n*i),sin(-2*PI/n*i));
        c[i]=a0[i]+w*a1[i];
        c[i+m]=a0[i]-w*a1[i];
    }
    copy(c,c+n,y);
}

合并过程的推导:

对于0k<n/20 \le k < n/2

yk=A(ωnk)=A[0](ωn2k)+ωnkA[1](ωn2k)=A[0](ωn/2k)+ωnkA[1](ωn/2k)=yk[0]+ωnkyk[1]\begin{aligned} y _k & =A(\omega _{n}^{k}) \\\\ & =A^{[0]}(\omega _{n}^{2k})+\omega _{n}^{k}A^{[1]}(\omega _{n}^{2k}) \\\\ & =A^{[0]}(\omega _{n/2}^{k})+\omega _{n}^{k}A^{[1]}(\omega _{n/2}^{k}) \\\\ & =y_k^{[0]}+\omega _{n}^{k}y_k^{[1]} \end{aligned}

前半段没什么问题,再来看后半段

yk+(n/2)=A(ωnk+(n/2))=A[0](ωn2k+n)+ωnk+(n/2)A[1](ωn2k+n)=A[0](ωn2k)ωnkA[1](ωn2k)=A[0](ωn/2k)ωnkA[1](ωn/2k)=yk[0]ωnkyk[1]\begin{aligned} y_{k+(n/2)} & =A(\omega _{n}^{k+(n/2)}) \\\\ & =A^{[0]}(\omega _{n}^{2k+n})+\omega _{n}^{k+(n/2)}A^{[1]}(\omega _{n}^{2k+n}) \\\\ & =A^{[0]}(\omega _{n}^{2k})-\omega _{n}^{k}A^{[1]}(\omega _{n}^{2k}) \\\\ & =A^{[0]}(\omega _{n/2}^{k})-\omega _{n}^{k}A^{[1]}(\omega _{n/2}^{k}) \\\\ & =y_k^{[0]}-\omega _{n}^{k}y_k^{[1]}\end{aligned}

迭代实现

递归实际运行起来常数很大,我们需要更高效的实现方法。

先来观察一下递归过程中输入向量的下标变化,以n=8n=8举例,可以将这个过程自行脑补成一个完全二叉树的样子:

0   1   2   3   4   5   6   7 

0   2   4   6 - 1   3   5   7

0   4 - 2   6 - 1   5 - 3   7

0 - 4 - 2 - 6 - 1 - 5 - 3 - 7

如果观察二进制的话会发现对应的下标是反转二进制位得到的,比如“011”变成“110”,即下标3变成了6。

代码实现举例两种

  1. 直接求出对应位置反转二进制位后的数,然后交换,时间复杂度O(nlogn)O(nlogn)
void change(Complex y[],int len)
{
    int k=0;
    while((1<<k)<len) ++k;
    for(int i=0;i<len;++i)
    {
        int t=0;
        for(int j=0;j<k;++j)
            if(i>>j&1)
                t|=1<<(k-j-1);
        if(i<t) swap(y[i],y[t]);
    }
}
  1. 从高位模拟二进制加一,用经典的摊还分析可以证明复杂度是O(n)O(n)
void change(Complex y[],int len)
{
    int i,j,k;
    for(i=1,j=len/2;i<len-1;i++)
    {
        if(i<j) swap(y[i],y[j]);
        k=len/2;
        while(j>=k)
        {
            j-=k;
            k/=2;
        }
        if(j<k) j+=k;
    }
}

之后我们再考虑自底向上的合并,在之前的递归版本中,有一个公用子表达式ωnkyk[1]\omega _{n}^{k}y_k^{[1]}计算了两次,我们可以只计算一次乘积,存放在临时变量tt里,然后从yk[0]y_k^{[0]}中增加或者减去tt,这一系列操作称为一个蝴蝶操作

代码:

void fft(Complex y[],int len,int on)
{
    change(y,len);
    for(int h=2;h<=len;h<<=1)
    {
        Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
        for(int j=0;j<len;j+=h)
        {
            Complex w(1,0);
            for(int k=j;k<j+h/2;k++)
            {
                Complex u=y[k];
                Complex t=w*y[k+h/2];
                y[k]=u+t;
                y[k+h/2]=u-t;
                w=w*wn;
            }
        }
    }
    if(on==-1)
        for(int i=0;i<len;i++)
            y[i].x/=len;
}

on 取值1或-1,on 为-1代表逆变换。

实际上,预处理单位根代替每次旋转精度会更好。


六、模板

多项式乘法

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const double PI=acos(-1.0);
const int MAXN=1<<18;
struct Complex
{
    double x,y;
    Complex(double x_=0,double y_=0)
    {
        x=x_;
        y=y_;
    }
    Complex operator -(const Complex &t)const
    {
        return Complex(x-t.x,y-t.y);
    }
    Complex operator +(const Complex &t)const
    {
        return Complex(x+t.x,y+t.y);
    }
    Complex operator *(const Complex &t)const
    {
        return Complex(x*t.x-y*t.y,x*t.y+y*t.x);
    }
} x1[MAXN+5],x2[MAXN+5],wn[MAXN+5];
void init()
{
    for(int i=0;i<=MAXN;++i)
        wn[i]=Complex(cos(-2*PI*i/MAXN),sin(-2*PI*i/MAXN));
}
void change(Complex y[],int len)
{
    int i,j,k;
    for(i=1,j=len/2;i<len-1;i++)
    {
        if(i<j) swap(y[i],y[j]);
        k=len/2;
        while(j>=k)
        {
            j-=k;
            k/=2;
        }
        if(j<k) j+=k;
    }
}
void fft(Complex y[],int len,int on)
{
    change(y,len);
    for(int h=2;h<=len;h<<=1)
    {
        int st=MAXN/h;
        for(int j=0;j<len;j+=h)
        {
            int ptr=0;
            for(int k=j;k<j+h/2;k++)
            {
                Complex w=wn[on==1?ptr:MAXN-ptr];
                Complex u=y[k],t=w*y[k+h/2];
                y[k]=u+t;
                y[k+h/2]=u-t;
                ptr+=st;
            }
        }
    }
    if(on==-1)
        for(int i=0;i<len;i++)
            y[i].x/=len;
}
int n,m;
int main()
{
    init();
    scanf("%d%d",&n,&m);
    ++n;++m;
    int len=1;
    while(len<(n<<1)||len<(m<<1)) len<<=1;
    for(int i=0;i<n;++i)
    {
        int x;
        scanf("%d",&x);
        x1[i].x=x;
    }
    for(int i=0;i<m;++i)
    {
        int x;
        scanf("%d",&x);
        x2[i].x=x;
    }
    fft(x1,len,1);
    fft(x2,len,1);
    for(int i=0;i<len;++i) x1[i]=x1[i]*x2[i];
    fft(x1,len,-1);
    for(int i=0;i<n+m-1;++i)
        printf("%d%c",(int)(x1[i].x+0.5)," \n"[i==n+m-2]);
    return 0;
}

七、结语

啊,终于弄完快速傅里叶变换了,撒花!

实际上,关于FFT还有很多东西没有讨论到,先到这里吧。


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!

多维快速傅里叶变换 上一篇
莫队算法总结 下一篇