码迷,mamicode.com
首页 > 其他好文 > 详细

多项式板子(待补充)

时间:2020-01-21 18:16:11      阅读:86      评论:0      收藏:0      [点我收藏+]

标签:多项式   i+1   static   乘法   const   cpp   length   ini   stat   

已经实现的操作

  • 乘法NTT, mul
  • 求逆inverse
  • 积分integral,微分deriv
  • 对数ln
  • 指数exp
  • 快速幂power(常数项为\(1\)
  • 开根sqrt(常数项为\(1\)

说明

  • 所有操作都是原地操作,以后有空改(大量copy还是比较麻烦、耗时)
  • 常数还可以

Code

const int N = 262144, G = 3, iG = 332748118, p = 998244353, maxlen = N;
inline int add(int x, int y){return (x+y) % p;}
inline int sub(int x, int y){return (x-y+p) % p;}
inline int mul(int x, int y){return 1LL * x * y % p;}
int last, rev[N], w[N], invw[N], pinv[N];
inline int qp(int x, int y){
    int res = 1;
    for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
    return res;
}

void pre(){
    w[1] = qp(G, (p-1) / maxlen);
    invw[1] = qp(iG, (p-1) / maxlen);
    w[0] = invw[0] = 1;
    for(int i=2; i<maxlen; i++) w[i] = mul(w[i-1], w[1]), invw[i] = mul(invw[i-1], invw[1]);
    pinv[1] = 1;
    for(int i=2; i<maxlen; i++) pinv[i] = mul(p - p / i, pinv[p % i]);
}

void init(int n){
    if(last == n) return ;
    last = n;
    int tp = n, lg = -1;
    while(tp != 1) tp >>= 1, ++lg;
    for(int i=0; i<n; i++)
        rev[i] = (rev[i>>1] >> 1) | ((i & 1) << lg);
}
void NTT(int *f, int n, int *w){
    init(n);
    for(int i=0; i<n; i++) if(rev[i] > i) swap(f[i], f[rev[i]]);
    for(int l=1; l<n; l<<=1){
        int step = maxlen / (l << 1);
        for(int i=0; i<n; i += (l << 1)){
            for(int j=i, Wj=0; j<i+l; j++, Wj += step){
                int x = f[j], y = mul(f[j+l], w[Wj]);
                f[j] = add(x, y); f[j+l] = sub(x, y);
            }
        }
    }
}
void mul(int *f, int *g, int n){ // will make g unavailable
    NTT(f, n, w); NTT(g, n, w);
    for(int i=0; i<n; i++) f[i] = mul(f[i], g[i]);
    NTT(f, n, invw);
    int in = qp(n, p-2);
    for(int i=0; i<n; i++) f[i] = mul(f[i], in);
}

void inverse(int *f, int n){ // n : pow2 length of f
    static int g[N], tp[N];
    g[0] = qp(f[0], p-2);
    for(int i=1; i<n; i<<=1){ // i : last length of g
        copy(f, f+(i<<1), tp);
        NTT(tp, i<<2, w); NTT(g, i<<2, w);
        for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], sub(2, mul(tp[j], g[j])));
        NTT(g, i<<2, invw);
        int invLen = qp((i<<2), p-2);
        for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], invLen);
        fill(g + (i << 1), g + (i << 2), 0);
    }
    copy(g, g+n, f);
    fill(g, g+n, 0);
    fill(tp, tp+(n<<1), 0);
}

void deriv(int *a, int n){
    for(int i=1; i<n; i++) a[i-1] = mul(a[i], i);
    a[n-1] = 0;
}
void integral(int *a, int n){
    for(int i=n-1; i>=0; i--) a[i] = mul(a[i-1], pinv[i]);
    a[0] = 0;
}
void ln(int *a, int n){
    static int tp[N];
    for(int i=0; i<n-1; i++) tp[i] = mul(a[i+1], i+1);
    tp[n-1] = 0;
    inverse(a, n);
    mul(a, tp, n<<1);
    fill(a+n, a+(n<<1), 0);
    integral(a, n);
    fill(tp, tp+(n<<1), 0);
}
void exp(int *a, int n){
    static int f[N], lnf[N], f0[N];
    f[0] = 1;
    for(int i=1; i<n; i<<=1){
        copy(f, f + i, lnf); copy(f, f + i, f0);
        ln(lnf, i << 1);
        for(int j=0; j<(i << 1); j++) f[j] = sub(a[j], lnf[j]);
        ++f[0];
        mul(f, f0, i << 2);
        fill(f + (i << 1), f + (i << 2), 0);
        fill(f0 + (i << 1), f0 + (i << 2), 0);
    }
    copy(f, f+n, a);
    fill(f, f + n, 0);
    fill(lnf, lnf + n, 0);
    fill(f0, f0 + n, 0);
}
void power(int *a, int n, int p){
    if(p == 1) return ;
    ln(a, n);
    for(int i=1; i<n; i++) a[i] = mul(a[i], p);
    exp(a, n);
}
void sqrt(int *a, int n){
    static int f[N], f0[N], invf[N], tp[N];
    f[0] = 1;
    for(int i=1; i<n; i<<=1){
        copy(a, a + (i << 1), tp);
        copy(f, f + i, invf);
        inverse(invf, i<<1);
        copy(f, f + i, f0);
        NTT(tp, i << 2, w);
        NTT(f, i << 2, w);
        NTT(invf, i << 2, w);
        for(int j=0; j<(i<<2); j++) f[j] = mul(sub(mul(f[j], f[j]), tp[j]), mul(pinv[2], invf[j]));
        NTT(f, i << 2, invw);
        int invLen = qp(i << 2, p - 2);
        for(int j=0; j<(i<<1); j++) f[j] = sub(f0[j], mul(f[j], invLen));
        fill(f + (i << 1), f + (i << 2), 0);
        fill(invf + (i << 1), invf + (i << 2), 0);
    }
    copy(f, f+n, a);
    fill(f, f + n, 0);
    fill(tp, tp + (n << 1), 0);
}

多项式板子(待补充)

标签:多项式   i+1   static   乘法   const   cpp   length   ini   stat   

原文地址:https://www.cnblogs.com/RiverHamster/p/polynomial-template.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!