标签:\n mem limits sum int end string ring long
$\newcommand{align}[1]{\begin{align*}#1\end{align*}}$题意:给出$f(x)=\prod\limits_{i=1}^n(a_ix+1)$和$g(x)=\prod\limits_{j=1}^m(b_jx+1)$的各项系数,求$h(x)=\prod\limits_{i=1}^n\prod\limits_{j=1}^m(a_ib_jx+1)$的前$k$项系数
乘积的形式不好处理,取个$\ln$就可以化为和的形式
$\align{\ln(a_ix+1)=\int\frac{a_i}{a_ix+1}\mathrm dx=\sum\limits_{k\geq1}\dfrac{(-1)^{k+1}}ka_i^kx^k}$
$\align{\ln(b_jx+1)=\sum\limits_{k\geq1}\dfrac{(-1)^{k+1}}kb_j^kx^k}$
$\align{\ln(a_ib_jx+1)=\sum\limits_{k\geq1}\dfrac{(-1)^{k+1}}ka_i^kb_j^kx^k}$
容易发现我们把前两个式子的系数做点积,再对$\forall k\geq1$把第$k$位乘上$(-1)^{-k+1}k$就得到了第三个式子的系数
所以我们把$\ln f(x)$和$\ln g(x)$的系数如此处理后就得到了$\ln h(x)$的系数,$\exp$回去即可
#include<stdio.h> #include<string.h> typedef long long ll; const int mod=998244353,maxn=262144; void swap(int&a,int&b){a^=b^=a^=b;} int mul(int a,int b){return a*(ll)b%mod;} int ad(int a,int b){return(a+b)%mod;} int de(int a,int b){return(a-b)%mod;} int pow(int a,int b){ int s=1; while(b){ if(b&1)s=mul(s,a); a=mul(a,a); b>>=1; } return s; } int rev[maxn],N,iN; void pre(int n){ int i,k; for(N=1,k=0;N<n;N<<=1)k++; for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1)); iN=pow(N,mod-2); } void ntt(int*a,int on){ int i,j,k,t,w,wn; for(i=0;i<N;i++){ if(i<rev[i])swap(a[i],a[rev[i]]); } for(i=2;i<=N;i<<=1){ wn=pow(3,on==1?(mod-1)/i:(mod-1-(mod-1)/i)); for(j=0;j<N;j+=i){ w=1; for(k=0;k<i>>1;k++){ t=mul(w,a[i/2+j+k]); a[i/2+j+k]=de(a[j+k],t); a[j+k]=ad(a[j+k],t); w=mul(w,wn); } } } if(on==-1){ for(i=0;i<N;i++)a[i]=mul(a[i],iN); } } int t0[maxn]; void getinv(int*a,int*b,int n){ if(n==1){ b[0]=pow(a[0],mod-2); return; } int i; getinv(a,b,n>>1); pre(n<<1); memset(t0,0,N<<2); memcpy(t0,a,n<<2); ntt(t0,1); ntt(b,1); for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(b[i],t0[i])); ntt(b,-1); for(i=n;i<N;i++)b[i]=0; } int t1[maxn],inv[maxn]; void getln(int*a,int*b,int n){ int i; memset(t1,0,n<<3); getinv(a,t1,n); for(i=1;i<n;i++)b[i-1]=mul(i,a[i]); ntt(b,1); ntt(t1,1); for(i=0;i<N;i++)b[i]=mul(b[i],t1[i]); ntt(b,-1); for(i=n-1;i>0;i--)b[i]=mul(b[i-1],inv[i]); b[0]=0; for(i=n;i<N;i++)b[i]=0; } int t2[maxn]; void exp(int*a,int*b,int n){ if(n==1){ b[0]=1; return; } int i; exp(a,b,n>>1); memset(t2,0,n<<3); getln(b,t2,n); for(i=0;i<n;i++)t2[i]=de(a[i],t2[i]); t2[0]++; ntt(b,1); ntt(t2,1); for(i=0;i<N;i++)b[i]=mul(b[i],t2[i]); ntt(b,-1); for(i=n;i<N;i++)b[i]=0; } int f[maxn],g[maxn],lf[maxn],lg[maxn],lh[maxn],h[maxn]; int main(){ int N,n,m,k,i; scanf("%d%d%d",&n,&m,&k); for(i=0;i<=n;i++)scanf("%d",f+i); for(i=0;i<=m;i++)scanf("%d",g+i); for(N=1;N<n+1||N<m+1||N<k+1;N<<=1); inv[1]=1; for(i=2;i<N;i++)inv[i]=-mul(mod/i,inv[mod%i]); getln(f,lf,N); getln(g,lg,N); for(i=0;i<N;i++)lh[i]=mul(mul(lf[i],lg[i]),i&1?i:-i); exp(lh,h,N); for(i=0;i<k;i++)printf("%d ",ad(h[i],mod)); }
标签:\n mem limits sum int end string ring long
原文地址:https://www.cnblogs.com/jefflyy/p/9205525.html