标签:for inline return 质数 lang 转化 而且 rod 算法
快速傅里叶变换(Fast Fourier Transform,FFT)
是一种可在 \(O(n \log n)\) 时间内完成的离散傅里叶变换(Discrete Fourier Transform,DFT)
的算法,用来实现将信号从原始域(通常是时间或空间)到频域的互相转化。
FFT 在算法竞赛中主要用来加速多项式乘法(循环卷积)。
形如
的式子称为 \(x\) 的 \(n - 1\) 次多项式,其中 \(a_0, a_1, \dots, a_{n - 1}\) 称为多项式系数,\(n-1\) 称为多项式的次数,记为 \(\deg A(x)\) 或 \(\deg A\)。
\(n - 1\) 次多项式 \(A(x)\) 在 \(x = m\) 处的点值
记 \(A(x)\times B(x)\) 表示多项式 \(A(x), B(x)\) 做多项式乘法,可以简写为 \(A(x)\cdot B(x)\) 或 \(A(x)B(x)\)。
多项式乘法
用系数关系可以表示为
其中 \(\deg C = \deg A + \deg B\)。
易证它们的点值满足如下关系
记 \(\operatorname{conv}(A, B, n)\) 表示多项式 \(A(x), B(x)\) 做长度为 \(n\) 的循环卷积。
循环卷积
系数关系表示为
其中 \(\deg C = n - 1\)。
容易发现,当 \(n > \deg A + \deg B\) 时,该运算等价于多项式乘法。
离散傅里叶变换(Discrete Fourier Transform, DFT)
将多项式 \(A(x)=\sum_{k=0}^{n-1}a_kx^k\) 转换为一些特殊的点值。
记 \(n\) 次单位复根
\(DFT(A)\) 就是要计算点值 \(A(\omega_n^k), k = 0, 1, 2, \dots, n-1\)。
单位根自带的循环特性使得循环卷积 \(C(x) = \operatorname{conv}(A, B, n)\) 的点值也满足:
IDFT 是 DFT 的逆变换。
首先,用等比数列求和易证:
考虑循环卷积 \(C(x) = \operatorname{conv}(A, B, n)\) 的系数表示
设多项式
只要计算 \(DFT(C‘)\) 即可得到 \(C(x)\) 的系数,于是我们用 DFT 完成了逆变换 IDFT。
用两次 DFT 和一次 IDFT就可以计算 \(\operatorname{conv}(A, B, n)\)。
暴力的复杂度是 \(O(n^2)\),此处不赘述。
现在尝试将 DFT 问题分解以优化时间复杂度。
本部分认为 \(n = \deg A + 1\) 为 \(2\) 的整数次幂。对于更一般的情况,暂不考虑。
将序列 \(a_i\) 分成左右两半。
进一步,将 \(A(\omega_{n}^r)\) 按奇偶分类:
设
我们只需要求出 \(P(\omega_{n/2}^r)\) 和 \(Q(\omega_{n/2}^r)\) ,即求解规模为原来一半的两个子问题 \(DFT(P), DFT(Q)\),就能在 \(O(n)\) 时间内计算出 \(DFT(A)\)。
在算法竞赛中这种方法更常见。
注意到在 DIF
中我们最后将 \(A(\omega_n^r)\) 奇偶分类求解,那不妨思考将序列 \(a_k\) 按奇偶分类。
设
则
所以
将 \(A(\omega_n^k)\) 再细分为左右两半,这里运用了等式 \(\omega_{n/2}^k = \omega_{n/2}^{k + n/2}\) 和 \(\omega_n^k+\omega_n{k+n/2} = 0\) :
我们只需要求出 \(A_0(\omega_{n/2}^k)\) 和 \(A_1(\omega_{n/2}^k)\) ,即求解规模为原来一半的两个子问题 \(DFT(A_0), DFT(A_1)\),就能在 \(O(n)\) 时间内计算出 \(DFT(A)\)。
设次数为 \(n - 1\) 的多项式做 DFT 的时间复杂度为 \(T(n)\),则
根据主定理
上述两种计算方式均可以使用递归实现,这里直接给出代码,不再赘述。
DIF
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size(), m = n >> 1;
if (n == 1) return;
std::vector<Complex> p(m), q(m);
for (int i = 0; i < m; i++) {
p[i] = a[i] + a[i + m];
q[i] = (a[i] - a[i + m]) * Complex(cos(2 * PI * i / n), sin(2 * PI * i / n));
}
dft(p), dft(q);
for (int i = 0; i < m; i++)
a[i << 1] = p[i], a[i << 1 | 1] = q[i];
}
void idft(std::vector<Complex> &a) {
dft(a);
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
DIT
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size(), m = n >> 1;
if (n == 1) return;
std::vector<Complex> p(m), q(m);
for (int i = 0; i < m; i++) {
p[i] = a[i << 1];
q[i] = a[i << 1 | 1];
}
dft(p), dft(q);
for (int i = 0; i < m; i++) {
Complex &u = p[i], v = Complex(cos(2 * PI * i / n), sin(2 * PI * i / n)) * q[i];
a[i] = u + v, a[i + m] = u - v;
}
}
void idft(std::vector<Complex> &a) {
dft(a);
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
下面探讨以 非递归方式 实现 DIF
与 DIT
。
由于 DIT
更易于理解(其实只是资料多),先讲这个。
考虑递归的过程:
发现 \(0 \rightarrow 3\) 只是在重新安排数据位置,并没有修改数据,如果我们能把映射关系找到,那就可以一步到位,直接从 \(3\) 开始。
设一个数在第 \(i\) 个阶段 \((0 \leq i \leq \log_2n)\) 的位置为 \(p_i\),相对位置为 \(p_i‘\)(相对位置
指它在括号里的位置,例如上面第 \(1\) 阶段 \(a_1\) 的相对位置为 \(0\))。
容易发现
如果将 \(p_0\) 写成二进制 \(\overline{b_4b_3b_2b_1b_0}\)(这里以 \(n = 32\) 为例),那么单次变化的过程相当于把二进制的后几位向右 rotate
一位,总的变化过程可以描述为:
可以发现,整个过程实际上是在做 reverse
操作!
至此我们找到了映射关系,成功把前面的步骤都砍掉了,只剩回溯,可以改成循环。
void dft(std::vector<Complex> &a) {
int n = a.size();
for (int i = 0, j = 0; i < n; i++) {
if (i > j) std::swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j + k] = a[i + j] - t;
a[i + j] = a[i + j] + t;
}
}
}
}
这里先将递归版DIF的过程简单复述:
第一步:将序列 \(a\) 对半分
第二步:递归计算 \(DFT(p), DFT(q)\)
第三步:重新安排数据位置
发现回溯的过程(即第三步)实际上也只是在重新安排数据存储的位置,而且是上面 \(DIT\) 第一步的逆过程,所以就是位翻转的逆过程,所以还是位翻转。
所以最后安排数据位置可以一步搞定,只剩递归压栈的过程,可以改成循环。
void dft(vector<Complex> &a) {
int n = a.size();
for (int k = n >> 1; k; k >>= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k];
a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j] = a[i + j] + t;
}
}
}
for (int i = 0, j = 0; i < n; i++) {
if (i > j) std::swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
}
发现 \(DIF\) 的最后一步和 \(DIT\) 的第一步都是位翻转,所以先 \(DIF\) 再 \(DIT\),就可以省略位翻转。
完整代码
#include <bits/stdc++.h>
template <class T>
inline void readInt(T &w) {
char c, p = 0;
while (!isdigit(c = getchar())) p = c == ‘-‘;
for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
if (p) w = -w;
}
struct Complex {
double a, b; // a + bi
Complex(double a = 0, double b = 0): a(a), b(b) {}
};
inline Complex operator+(const Complex &p, const Complex &q) {
return Complex(p.a + q.a, p.b + q.b);
}
inline Complex operator-(const Complex &p, const Complex &q) {
return Complex(p.a - q.a, p.b - q.b);
}
inline Complex operator*(const Complex &p, const Complex &q) {
return Complex(p.a * q.a - p.b * q.b, p.a * q.b + p.b * q.a);
}
const double PI = acos(-1.0);
void dft(std::vector<Complex> &a) {
int n = a.size();
for (int k = n >> 1; k; k >>= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k];
a[i + j + k] = (a[i + j] - t) * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j] = a[i + j] + t;
}
}
}
}
void idft(std::vector<Complex> &a) {
int n = a.size();
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
auto t = a[i + j + k] * Complex(cos(PI * j / k), sin(PI * j / k));
a[i + j + k] = a[i + j] - t;
a[i + j] = a[i + j] + t;
}
}
}
for (auto &v: a) v.a /= a.size(), v.b /= a.size();
std::reverse(a.begin() + 1, a.end());
}
int main() {
int n, m, k;
readInt(n), readInt(m);
k = 1 << std::__lg(n + m) + 1;
std::vector<Complex> a(k), b(k), c(k);
for (int i = 0; i <= n; i++) readInt(a[i].a);
for (int i = 0; i <= m; i++) readInt(b[i].a);
dft(a), dft(b);
for (int i = 0; i < k; i++) c[i] = a[i] * b[i];
idft(c);
for (int i = 0; i <= n + m; i++) printf("%d ", (int)(c[i].a + 0.5));
return 0;
}
如果一个质数存在 \(2^n\) 次单位根(其中 \(n\) 最大时的单位根称为原根),那么在这个质数的剩余系下上面的结论依旧成立,可以使用FFT,多称这种FFT为快速数论变换(Number Theory Transform, NTT)
。
常见的质数是 \(P = 998244353\),它的原根 \(g = 3\)。
代码(预处理单位根,较好地平衡了代码复杂度和常数且有一定的封装度):
#include <bits/stdc++.h>
template <class T>
inline void readInt(T &w) {
char c, p = 0;
while (!isdigit(c = getchar())) p = c == ‘-‘;
for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
constexpr int P(998244353), G(3);
inline void inc(int &x, int y) { (x += y) >= P ? x -= P : 0; }
inline int sum(int x, int y) { return x + y >= P ? x + y - P : x + y; }
inline int sub(int x, int y) { return x - y < 0 ? x - y + P : x - y; }
inline int fpow(int x, int k = P - 2) {
int r = 1;
for (; k; k >>= 1, x = 1LL * x * x % P)
if (k & 1) r = 1LL * r * x % P;
return r;
}
namespace Polynomial {
using Polynom = std::vector<int>;
int n;
std::vector<int> w;
void getOmega(int k) {
w.resize(k);
w[0] = 1;
int base = fpow(G, (P - 1) / (k << 1));
for (int i = 1; i < k; i++) w[i] = 1LL * w[i - 1] * base % P;
}
void dft(Polynom &a) {
for (int k = n >> 1; k; k >>= 1) {
getOmega(k);
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
int y = a[i + j + k];
a[i + j + k] = (1LL * a[i + j] - y + P) * w[j] % P;
inc(a[i + j], y);
}
}
}
}
void idft(Polynom &a) {
for (int k = 1; k < n; k <<= 1) {
getOmega(k);
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j++) {
int x = a[i + j], y = 1LL * a[i + j + k] * w[j] % P;
a[i + j] = sum(x, y);
a[i + j + k] = sub(x, y);
}
}
}
int inv = fpow(n);
for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * inv % P;
std::reverse(a.begin() + 1, a.end());
}
} // namespace Polynom
using Polynomial::dft;
using Polynomial::idft;
void poly_multiply(unsigned *A, int n, unsigned *B, int m, unsigned *C) {
int k = Polynomial::n = 1 << std::__lg(n + m) + 1;
std::vector<int> a(k), b(k);
for (int i = 0; i <= n; i++) a[i] = A[i];
for (int i = 0; i <= m; i++) b[i] = B[i];
dft(a), dft(b);
for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i <= n + m; i++) C[i] = a[i];
}
int main() {
int n, m, k;
readInt(n, m);
Polynomial::n = k = 1 << std::__lg(n + m) + 1;
std::vector<int> a(k), b(k);
for (int i = 0; i <= n; i++) readInt(a[i]);
for (int i = 0; i <= m; i++) readInt(b[i]);
dft(a), dft(b);
for (int i = 0; i < k; i++) a[i] = 1LL * a[i] * b[i] % P;
idft(a);
for (int i = 0; i <= n + m; i++) printf("%d ", a[i]);
return 0;
}
上面提到的 FFT 算法虽然限制了 \(n\) 为 2 的次幂,但在大多数情况下已经足够解决问题。
对于更一般的 \(n\) 需要用到 Bluestein’s Algorithm
,可以参考2016年国家集训队论文《再探快速傅里叶变换——毛啸》。
后续可能会填这个坑。
标签:for inline return 质数 lang 转化 而且 rod 算法
原文地址:https://www.cnblogs.com/HolyK/p/13991949.html