一、写在前面
最近数字图像处理课正在学快速傅里叶变换,发现自己对此理解的还不是很到位。于是借此机会,对照着《算法导论》,对这部分内容啃一啃。
两个n次多项式相加的最直接方法所需的时间是O(n),但是相乘的最直接方法所需的时间为O(n2)。用快速傅里叶变换(Fast Fourier Transform,FFT)可以使多项式相乘的时间复杂度降低为O(nlogn)。
需要的一些前置技能:复数、多项式、线性代数。
二、多项式
一个以x为变量的多项式定义在一个代数域F上,将函数A(x)表示为形式和:
A(x)=j=0∑n−1ajxj
我们称a0,a1,…,an−1为如上多项式的系数,所有系数都属于域F,典型的情形是复数集合C。
如果一个多项式A(x)的最高次的非零系数是ak,则称A(x)的次数是k,记degree(A)=k。任何严格大于一个多项式次数的整数都是该多项式的次数界,因此,对于次数界为n的多项式,其次数可以是0∼n−1之间的任何整数。
多项式加法
如果A(x)和B(x)是次数界为n的多项式,那么它们的和也是一个次数界为n的多项式C(x),对所有属于定义域的x,都有C(x)=A(x)+B(x)。也就是说,
若
A(x)=j=0∑n−1ajxj
B(x)=j=0∑n−1bjxj
则
C(x)=j=0∑n−1cjxj(cj=aj+bj)
例如,如果有多项式A(x)=6x3+7x2−10x+9和B(x)=−2x3+4x−5,那么C(x)=4x3+7x2−6x+4。
多项式乘法
如果A(x)和B(x)是次数界为n的多项式,那么它们的乘积C(x)是一个次数界为2n−1的多项式C(x),对所有属于定义域的x,都有C(x)=A(x)B(x)。方法类似还是用上一个例子,那么得到
C(x)=−12x6−14x5+44x4−20x3−75x2+86x−45
形式化的式子有
C(x)=j=0∑2n−2cjxj
其中
cj=k=0∑jakbj−k
此时
degree(C)=degree(A)+degree(B)
多项式的表示
从某种意义上,多项式的系数表达与点值表达式等价的。
系数表达
对一个次数界为n的多项式A(x)=∑j=0n−1ajxj而言,其系数表达是一个由系数组成的(列)向量a=(a0,a1,…,an−1)。对于多项式乘法,系数向量c成为输入向量a和b的卷积,表示成c=a⊗b。
点值表达
一个次数界为n的多项式A(x)的点值表达就是一个由n个点值对组成的集合
{(x0,y0),(x1,y1),…,(xn−1,yn−1)}
使得对k=0,1,…,n−1,所有xk各不相同,且yk=A(xk)。
一个多项式可以有很多不同的点值表达。如果采用的点都相同的话,用点值表达多项式做乘法只需O(n)的时间。
求值与插值
从一个多项式的系数表达转化为点值表达的过程是求值,其逆运算称为插值。
定理(插值多项式的唯一性):对于任意n个点值对组成的集合{(x0,y0),(x1,y1),…,(xn−1,yn−1)},其中所有的xk都不同,那么存在唯一的次数界为n的多项式A(x),满足yk=A(xk)。
证明列出矩阵方程,然后结合范德蒙德矩阵的性质。
简单的求值和插值(拉格朗日插值)的时间复杂度都是O(n2)的。
我们之后就要通过巧妙选取点来加速这两个过程,使其运行时间变为O(nlogn)。
三、单位复数根
n次单位复数根是满足ωn=1的复数ω。
n次单位复数根恰好有n个:
ωn0,ωn1,…,ωnn−1
其中主n次单位复数根为
ωn=e2πi/n=cos(2π/n)+isin(2π/n)
其他n次单位复数根都是ωn的幂次。
消去引理: 对于任何整数n≥0,k≥0,d>0,有ωdndk=ωnk
推论: 对于任意偶数n>0,有ωnn/2=ω2=−1
折半引理: 如果n>0为偶数,那么n个n次单位复数根的平方的集合就是n/2个n/2次单位复数根的集合
求和引理: 对任意整数n≥1和不能被n整除的非负整数k,有∑j=0n−1(ωnk)j=0
四、快速傅里叶变换
DFT
现在我们希望计算次数界n的多项式
A(x)=j=0∑n−1ajxj
在ωnk处的值,记为yk
yk=A(ωnk)=j=0∑n−1ajωnkj
向量y=(y0,y1,…,yn−1)就是系数向量a=(a0,a1,…,an−1)的离散傅里叶变换(DFT),记为y=DFTn(a)。
FFT
快速傅里叶变换(FFT) 利用复数单位根的特殊性质,可以在O(nlogn)时间内计算出DFTn(a)。首先通篇假设n恰好是2的整数幂。
FFT利用了分治策略,采用A(x)中偶数下标的系数与奇数下标的系数,分别定义两个新的次数界为n/2的多项式A[0](x)和A[1](x):
A[0](x)=a0+a2x+a4x2+⋯+an−2xn/2−1
A[1](x)=a1+a3x+a5x2+⋯+an−1xn/2−1
于是有
A(x)=A[0](x2)+xA[1](x2)
所以,求A(x)在ωn0,ωn1,…,ωnn−1处的值转换为求次数界为n/2的多项式A[0](x)和A[1](x)在点(ωn0)2,(ωn1)2,…,(ωnn−1)2的值。可以发现其实是n/2个n/2次单位复数根,且每个根恰好出现两次。
IDFT
将点值表达的多项式转换回系数表达,是相似的过程。
我们把DFT写成矩阵乘积y=Vna。
其中Vn是一个范德蒙德矩阵,在(k,j)处的元素为ωnkj。
对于逆运算a=DFTn−1(y),我们把y乘以Vn的逆矩阵来处理。
定理: 对j,k=0,1,…,n−1,Vn−1在(j,k)元素为ωn−kj/n。
证明Vn−1Vn=In时用求和引理即可,注意使用条件。
所以可以推导出DFTn−1(y):
aj=n1k=0∑n−1ykωn−kj
可以看出只需将单位根取倒数,做一次FFT,最后将结果都除以 n,就做完逆变换了。
五、代码实现
首先是手写复数类,也可以用 std::complex<T>
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| struct Complex { double x,y; Complex(double x_=0,double y_=0) { x=x_; y=y_; } Complex operator -(const Complex &t)const { return Complex(x-t.x,y-t.y); } Complex operator +(const Complex &t)const { return Complex(x+t.x,y+t.y); } Complex operator *(const Complex &t)const { return Complex(x*t.x-y*t.y,x*t.y+y*t.x); } };
|
递归实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| void fft(Complex y[],int n) { if(n==1) return; static Complex c[MAXN]; int m=n/2; for(int i=0;i<m;++i) { c[i]=y[i*2]; c[i+m]=y[i*2+1]; } copy(c,c+n,y); Complex *a0=y,*a1=y+m; fft(a0,m); fft(a1,m); for(int i=0;i<m;++i) { Complex w(cos(-2*PI/n*i),sin(-2*PI/n*i)); c[i]=a0[i]+w*a1[i]; c[i+m]=a0[i]-w*a1[i]; } copy(c,c+n,y); }
|
合并过程的推导:
对于0≤k<n/2
yk=A(ωnk)=A[0](ωn2k)+ωnkA[1](ωn2k)=A[0](ωn/2k)+ωnkA[1](ωn/2k)=yk[0]+ωnkyk[1]
前半段没什么问题,再来看后半段
yk+(n/2)=A(ωnk+(n/2))=A[0](ωn2k+n)+ωnk+(n/2)A[1](ωn2k+n)=A[0](ωn2k)−ωnkA[1](ωn2k)=A[0](ωn/2k)−ωnkA[1](ωn/2k)=yk[0]−ωnkyk[1]
迭代实现
递归实际运行起来常数很大,我们需要更高效的实现方法。
先来观察一下递归过程中输入向量的下标变化,以n=8举例,可以将这个过程自行脑补成一个完全二叉树的样子:
1 2 3 4 5 6 7
| 0 1 2 3 4 5 6 7
0 2 4 6 - 1 3 5 7
0 4 - 2 6 - 1 5 - 3 7
0 - 4 - 2 - 6 - 1 - 5 - 3 - 7
|
如果观察二进制的话会发现对应的下标是反转二进制位得到的,比如“011”变成“110”,即下标3变成了6。
代码实现举例两种
- 直接求出对应位置反转二进制位后的数,然后交换,时间复杂度O(nlogn)
1 2 3 4 5 6 7 8 9 10 11 12 13
| void change(Complex y[],int len) { int k=0; while((1<<k)<len) ++k; for(int i=0;i<len;++i) { int t=0; for(int j=0;j<k;++j) if(i>>j&1) t|=1<<(k-j-1); if(i<t) swap(y[i],y[t]); } }
|
- 从高位模拟二进制加一,用经典的摊还分析可以证明复杂度是O(n)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| void change(Complex y[],int len) { int i,j,k; for(i=1,j=len/2;i<len-1;i++) { if(i<j) swap(y[i],y[j]); k=len/2; while(j>=k) { j-=k; k/=2; } if(j<k) j+=k; } }
|
之后我们再考虑自底向上的合并,在之前的递归版本中,有一个公用子表达式ωnkyk[1]计算了两次,我们可以只计算一次乘积,存放在临时变量t里,然后从yk[0]中增加或者减去t,这一系列操作称为一个蝴蝶操作。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| void fft(Complex y[],int len,int on) { change(y,len); for(int h=2;h<=len;h<<=1) { Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h)); for(int j=0;j<len;j+=h) { Complex w(1,0); for(int k=j;k<j+h/2;k++) { Complex u=y[k]; Complex t=w*y[k+h/2]; y[k]=u+t; y[k+h/2]=u-t; w=w*wn; } } } if(on==-1) for(int i=0;i<len;i++) y[i].x/=len; }
|
on
取值1或-1,on
为-1代表逆变换。
实际上,预处理单位根代替每次旋转精度会更好。
六、模板
多项式乘法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| #include<bits/stdc++.h> using namespace std; typedef long long LL; const double PI=acos(-1.0); const int MAXN=1<<18; struct Complex { double x,y; Complex(double x_=0,double y_=0) { x=x_; y=y_; } Complex operator -(const Complex &t)const { return Complex(x-t.x,y-t.y); } Complex operator +(const Complex &t)const { return Complex(x+t.x,y+t.y); } Complex operator *(const Complex &t)const { return Complex(x*t.x-y*t.y,x*t.y+y*t.x); } } x1[MAXN+5],x2[MAXN+5],wn[MAXN+5]; void init() { for(int i=0;i<=MAXN;++i) wn[i]=Complex(cos(-2*PI*i/MAXN),sin(-2*PI*i/MAXN)); } void change(Complex y[],int len) { int i,j,k; for(i=1,j=len/2;i<len-1;i++) { if(i<j) swap(y[i],y[j]); k=len/2; while(j>=k) { j-=k; k/=2; } if(j<k) j+=k; } } void fft(Complex y[],int len,int on) { change(y,len); for(int h=2;h<=len;h<<=1) { int st=MAXN/h; for(int j=0;j<len;j+=h) { int ptr=0; for(int k=j;k<j+h/2;k++) { Complex w=wn[on==1?ptr:MAXN-ptr]; Complex u=y[k],t=w*y[k+h/2]; y[k]=u+t; y[k+h/2]=u-t; ptr+=st; } } } if(on==-1) for(int i=0;i<len;i++) y[i].x/=len; } int n,m; int main() { init(); scanf("%d%d",&n,&m); ++n;++m; int len=1; while(len<(n<<1)||len<(m<<1)) len<<=1; for(int i=0;i<n;++i) { int x; scanf("%d",&x); x1[i].x=x; } for(int i=0;i<m;++i) { int x; scanf("%d",&x); x2[i].x=x; } fft(x1,len,1); fft(x2,len,1); for(int i=0;i<len;++i) x1[i]=x1[i]*x2[i]; fft(x1,len,-1); for(int i=0;i<n+m-1;++i) printf("%d%c",(int)(x1[i].x+0.5)," \n"[i==n+m-2]); return 0; }
|
七、结语
啊,终于弄完快速傅里叶变换了,撒花!
实际上,关于FFT还有很多东西没有讨论到,先到这里吧。