标签:多项式 i+1 static 乘法 const cpp length ini stat
NTT
, mul
inverse
integral
,微分deriv
ln
exp
power
(常数项为\(1\))sqrt
(常数项为\(1\))const int N = 262144, G = 3, iG = 332748118, p = 998244353, maxlen = N;
inline int add(int x, int y){return (x+y) % p;}
inline int sub(int x, int y){return (x-y+p) % p;}
inline int mul(int x, int y){return 1LL * x * y % p;}
int last, rev[N], w[N], invw[N], pinv[N];
inline int qp(int x, int y){
int res = 1;
for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
return res;
}
void pre(){
w[1] = qp(G, (p-1) / maxlen);
invw[1] = qp(iG, (p-1) / maxlen);
w[0] = invw[0] = 1;
for(int i=2; i<maxlen; i++) w[i] = mul(w[i-1], w[1]), invw[i] = mul(invw[i-1], invw[1]);
pinv[1] = 1;
for(int i=2; i<maxlen; i++) pinv[i] = mul(p - p / i, pinv[p % i]);
}
void init(int n){
if(last == n) return ;
last = n;
int tp = n, lg = -1;
while(tp != 1) tp >>= 1, ++lg;
for(int i=0; i<n; i++)
rev[i] = (rev[i>>1] >> 1) | ((i & 1) << lg);
}
void NTT(int *f, int n, int *w){
init(n);
for(int i=0; i<n; i++) if(rev[i] > i) swap(f[i], f[rev[i]]);
for(int l=1; l<n; l<<=1){
int step = maxlen / (l << 1);
for(int i=0; i<n; i += (l << 1)){
for(int j=i, Wj=0; j<i+l; j++, Wj += step){
int x = f[j], y = mul(f[j+l], w[Wj]);
f[j] = add(x, y); f[j+l] = sub(x, y);
}
}
}
}
void mul(int *f, int *g, int n){ // will make g unavailable
NTT(f, n, w); NTT(g, n, w);
for(int i=0; i<n; i++) f[i] = mul(f[i], g[i]);
NTT(f, n, invw);
int in = qp(n, p-2);
for(int i=0; i<n; i++) f[i] = mul(f[i], in);
}
void inverse(int *f, int n){ // n : pow2 length of f
static int g[N], tp[N];
g[0] = qp(f[0], p-2);
for(int i=1; i<n; i<<=1){ // i : last length of g
copy(f, f+(i<<1), tp);
NTT(tp, i<<2, w); NTT(g, i<<2, w);
for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], sub(2, mul(tp[j], g[j])));
NTT(g, i<<2, invw);
int invLen = qp((i<<2), p-2);
for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], invLen);
fill(g + (i << 1), g + (i << 2), 0);
}
copy(g, g+n, f);
fill(g, g+n, 0);
fill(tp, tp+(n<<1), 0);
}
void deriv(int *a, int n){
for(int i=1; i<n; i++) a[i-1] = mul(a[i], i);
a[n-1] = 0;
}
void integral(int *a, int n){
for(int i=n-1; i>=0; i--) a[i] = mul(a[i-1], pinv[i]);
a[0] = 0;
}
void ln(int *a, int n){
static int tp[N];
for(int i=0; i<n-1; i++) tp[i] = mul(a[i+1], i+1);
tp[n-1] = 0;
inverse(a, n);
mul(a, tp, n<<1);
fill(a+n, a+(n<<1), 0);
integral(a, n);
fill(tp, tp+(n<<1), 0);
}
void exp(int *a, int n){
static int f[N], lnf[N], f0[N];
f[0] = 1;
for(int i=1; i<n; i<<=1){
copy(f, f + i, lnf); copy(f, f + i, f0);
ln(lnf, i << 1);
for(int j=0; j<(i << 1); j++) f[j] = sub(a[j], lnf[j]);
++f[0];
mul(f, f0, i << 2);
fill(f + (i << 1), f + (i << 2), 0);
fill(f0 + (i << 1), f0 + (i << 2), 0);
}
copy(f, f+n, a);
fill(f, f + n, 0);
fill(lnf, lnf + n, 0);
fill(f0, f0 + n, 0);
}
void power(int *a, int n, int p){
if(p == 1) return ;
ln(a, n);
for(int i=1; i<n; i++) a[i] = mul(a[i], p);
exp(a, n);
}
void sqrt(int *a, int n){
static int f[N], f0[N], invf[N], tp[N];
f[0] = 1;
for(int i=1; i<n; i<<=1){
copy(a, a + (i << 1), tp);
copy(f, f + i, invf);
inverse(invf, i<<1);
copy(f, f + i, f0);
NTT(tp, i << 2, w);
NTT(f, i << 2, w);
NTT(invf, i << 2, w);
for(int j=0; j<(i<<2); j++) f[j] = mul(sub(mul(f[j], f[j]), tp[j]), mul(pinv[2], invf[j]));
NTT(f, i << 2, invw);
int invLen = qp(i << 2, p - 2);
for(int j=0; j<(i<<1); j++) f[j] = sub(f0[j], mul(f[j], invLen));
fill(f + (i << 1), f + (i << 2), 0);
fill(invf + (i << 1), invf + (i << 2), 0);
}
copy(f, f+n, a);
fill(f, f + n, 0);
fill(tp, tp + (n << 1), 0);
}
标签:多项式 i+1 static 乘法 const cpp length ini stat
原文地址:https://www.cnblogs.com/RiverHamster/p/polynomial-template.html