标签:== ... string ++ a* 多点 -- limit bst
被神秘力量驱使去学这个东西...
由$(x_i,y_i)(0\leq i\leq n)$构造多项式$L(x)=\sum\limits_{i=0}^ny_i\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}\frac{x-x_j}{x_i-x_j}$
观察$L(x_k)$里的$i,j$,如果$i=k$,那么右边的大pi每一项的分子分母都相同,值为$1$,否则右边的大pi在$j=k$时存在一个分子为$0$,这个多项式对$\forall0\leq i\leq n$满足$L(x_i)=y_i$
大pi中,分母是常数,所以我们要对每个$i$求出$\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)$,直接对$\frac{\mathrm d}{\mathrm dx}\prod\limits_{i=0}^n(x-x_i)$在所有$x_i$处多点求值即可
分子用分治求就可以了,合并的时候(左边的答案乘右边的$\prod(x-x_i)$)加上(左边的$\prod(x-x_i)$乘右边的答案)即可,递归到底层返回预处理好的$\dfrac{y_i}{\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)}$就可以了
总时间复杂度$O(n\log_2^2n)$,空间复杂度$O(n\log_2n)$,常数巨大,我写得太丑了,$n=50000$要跑$3$秒多
#include<stdio.h> #include<string.h> typedef long long ll; const int mod=998244353,maxn=131072; void swap(int&a,int&b){ int c=a; a=b; b=c; } int max(int a,int b){return a>b?a:b;} int mul(int a,int b){return a*(ll)b%mod;} int ad(int a,int b){return(a+b)%mod;} int de(int a,int b){return(a-b)%mod;} int pow(int a,int b){ int s=1; while(b){ if(b&1)s=mul(s,a); a=mul(a,a); b>>=1; } return s; } int rev[maxn],N,iN; void pre(int n){ int i,k; for(N=1,k=0;N<n;N<<=1)k++; for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1)); iN=pow(N,mod-2); } void ntt(int*a,int on){ int i,j,k,t,w,wn; for(i=0;i<N;i++){ if(i<rev[i])swap(a[i],a[rev[i]]); } for(i=2;i<=N;i<<=1){ wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i)); for(j=0;j<N;j+=i){ w=1; for(k=0;k<i>>1;k++){ t=mul(w,a[i/2+j+k]); a[i/2+j+k]=de(a[j+k],t); a[j+k]=ad(a[j+k],t); w=mul(w,wn); } } } if(on==-1){ for(i=0;i<N;i++)a[i]=mul(a[i],iN); } } int t0[maxn]; void getinv(int*a,int*b,int n){ if(n==1){ b[0]=pow(a[0],mod-2); return; } int i; getinv(a,b,n>>1); pre(n<<1); memset(t0,0,N<<2); memcpy(t0,a,n<<2); ntt(t0,1); ntt(b,1); for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(t0[i],b[i])); ntt(b,-1); for(i=n;i<N;i++)b[i]=0; } int ta[maxn],tb[maxn],tc[maxn]; #define clr while(k!=0&&tc[k]==0)k--; memcpy(c,tc,(k+1)<<2); void add(int*a,int n,int*b,int m,int*c,int&k){ k=max(n,m); memset(ta,0,(k+1)<<2); memcpy(ta,a,(n+1)<<2); memset(tb,0,(k+1)<<2); memcpy(tb,b,(m+1)<<2); for(int i=0;i<=k;i++)tc[i]=ad(ta[i],tb[i]); clr } void dec(int*a,int n,int*b,int m,int*c,int&k){ k=max(n,m); memset(ta,0,(k+1)<<2); memcpy(ta,a,(n+1)<<2); memset(tb,0,(k+1)<<2); memcpy(tb,b,(m+1)<<2); for(int i=0;i<=k;i++)tc[i]=de(ta[i],tb[i]); clr } void dif(int*a,int n,int*c,int&k){ k=n-1; for(int i=1;i<=n;i++)c[i-1]=mul(i,a[i]); } void reverse(int*a,int n){ for(int i=0;i<=n>>1;i++)swap(a[i],a[n-i]); } void mul(int*a,int n,int*b,int m,int*c,int&k){ int i; k=n+m; pre(k+1); memset(ta,0,N<<2); memcpy(ta,a,(n+1)<<2); memset(tb,0,N<<2); memcpy(tb,b,(m+1)<<2); ntt(ta,1); ntt(tb,1); for(i=0;i<N;i++)tc[i]=mul(ta[i],tb[i]); ntt(tc,-1); clr } int t1[maxn]; void div(int*a,int n,int*b,int m,int*c,int&k){ if(n<m){ k=0; return; } int i,rn; for(rn=1;rn<n-m+1;rn<<=1); memset(ta,0,rn<<3); memcpy(ta,a,(n+1)<<2); memset(tb,0,rn<<3); memcpy(tb,b,(m+1)<<2); reverse(tb,m); for(i=rn;i<=m;i++)tb[i]=0; memset(t1,0,rn<<3); getinv(tb,t1,rn); pre(rn<<1); reverse(ta,n); for(i=rn;i<=n;i++)ta[i]=0; ntt(ta,1); ntt(t1,1); for(i=0;i<N;i++)tc[i]=mul(ta[i],t1[i]); ntt(tc,-1); k=n-m; reverse(tc,k); clr } void modulo(int*a,int n,int*b,int m,int*c,int&k){ if(n<m){ k=n; memcpy(c,a,(n+1)<<2); return; } div(a,n,b,m,t1,k); mul(t1,k,b,m,t1,k); dec(a,n,t1,k,c,k); } int X[50010],Y[50010],*tr[200010],go; void build(int l,int r,int x){ if(l==r){ tr[x]=new int[2]; tr[x][1]=1; tr[x][0]=-X[l]; return; } int mid=(l+r)>>1; build(l,mid,x<<1); build(mid+1,r,x<<1|1); tr[x]=new int[r-l+2]; mul(tr[x<<1],mid-l+1,tr[x<<1|1],r-mid,tr[x],go); } void solve(int*f,int n,int l,int r,int x,int*ans){ int mid=(l+r)>>1,*now; now=new int[r-l+1]; modulo(f,n,tr[x],r-l+1,now,n); if(l==r){ ans[l]=now[0]; return; } solve(now,n,l,mid,x<<1,ans); solve(now,n,mid+1,r,x<<1|1,ans); } int di[50010],res[50010]; int*solve(int l,int r,int x){ int mid=(l+r)>>1,*res,n1,n2,*t1,*t2; res=new int[r-l+1]; if(l==r){ res[0]=Y[l]; return res; } t1=new int[r-mid+1]; mul(solve(l,mid,x<<1),mid-l,tr[x<<1|1],r-mid,t1,n1); t2=new int[r-mid+1]; mul(tr[x<<1],mid-l+1,solve(mid+1,r,x<<1|1),r-mid-1,t2,n2); add(t1,n1,t2,n2,res,go); return res; } int main(){ int n,i,*ans; scanf("%d",&n); for(i=0;i<=n;i++)scanf("%d%d",X+i,Y+i); build(0,n,1); dif(tr[1],n+1,di,i); solve(di,n,0,n,1,res); for(i=0;i<=n;i++)Y[i]=mul(Y[i],pow(res[i],mod-2)); ans=solve(0,n,1); for(i=0;i<=n;i++)printf("%d ",ad(ans[i],mod)); }
标签:== ... string ++ a* 多点 -- limit bst
原文地址:https://www.cnblogs.com/jefflyy/p/9203230.html