码迷,mamicode.com
首页 > 其他好文 > 详细

多项式模板

时间:2020-03-10 15:44:48      阅读:57      评论:0      收藏:0      [点我收藏+]

标签:long   name   div   max   std   deb   style   include   void   

#include<bits/stdc++.h>
using namespace std;
#define mo 998244353
#define N 500010
#define ll unsigned long long
#define int long long
#define pl vector<int>
int qp(int x,int y){
    int r=1;
    for(;y;y>>=1,x=1ll*x*x%mo)
        if(y&1)r=1ll*r*x%mo;
    return r;
}
int n,m,rev[N],v,le,w[N],p[N],ans[N];
void deb(pl x){
    for(int i:x)cout<<i<< ;
    puts("");
}
void init(int n){
    v=1;
    le=0;
    while(v<n)le++,v*=2;
    for(signed i=0;i<v;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(le-1));
    int g=qp(3,(mo-1)/v);
    w[v/2]=1;
    for(int i=v/2+1;i<v;i++)
        w[i]=1ull*w[i-1]*g%mo;
    for(signed i=v/2-1;~i;i--)
        w[i]=w[i*2];
}
void fft(int v,pl &a,int t){
    static unsigned long long b[N];
    int s=le-__builtin_ctz(v);
       for(int i=0;i<v;i++)
           b[rev[i]>>s]=a[i];
    int c=0;
    w[0]=1;
    for(signed i=1;i<v;i*=2,c++)
        for(signed r=i*2,j=0;j<v;j+=r)
            for(signed k=0;k<i;k++){
                   int tx=b[j+i+k]*w[k+i]%mo;
                b[j+i+k]=b[j+k]+mo-tx;
                b[j+k]+=tx;
            }
    for(int i=0;i<v;i++)
        a[i]=b[i]%mo;
    if(t==0)return;
    int iv=qp(v,mo-2);
    for(signed i=0;i<v;i++)
        a[i]=1ull*a[i]*iv%mo;
    a.resize(v);
    reverse(a.begin()+1,a.end());
}
pl operator *(pl x,pl y){
    int s=x.size()+y.size()-1;
    if(x.size()<=20||y.size()<=20){
        pl r;
        r.resize(s);
        for(int i=0;i<x.size();i++)
            for(int j=0;j<y.size();j++)
                r[i+j]=(r[i+j]+x[i]*y[j])%mo;
        return r;
    }
    init(s);
    x.resize(v);
    y.resize(v);
    fft(v,x,0);
    fft(v,y,0);
    //deb(x);
    //deb(y);
    for(int i=0;i<v;i++)
        x[i]=x[i]*y[i]%mo;
    fft(v,x,1);
    x.resize(s);
    return x;
}
void inv(int n,pl &b,pl &a){
    if(n==1){
        b[0]=qp(a[0],mo-2);
        return;
    }
    inv((n+1)/2,b,a);
    static pl c;
    init(n*2);
    c.resize(v);
    b.resize(v);
    for(int i=0;i<n;i++)
        c[i]=a[i];
    fft(v,c,0);
    //deb(c);
    fft(v,b,0);
    //deb(b);
    for(int i=0;i<v;i++)
        b[i]=1ll*(2ll-1ll*c[i]*b[i]%mo+mo)%mo*b[i]%mo;
    //deb(b);
    fft(v,b,1);  
       b.resize(n);
       //deb(b);
}
void ad(pl &x,pl y,int l){
    x.resize(max((int)x.size(),(int)y.size()+l));
    for(int i=0;i<y.size();i++)
        x[i+l]=(x[i+l]+y[i])%mo;
}
pl operator +(pl x,pl y){
    ad(x,y,0);
    return x;
}
pl iv(pl x){
    pl y;
    int n=x.size();
    y.resize(n);
    inv(n,y,x);
    y.resize(n);
    return y;
}
pl operator /(pl a,pl y){
    int n=a.size()-1,m=y.size()-1;
    pl x,b,t;
    x.resize(n+1);
    b.resize(m+1);
    for(int i=0;i<=n;i++)
        x[n-i]=a[i];
    for(int i=0;i<=m;i++)
        b[m-i]=y[i];
    for(int i=n-m+2;i<=m;i++)
        b[i]=0;
    b.resize(n-m+1);
    t=iv(b);
    //deb(t);
    //deb(x);
    //deb(t);
    x=x*t;
    //deb(x);
    x.resize(n-m+1);
    reverse(x.begin(),x.end());
    return x;
}
pl operator -(pl x,pl y){
    int s=max(x.size(),y.size());
    x.resize(s);
    y.resize(s);
    for(int i=0;i<s;i++)
        x[i]=(x[i]-y[i]+mo)%mo;
    return x;
}
pl operator %(pl x,pl y){
    int n=(int)x.size()-1,m=(int)y.size()-1;
    if(x.size()<y.size())return x;
    if(!m){
        pl a;
        a.resize(1);
        return a;
    }
    x=x-(x/y)*y;
    x.resize(m);
    return x;
}
pl qd(pl x){
    pl y;
    int n=x.size();
    y.resize(n-1);
    //deb(x);
    for(int i=0;i<n-1;i++)
        y[i]=x[i+1]*(i+1)%mo;
    //deb(y);
    return y;
}
pl jf(pl x){
    int n=x.size();
    pl y;
    y.resize(n+1);
    for(int i=1;i<=n;i++)
        y[i]=x[i-1]*qp(i,mo-2)%mo;
    return y;
}
pl ln(pl x){
    int n=x.size();
    pl y=qd(x),z=iv(x);
    y=y*z;
    y=jf(y);
    return y;
}
inline char nc(){
    static char buf[500000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,500000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
    char ch=nc();int sum=0;
    while(!(ch>=0&&ch<=9))ch=nc();
    while(ch>=0&&ch<=9)sum=sum*10+ch-48,ch=nc();
    return sum;
}
char bf[100];
void wr(int x){
    if(!x){
        putchar(0);
        putchar( );
        return;
    }
    int ct=0;
    while(x){
        bf[++ct]=x%10;
        x/=10;
    }
    for(int i=ct;i;i--)
        putchar(bf[i]+0);
    putchar( );
}
namespace qz{
    ll b[N],ans[N];
    pl t;
    void fz(int o,int l,int r,pl &p,pl *a){
        if(l==r){
            a[o].resize(2);
            a[o][0]=(mo-p[l])%mo;
            a[o][1]=1;
            return;
        }
        int md=(l+r)/2;
        fz(o*2,l,md,p,a);
        fz(o*2+1,md+1,r,p,a);
        a[o]=a[o*2]*a[o*2+1];
        //deb(a[o]);
    }
    void ga(int o,int l,int r,pl &ans,pl *a,pl *c){
        if(l==r){
            ans[l]=c[o][0];
            return;
        }
        int md=(l+r)/2;
        c[o*2]=c[o]%a[o*2];
        c[o*2+1]=c[o]%a[o*2+1];
        ga(o*2,l,md,ans,a,c);
        ga(o*2+1,md+1,r,ans,a,c);
    }
    void gt(pl t,pl &ans){
        int n=ans.size();
        static pl a[N],b[N];
        fz(1,0,n-1,ans,a);
        if(n>=m)b[1]=t%a[1];
        ga(1,0,n-1,ans,a,b);
    }
    void d2(int o,int l,int r,pl &y,pl *a,pl *b){
        if(l==r){
            b[o].resize(1);
            b[o][0]=y[l];
            return;
        }
        int md=(l+r)/2;
        d2(o*2,l,md,y,a,b);
        d2(o*2+1,md+1,r,y,a,b);
        b[o]=b[o*2]*a[o*2+1]+b[o*2+1]*a[o*2];
    }
};
pl cz(int n,pl &x,pl &y){
    static pl a[N],b[N];
    qz::fz(1,0,n-1,x,a);
    qz::gt(qd(a[1]),x);
    //deb(x);
    for(int i=0;i<n;i++)
        y[i]=y[i]*qp(x[i],mo-2)%mo;
    //deb(y);
    qz::d2(1,0,n-1,y,a,b);
    return b[1];
}
signed main(){
    
}

 

多项式模板

标签:long   name   div   max   std   deb   style   include   void   

原文地址:https://www.cnblogs.com/cszmc2004/p/12455988.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!