码迷,mamicode.com
首页 > 其他好文 > 详细

[xsy2781]快速插值

时间:2018-06-20 14:33:34      阅读:93      评论:0      收藏:0      [点我收藏+]

标签:==   ...   string   ++   a*   多点   --   limit   bst   

被神秘力量驱使去学这个东西...

由$(x_i,y_i)(0\leq i\leq n)$构造多项式$L(x)=\sum\limits_{i=0}^ny_i\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}\frac{x-x_j}{x_i-x_j}$

观察$L(x_k)$里的$i,j$,如果$i=k$,那么右边的大pi每一项的分子分母都相同,值为$1$,否则右边的大pi在$j=k$时存在一个分子为$0$,这个多项式对$\forall0\leq i\leq n$满足$L(x_i)=y_i$

大pi中,分母是常数,所以我们要对每个$i$求出$\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)$,直接对$\frac{\mathrm d}{\mathrm dx}\prod\limits_{i=0}^n(x-x_i)$在所有$x_i$处多点求值即可

分子用分治求就可以了,合并的时候(左边的答案乘右边的$\prod(x-x_i)$)加上(左边的$\prod(x-x_i)$乘右边的答案)即可,递归到底层返回预处理好的$\dfrac{y_i}{\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)}$就可以了

总时间复杂度$O(n\log_2^2n)$,空间复杂度$O(n\log_2n)$,常数巨大,我写得太丑了,$n=50000$要跑$3$秒多

#include<stdio.h>
#include<string.h>
typedef long long ll;
const int mod=998244353,maxn=131072;
void swap(int&a,int&b){
	int c=a;
	a=b;
	b=c;
}
int max(int a,int b){return a>b?a:b;}
int mul(int a,int b){return a*(ll)b%mod;}
int ad(int a,int b){return(a+b)%mod;}
int de(int a,int b){return(a-b)%mod;}
int pow(int a,int b){
	int s=1;
	while(b){
		if(b&1)s=mul(s,a);
		a=mul(a,a);
		b>>=1;
	}
	return s;
}
int rev[maxn],N,iN;
void pre(int n){
	int i,k;
	for(N=1,k=0;N<n;N<<=1)k++;
	for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	iN=pow(N,mod-2);
}
void ntt(int*a,int on){
	int i,j,k,t,w,wn;
	for(i=0;i<N;i++){
		if(i<rev[i])swap(a[i],a[rev[i]]);
	}
	for(i=2;i<=N;i<<=1){
		wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i));
		for(j=0;j<N;j+=i){
			w=1;
			for(k=0;k<i>>1;k++){
				t=mul(w,a[i/2+j+k]);
				a[i/2+j+k]=de(a[j+k],t);
				a[j+k]=ad(a[j+k],t);
				w=mul(w,wn);
			}
		}
	}
	if(on==-1){
		for(i=0;i<N;i++)a[i]=mul(a[i],iN);
	}
}
int t0[maxn];
void getinv(int*a,int*b,int n){
	if(n==1){
		b[0]=pow(a[0],mod-2);
		return;
	}
	int i;
	getinv(a,b,n>>1);
	pre(n<<1);
	memset(t0,0,N<<2);
	memcpy(t0,a,n<<2);
	ntt(t0,1);
	ntt(b,1);
	for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(t0[i],b[i]));
	ntt(b,-1);
	for(i=n;i<N;i++)b[i]=0;
}
int ta[maxn],tb[maxn],tc[maxn];
#define clr while(k!=0&&tc[k]==0)k--;			memcpy(c,tc,(k+1)<<2);
void add(int*a,int n,int*b,int m,int*c,int&k){
	k=max(n,m);
	memset(ta,0,(k+1)<<2);
	memcpy(ta,a,(n+1)<<2);
	memset(tb,0,(k+1)<<2);
	memcpy(tb,b,(m+1)<<2);
	for(int i=0;i<=k;i++)tc[i]=ad(ta[i],tb[i]);
	clr
}
void dec(int*a,int n,int*b,int m,int*c,int&k){
	k=max(n,m);
	memset(ta,0,(k+1)<<2);
	memcpy(ta,a,(n+1)<<2);
	memset(tb,0,(k+1)<<2);
	memcpy(tb,b,(m+1)<<2);
	for(int i=0;i<=k;i++)tc[i]=de(ta[i],tb[i]);
	clr
}
void dif(int*a,int n,int*c,int&k){
	k=n-1;
	for(int i=1;i<=n;i++)c[i-1]=mul(i,a[i]);
}
void reverse(int*a,int n){
	for(int i=0;i<=n>>1;i++)swap(a[i],a[n-i]);
}
void mul(int*a,int n,int*b,int m,int*c,int&k){
	int i;
	k=n+m;
	pre(k+1);
	memset(ta,0,N<<2);
	memcpy(ta,a,(n+1)<<2);
	memset(tb,0,N<<2);
	memcpy(tb,b,(m+1)<<2);
	ntt(ta,1);
	ntt(tb,1);
	for(i=0;i<N;i++)tc[i]=mul(ta[i],tb[i]);
	ntt(tc,-1);
	clr
}
int t1[maxn];
void div(int*a,int n,int*b,int m,int*c,int&k){
	if(n<m){
		k=0;
		return;
	}
	int i,rn;
	for(rn=1;rn<n-m+1;rn<<=1);
	memset(ta,0,rn<<3);
	memcpy(ta,a,(n+1)<<2);
	memset(tb,0,rn<<3);
	memcpy(tb,b,(m+1)<<2);
	reverse(tb,m);
	for(i=rn;i<=m;i++)tb[i]=0;
	memset(t1,0,rn<<3);
	getinv(tb,t1,rn);
	pre(rn<<1);
	reverse(ta,n);
	for(i=rn;i<=n;i++)ta[i]=0;
	ntt(ta,1);
	ntt(t1,1);
	for(i=0;i<N;i++)tc[i]=mul(ta[i],t1[i]);
	ntt(tc,-1);
	k=n-m;
	reverse(tc,k);
	clr
}
void modulo(int*a,int n,int*b,int m,int*c,int&k){
	if(n<m){
		k=n;
		memcpy(c,a,(n+1)<<2);
		return;
	}
	div(a,n,b,m,t1,k);
	mul(t1,k,b,m,t1,k);
	dec(a,n,t1,k,c,k);
}
int X[50010],Y[50010],*tr[200010],go;
void build(int l,int r,int x){
	if(l==r){
		tr[x]=new int[2];
		tr[x][1]=1;
		tr[x][0]=-X[l];
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,x<<1);
	build(mid+1,r,x<<1|1);
	tr[x]=new int[r-l+2];
	mul(tr[x<<1],mid-l+1,tr[x<<1|1],r-mid,tr[x],go);
}
void solve(int*f,int n,int l,int r,int x,int*ans){
	int mid=(l+r)>>1,*now;
	now=new int[r-l+1];
	modulo(f,n,tr[x],r-l+1,now,n);
	if(l==r){
		ans[l]=now[0];
		return;
	}
	solve(now,n,l,mid,x<<1,ans);
	solve(now,n,mid+1,r,x<<1|1,ans);
}
int di[50010],res[50010];
int*solve(int l,int r,int x){
	int mid=(l+r)>>1,*res,n1,n2,*t1,*t2;
	res=new int[r-l+1];
	if(l==r){
		res[0]=Y[l];
		return res;
	}
	t1=new int[r-mid+1];
	mul(solve(l,mid,x<<1),mid-l,tr[x<<1|1],r-mid,t1,n1);
	t2=new int[r-mid+1];
	mul(tr[x<<1],mid-l+1,solve(mid+1,r,x<<1|1),r-mid-1,t2,n2);
	add(t1,n1,t2,n2,res,go);
	return res;
}
int main(){
	int n,i,*ans;
	scanf("%d",&n);
	for(i=0;i<=n;i++)scanf("%d%d",X+i,Y+i);
	build(0,n,1);
	dif(tr[1],n+1,di,i);
	solve(di,n,0,n,1,res);
	for(i=0;i<=n;i++)Y[i]=mul(Y[i],pow(res[i],mod-2));
	ans=solve(0,n,1);
	for(i=0;i<=n;i++)printf("%d ",ad(ans[i],mod));
}

[xsy2781]快速插值

标签:==   ...   string   ++   a*   多点   --   limit   bst   

原文地址:https://www.cnblogs.com/jefflyy/p/9203230.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!