多项式乘法快速算法:FFT、NTT与任意模数MTT实现
FFT(快速傅里叶变换)
FFT利用单位根的性质,将多项式乘法从O(n²)优化到O(n log n)。核心思想是将系数表示转换为点值表示,在点值域进行乘法后再逆变换回来。
迭代实现采用位逆序置换(bit-reversal permutation)和蝴蝶操作,避免了递归开销。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1 << 21;
const double PI = acos(-1.0);
struct Complex {
double re, im;
Complex(double r = 0, double i = 0) : re(r), im(i) {}
Complex operator+(const Complex& o) const { return Complex(re + o.re, im + o.im); }
Complex operator-(const Complex& o) const { return Complex(re - o.re, im - o.im); }
Complex operator*(const Complex& o) const {
return Complex(re * o.re - im * o.im, re * o.im + im * o.re);
}
};
int rev[N];
void init_rev(int n, int bit) {
for (int i = 0; i < n; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void fft(Complex* a, int n, int inv) {
for (int i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
Complex wn(cos(PI / mid), inv * sin(PI / mid));
for (int i = 0; i < n; i += mid << 1) {
Complex w(1, 0);
for (int j = 0; j < mid; j++, w = w * wn) {
Complex x = a[i + j], y = w * a[i + j + mid];
a[i + j] = x + y;
a[i + j + mid] = x - y;
}
}
}
if (inv == -1)
for (int i = 0; i < n; i++) a[i].re /= n;
}
Complex ta[N], tb[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m;
for (int i = 0; i <= n; i++) cin >> ta[i].re;
for (int i = 0; i <= m; i++) cin >> tb[i].re;
int lim = 1, bit = 0;
while (lim <= n + m) lim <<= 1, bit++;
init_rev(lim, bit);
fft(ta, lim, 1);
fft(tb, lim, 1);
for (int i = 0; i < lim; i++) ta[i] = ta[i] * tb[i];
fft(ta, lim, -1);
for (int i = 0; i <= n + m; i++)
cout << (int)(ta[i].re + 0.5) << " \n"[i == n + m];
return 0;
}NTT(数论变换)
NTT将FFT中的复数单位根替换为模意义下的原根,解决了FFT的精度问题,同时支持模运算。常用模数998244353 = 119 × 2²³ + 1,其原根为3。
关键性质:对于形如c·2ᵏ+1的质数p,存在2ᵏ次单位根g^((p-1)/2ᵏ)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
const int G = 3;
const int N = 1 << 21;
int power(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD;
b >>= 1;
}
return res;
}
int rev[N], invG = power(G, MOD - 2);
void ntt(int* a, int n, int type) {
for (int i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
int wn = power(type == 1 ? G : invG, (MOD - 1) / (mid << 1));
for (int i = 0; i < n; i += mid << 1) {
int w = 1;
for (int j = 0; j < mid; j++, w = (ll)w * wn % MOD) {
int x = a[i + j], y = (ll)w * a[i + j + mid] % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + mid] = (x - y + MOD) % MOD;
}
}
}
if (type == -1) {
int invn = power(n, MOD - 2);
for (int i = 0; i < n; i++) a[i] = (ll)a[i] * invn % MOD;
}
}
int pa[N], pb[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m;
for (int i = 0; i <= n; i++) cin >> pa[i];
for (int i = 0; i <= m; i++) cin >> pb[i];
int lim = 1, bit = 0;
while (lim <= n + m) lim <<= 1, bit++;
for (int i = 0; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
ntt(pa, lim, 1);
ntt(pb, lim, 1);
for (int i = 0; i < lim; i++) pa[i] = (ll)pa[i] * pb[i] % MOD;
ntt(pa, lim, -1);
for (int i = 0; i <= n + m; i++)
cout << pa[i] << " \n"[i == n + m];
return 0;
}MTT(任意模数多项式乘法)
当模数不满足NTT要求(如10⁹+7)时,采用MTT(Modular Transform Technique)。利用FFT的实部和虚部分解,将大整数拆分为高15位和低15位,通过四次FFT完成精确计算。
设a = a₁·2¹⁵ + a₀,b = b₁·2¹⁵ + b₀,则a·b = a₀b₀ + (a₀b₁+a₁b₀)·2¹⁵ + a₁b₁·2³⁰。通过复数打包技巧减少FFT次数。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const int N = 1 << 18;
const ld PI = acos(-1.0L);
const int SPLIT = 15;
const int MASK = (1 << SPLIT) - 1;
struct Complex {
ld x, y;
Complex(ld r = 0, ld i = 0) : x(r), y(i) {}
Complex operator+(const Complex& o) const { return Complex(x + o.x, y + o.y); }
Complex operator-(const Complex& o) const { return Complex(x - o.x, y - o.y); }
Complex operator*(const Complex& o) const {
return Complex(x * o.x - y * o.y, x * o.y + y * o.x);
}
Complex conj() const { return Complex(x, -y); }
};
int rev[N];
void init(int n, int bit) {
for (int i = 0; i < n; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void fft(Complex* a, int n, int inv) {
for (int i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
Complex wn(cos(PI / mid), inv * sin(PI / mid));
for (int i = 0; i < n; i += mid << 1) {
Complex w(1, 0);
for (int j = 0; j < mid; j++, w = w * wn) {
Complex u = a[i + j], v = w * a[i + j + mid];
a[i + j] = u + v;
a[i + j + mid] = u - v;
}
}
}
if (inv == -1)
for (int i = 0; i < n; i++) a[i].x /= n, a[i].y /= n;
}
Complex f[N], g[N], p[N], q[N];
void multiply(int* a, int* b, int* c, int n, int m, int mod) {
int lim = 1, bit = 0;
while (lim <= n + m) lim <<= 1, bit++;
init(lim, bit);
for (int i = 0; i < lim; i++) {
f[i] = (i <= n) ? Complex(a[i] & MASK, a[i] >> SPLIT) : Complex(0, 0);
g[i] = (i <= m) ? Complex(b[i] & MASK, b[i] >> SPLIT) : Complex(0, 0);
}
fft(f, lim, 1); fft(g, lim, 1);
for (int i = 0; i < lim; i++) {
int j = (lim - i) & (lim - 1);
Complex fa = (f[i] + f[j].conj()) * Complex(0.5, 0);
Complex fb = (f[i] - f[j].conj()) * Complex(0, -0.5);
Complex ga = (g[i] + g[j].conj()) * Complex(0.5, 0);
Complex gb = (g[i] - g[j].conj()) * Complex(0, -0.5);
p[i] = fa * ga + fa * gb * Complex(0, 1);
q[i] = fb * ga + fb * gb * Complex(0, 1);
}
fft(p, lim, -1); fft(q, lim, -1);
for (int i = 0; i <= n + m; i++) {
ll v0 = (ll)(p[i].x + 0.5L) % mod;
ll v1 = ((ll)(p[i].y + 0.5L) + (ll)(q[i].x + 0.5L)) % mod;
ll v2 = (ll)(q[i].y + 0.5L) % mod;
c[i] = (v0 + (v1 << SPLIT) + (v2 << (SPLIT << 1))) % mod;
}
}
int A[N], B[N], C[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m, mod;
cin >> n >> m >> mod;
for (int i = 0; i <= n; i++) cin >> A[i];
for (int i = 0; i <= m; i++) cin >> B[i];
multiply(A, B, C, n, m, mod);
for (int i = 0; i <= n + m; i++)
cout << C[i] << " \n"[i == n + m];
return 0;
}