以下内容摘自《算法导论》思考题30-3
我们可以将一维离散傅里叶变换推广到d维上,这时输入是一个d维的数组A = ( a j 1 , j 2 , … , j d ) A=(a_{j_1,j_2,\dots,j_d}) A = ( a j 1 , j 2 , … , j d ) ,维数分别为n 1 , n 2 , … , n d n_1,n_2,\dots,n_d n 1 , n 2 , … , n d ,其中n 1 n 2 … n d = n n_1n_2\dots n_d=n n 1 n 2 … n d = n 。定义d维离散傅里叶变换如下:
y k 1 , k 2 , … , k d = ∑ j 1 = 0 n 1 − 1 ∑ j 2 = 0 n 2 − 1 ⋯ ∑ j d = 0 n d − 1 a j 1 , j 2 , … , j d ω n 1 j 1 k 1 ω n 2 j 2 k 2 … ω n d j d k d y_{k_1,k_2,\dots,k_d}=\sum_{j_1=0}^{n_1-1}\sum_{j_2=0}^{n_2-1}\dots\sum_{j_d=0}^{n_d-1}a_{j_1,j_2,\dots,j_d}\omega _{n_1}^{j_1k_1}\omega _{n_2}^{j_2k_2}\dots\omega _{n_d}^{j_dk_d}
y k 1 , k 2 , … , k d = j 1 = 0 ∑ n 1 − 1 j 2 = 0 ∑ n 2 − 1 ⋯ j d = 0 ∑ n d − 1 a j 1 , j 2 , … , j d ω n 1 j 1 k 1 ω n 2 j 2 k 2 … ω n d j d k d
其中0 ≤ k 1 < n 1 , 0 ≤ k 2 < n 2 , … , 0 ≤ k d < n d 0\le k_1<n_1,0\le k_2<n_2,\dots,0\le k_d<n_d 0 ≤ k 1 < n 1 , 0 ≤ k 2 < n 2 , … , 0 ≤ k d < n d
a. 证明:我们可以依次在每个维度上计算一维的DFT来计算一个d维的DFT。也就是说,首先沿着第1维计算n / n 1 n/n_1 n / n 1 个独立的一维DFT。然后,把沿着第一维的DFT的结果作为输入,我们计算沿着第2维的n / n 2 n/n_2 n / n 2 个独立的一维DFT。利用这个结果作为输入,我们计算沿着第3维的n / n 3 n/n_3 n / n 3 个独立的一维DFT,如此下去,直到第d维。
b. 证明:维度的次序并无影响,于是可以通过在d个维度的任意顺序中计算一维DFT来计算一个d为的DFT。
c. 证明:如果采用计算快速傅里叶变换计算每个一维的DFT,那么计算一个d维的DFT的总时间是O ( n l o g n ) O(nlogn) O ( n l o g n ) ,与d无关。
练习题目
Five Dimensional Discrete Fourier Transform
2017年ICPC南宁赛区的G题
直接做多维DFT,复杂度O ( T N 6 ) O(TN^6) O ( T N 6 ) ,需要卡卡常数。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 #include <bits/stdc++.h> using namespace std;typedef double db;const db pi=acos (-1.0 );struct Complex { db x,y; Complex (db x_=0 ,db y_=0 ) { x=x_; y=y_; } Complex operator -(const Complex &t)const { return Complex (x-t.x,y-t.y); } Complex operator +(const Complex &t)const { return Complex (x+t.x,y+t.y); } Complex operator *(const Complex &t)const { return Complex (x*t.x-y*t.y,x*t.y+y*t.x); } Complex operator *(const db &t)const { return Complex (x*t,y*t); } } a[10 ][10 ][10 ][10 ][10 ],A[10 ][10 ][10 ][10 ][10 ],w[20 ],x[20 ],y[20 ]; Complex e (db t) { return Complex (cos (t),sin (t)); } int n[10 ];db alpha; void dft (Complex *x,Complex *y,int n) { for (int j=0 ;j<n;++j) w[j]=e (-2 *pi/n*j); for (int i=0 ;i<n;++i) { y[i]=Complex (0 ,0 ); for (int j=0 ;j<n;++j) y[i]=y[i]+x[j]*w[i*j%n]; } } db calc () { db ans=0 ; for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) ans+=abs (A[i0][i1][i2][i3][i4].x); int nn=n[0 ]*n[1 ]*n[2 ]*n[3 ]*n[4 ]; return ans/sqrt (1.0 *nn*nn*nn); } void baoli () { for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) { A[i0][i1][i2][i3][i4]=Complex (0 ,0 ); for (int j0=0 ;j0<n[0 ];++j0) for (int j1=0 ;j1<n[1 ];++j1) for (int j2=0 ;j2<n[2 ];++j2) for (int j3=0 ;j3<n[3 ];++j3) for (int j4=0 ;j4<n[4 ];++j4) A[i0][i1][i2][i3][i4]=A[i0][i1][i2][i3][i4]+a[j0][j1][j2][j3][j4]*e (2 *pi*(1.0 *i0*j0/n[0 ]+1.0 *i1*j1/n[1 ]+1.0 *i2*j2/n[2 ]+1.0 *i3*j3/n[3 ]+1.0 *i4*j4/n[4 ])); } } void dft () { for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) { for (int i0=0 ;i0<n[0 ];++i0) x[i0]=a[i0][i1][i2][i3][i4]; dft (x,y,n[0 ]); for (int i0=0 ;i0<n[0 ];++i0) a[i0][i1][i2][i3][i4]=y[i0]; } for (int i0=0 ;i0<n[0 ];++i0) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) { for (int i1=0 ;i1<n[1 ];++i1) x[i1]=a[i0][i1][i2][i3][i4]; dft (x,y,n[1 ]); for (int i1=0 ;i1<n[1 ];++i1) a[i0][i1][i2][i3][i4]=y[i1]; } for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) { for (int i2=0 ;i2<n[2 ];++i2) x[i2]=a[i0][i1][i2][i3][i4]; dft (x,y,n[2 ]); for (int i2=0 ;i2<n[2 ];++i2) a[i0][i1][i2][i3][i4]=y[i2]; } for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i4=0 ;i4<n[4 ];++i4) { for (int i3=0 ;i3<n[3 ];++i3) x[i3]=a[i0][i1][i2][i3][i4]; dft (x,y,n[3 ]); for (int i3=0 ;i3<n[3 ];++i3) a[i0][i1][i2][i3][i4]=y[i3]; } for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) { for (int i4=0 ;i4<n[4 ];++i4) x[i4]=a[i0][i1][i2][i3][i4]; dft (x,y,n[4 ]); for (int i4=0 ;i4<n[4 ];++i4) a[i0][i1][i2][i3][i4]=y[i4]; } for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) A[i0][i1][i2][i3][i4]=a[i0][i1][i2][i3][i4]; } int main () { int T; scanf ("%d" ,&T); while (T--) { for (int i=0 ;i<5 ;++i) scanf ("%d" ,&n[i]); scanf ("%lf" ,&alpha); for (int i0=0 ;i0<n[0 ];++i0) for (int i1=0 ;i1<n[1 ];++i1) for (int i2=0 ;i2<n[2 ];++i2) for (int i3=0 ;i3<n[3 ];++i3) for (int i4=0 ;i4<n[4 ];++i4) a[i0][i1][i2][i3][i4]=e ((i0-i1+i2-i3+i4)*alpha)*(i0^i1^i2^i3^i4); dft (); printf ("%.6f\n" ,calc ()); } }