标签:swap lse 要求 get ini else swa wap 复杂
题面
题意: 本题包含3个Task:
其中n为点数,\(n,bas\)都是给定的,\(n\leq 10^5\),答案对\(998244353\)取模。
直接模拟即可。
#include<bits/stdc++.h>
using namespace std;
#define N 200007
#define ll long long
const ll mod=998244353;
struct str
{
int x,y;
};
bool operator <(str a,str b)
{
return a.x<b.x||a.x==b.x&&a.y<b.y;
}
set<str> S;
ll p2(ll x){return x*x%mod;}
ll pw(ll x,ll p)
{
return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
}
int main()
{
int x,y,n,op;
ll p;
scanf("%d%lld%d",&n,&p,&op);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
if(x>y)swap(x,y);
S.insert({x,y});
}
int cnt=0;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
if(x>y)swap(x,y);
if(S.find({x,y})!=S.end())cnt++;
}
ll ans=pw(p,n-cnt);
printf("%lld\n",ans);
return 0;
}
要求\(ans=\sum_{T}bas^{n-|S\bigcap T|}\)
我们设\(ans'=\sum_{T}bas^{-|S\bigcap T|}\),那么\(ans=bas^n*ans'\)
然后我们考虑怎么处理\(|S\bigcap T|\),可以考虑先枚举它,然后再枚举S和T。
那么我们就需要算出一个\(F(E)\),表示交集恰好为E的\(<S,T>\)的数量。
但是这东西不好算,那么我们再设一个\(G(E)\),表示交集至少为E的\(<S,T>\)的数量。
那么根据容斥原理我们有:
\[F(E)=\sum_{E\subseteq V}(-1)^{|V|-|E|}G(V)\]
然后我们有
\[ans'=\sum_{E\subseteq S}F(E)bas^{-|E|}\]
\[=\sum_{E\subseteq S}\sum_{E\subseteq V\subseteq S}(-1)^{|V|-|E|}G(V)bas^{-|E|}\]
\[=\sum_{V\subseteq S}(-1)^{|V|}G(V)\sum_{E\subseteq V}(-bas)^{-|E|}\]
\[=\sum_{V\subseteq S}(-1)^{|V|}G(V)\sum_{i=0}^{|V|}C_{|V|}^i(-bas^{-1})^{i}\]
\[=\sum_{V\subseteq S}(-1)^{|V|}G(V)(1-bas^{-1})^{|V|}\]
我们设\(P=(bas^{-1}-1)\),那么\(ans'=\sum_{V\subseteq S}G(V)P^{|V|}\)
然后我们考虑\(G(V)\)怎么算
至少包含\(V\),那么可以看成先将\(V\)中的边全部连上,形成\(n-|V|\)个连通块,然后把这些连通块连起来的方案数。
我们设将\(V\)中的边连上后形成的连通块大小分别为\(a_1,a_2...,a_{n-|V|}\),那么我们有结论:
\[G(V)=n^{n-|V|-2}\prod _{i=1}^{n-|V|}a_i\]
证明:
我们考虑如果我们将所有连通块连成一棵树,那么对应的边有多少种方案。
一条连接连通块i和连通块j的边的方案数为\(a_i*a_j\)
那么将整棵树连出来的方案数为\(\prod_{i=1}^{n-|V|}a_i^{deg_i}\),其中\(deg_i\)为连通块i在这棵树上的度数。
我们惊奇的发现\(deg_i\)就是这棵树对应的purfer序列中i出现的次数加1
那么我们就枚举所有的purfer序列,设其为P,数字i在P中的出现次数为\(times_i\)
\[G(V)=\sum_{P}\prod_{i=1}^{n-|V|}a_i^{times_i+1}\]
\[=(\prod_{i=1}^{n-|V|}a_i)\sum_{P}\prod_{i=1}^{n-|V|}a_i^{times_i}\]
发现后面这部分它就等于\((a_1+a_2...+a_{n-|V|})^{n-|V|-2}\),即\(n^{n-|V|-2}\)
于是就有
\[G(V)=n^{n-|V|-2}\prod _{i=1}^{n-|V|}a_i\]
证毕。
?
回到我们的问题,有
\[ans'=\sum_{V\subseteq S}n^{n-|V|-2}P^{|V|}\prod _{i=1}^{n-|V|}a_i\]
为了后续操作的方便,我们将\(n,P\)放到连乘里面
\[ans'=n^{-2}P^{n}\sum_{V\subseteq S}\prod _{i=1}^{n-|V|}a_inP^{-1}\]
然后我们设\(K=nP^{-1},ans''=\sum_{V\subseteq S}\prod _{i=1}^{n-|V|}a_iK\),那么\(ans'=ans''*n^{-2}P^n\)
我们考虑\(ans''\)怎么算
容易想到用树上dp解决,设\(dp[v][i]\)为在v子树内,与v相连的连通块大小为i的所有方案的贡献和(注意这里的贡献不包含与v相连的那个连通块)
另外我们规定\(dp[v][0]\)即为v子树内的答案。
那么有转移方程
\[dp[v][i]=\sum_{j=1}^idp[v][j]*dp[u][i-j]\]
\[dp[v][0]=\sum_{i=1}^ndp[v][i]*i*K\]
初始状态为\(dp[v][1]=1\)
最后的答案就是\(dp[1][1]\)
但是这样复杂度是\(O(n^2)\)的,考虑怎么优化
我们其实没有必要将每一个\(dp[v][i]\)都算出来,我们只需要知道它们整体的一个值就可以。
于是我们设\(f[v]=\sum_{i=0}^{n}dp[v][i]\),\(g[v]=dp[v][0]=\sum_{i=1}^ndp[v][i]*i*K\)
那么在转移过程中有
\[g[v]=\sum_{i=1}^n\sum_{j=1}^idp[v][j]*dp[u][i-j]*i*K\]
\[=\sum_{i=1}^n\sum_{j=0}^n(i+j)K*dp[v][i]*dp[u][j]\]
\[=\sum_{i=1}^n\sum_{j=0}^niK*dp[v][i]*dp[u][j]+\sum_{i=1}^n\sum_{j=0}^njK*dp[v][i]*dp[u][j]\]
\[=g[v]*f[u]+g[u]*f[v]\]
?
\[f[v]=\sum_{i=1}^n\sum_{j=1}^idp[v][j]*dp[u][i-j]\]
\[=(\sum_{i=1}^{n}dp[v][i])(\sum_{j=0}^{n}dp[u][j])\]
\[=f[v]*f[u]\]
最后还要\(f[v]+=g[v]\)
初始状态为\(f[v]=1,g[v]=K\)
这样就可以\(O(n)\)转移了
最后的答案为\(bas^nn^{-2}P^{n}g[1]\)
#include<bits/stdc++.h>
using namespace std;
#define N 200007
#define M 400007
#define ll long long
const ll mod=998244353;
int f[N],g[N],n,P,K,sz[N];
int hd[N],pre[M],to[M],num;
void adde(int x,int y)
{
num++;pre[num]=hd[x];hd[x]=num;to[num]=y;
}
void dfs(int v,int fa)
{
f[v]=1,g[v]=K;
for(int i=hd[v];i;i=pre[i])
{
int u=to[i];
if(u==fa)continue;
dfs(u,v);
g[v]=(1ll*g[v]*f[u]+1ll*g[u]*f[v])%mod;
f[v]=1ll*f[v]*f[u]%mod;
}
f[v]=(f[v]+g[v])%mod;
}
ll p2(ll x){return x*x%mod;}
ll pw(ll x,ll p)
{
return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
}
int main()
{
//freopen("data.in","r",stdin);
int x,y,op,p;
scanf("%d%d%d",&n,&p,&op);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
adde(x,y),adde(y,x);
}
if(p==1)
{
printf("%lld\n",pw(n,n-2));
return 0;
}
P=pw(pw(p,mod-2)-1,mod-2),K=1ll*n*P%mod;
dfs(1,0);
ll ans = g[1] * p2(pw(n,mod-2)) % mod * pw(pw(P,mod-2),n) % mod;
printf("%lld\n",ans*pw(p,n)%mod);
return 0;
}
类似Task1,我们同样可以得到
\[ans'=\sum_{V}G(V)P^{|V|}\]
其中\(G(V)=(n^{n-|V|-2}\prod_{i=1}^{n-|V|}a_i)^2\),V为任意一棵n个点的森林的边集
带入化简得
\[ans'=n^{-4}P^n\sum_{V}\prod_{i=1}^{n-|V|}(a_i^2n^2P^{-1})\]
设\(K=n^2P^{-1},ans''=\sum_{V}\prod_{i=1}^{n-|V|}(a_i^2K)\)
则\(ans'=n^{-4}P^nans''\)
考虑\(ans''\),它相当于将n个点划分为若干连通块,每一个连通块内部构成一棵树,则每一种划分方案的贡献为\(\prod_{i=1}^{n-|V|}a_i^{a_i-2}(a_i^2K)\),其中每一个大小为i的连通块的贡献为\(i^{i-2}(i^2K)=i^iK\)
根据生成函数的一些性质,如果我们设\(G(x)=\sum_{i=1}^\infty \frac{i^iK}{i!}x^i\),那么\(e^{G(x)}\)就是\(ans''\)的指数型生成函数,取它的第\(x^n\)项再乘以\(n!\)就可以得到\(ans''\)。
多项式\(exp\)即可。
最后的答案为\(bas^nn^{-4}P^nans''\)
?
完整的代码(注意要特判\(bas=1\)的情况):
#include<bits/stdc++.h>
using namespace std;
#define N 600007
#define ll long long
const ll mod=998244353;
const int lim=2e5;
int tp;
ll n,bas;
ll p2(ll x){return x*x%mod;}
ll pw(ll x,ll p)
{
return p?p2(pw(x,p/2))*(p&1?x:1)%mod:1;
}
namespace tp0
{
struct edge
{
int x,y;
};
bool operator <(edge a,edge b)
{
return a.x<b.x||a.x==b.x&&a.y<b.y;
}
set<edge> S;
void work()
{
if(bas==1)
{
printf("%d\n",1);
return ;
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
if(x>y)swap(x,y);
S.insert({x,y});
}
int cnt=0;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
if(x>y)swap(x,y);
if(S.find({x,y})!=S.end())cnt++;
}
printf("%lld\n",pw(bas,n-cnt));
}
}
namespace tp1
{
int hd[N],pre[N],to[N],num;
ll f[N],g[N],K,P,Q;
void adde(int x,int y)
{
num++;pre[num]=hd[x];hd[x]=num;to[num]=y;
}
void dfs(int v,int fa)
{
f[v]=1,g[v]=K;
for(int i=hd[v];i;i=pre[i])
{
int u=to[i];
if(u==fa)continue;
dfs(u,v);
g[v]=(g[v]*f[u]+f[v]*g[u])%mod;
f[v]=f[v]*f[u]%mod;
}
f[v]=(f[v]+g[v])%mod;
}
void work()
{
if(bas==1)
{
printf("%lld\n",pw(n,n-2));
return ;
}
P=(pw(bas,mod-2)-1+mod)%mod;
K=n*pw(P,mod-2)%mod;
Q=pw(n,2*(mod-2))*pw(P,n)%mod;
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
adde(x,y),adde(y,x);
}
dfs(1,0);
ll ans=g[1];
ans=ans*Q%mod;
ans=ans*pw(bas,n)%mod;
printf("%lld\n",ans);
}
}
namespace tp2
{
ll inv[N],fac[N],ifac[N];
int rev[N],len;
void getlen(int n)
{
for(len=1;len<=n;len<<=1);
for(int i=0;i<len;i++)
rev[i]=rev[i>>1]>>1|(i&1?len>>1:0);
}
void NTT(ll *a,int op)
{
for(int i=0;i<len;i++)
if(rev[i]<i)swap(a[rev[i]],a[i]);
for(int i=1;i<len;i<<=1)
{
ll nw=pw(3,(mod-1)/(i<<1));
for(int j=0;j<len;j+=i<<1)
{
ll w=1;
for(int k=j;k<j+i;k++)
{
ll x=a[k],y=a[k+i]*w%mod;
a[k]=(x+y)%mod,a[k+i]=(x-y+mod)%mod;
w=w*nw%mod;
}
}
}
if(op<0)
{
reverse(a+1,a+len);
ll Inv=pw(len,mod-2);
for(int i=0;i<len;i++)
a[i]=a[i]*Inv%mod;
}
}
void copy(ll *a,ll *b,int n=len)
{
for(int i=0;i<n;i++)a[i]=b[i];
for(int i=n;i<len;i++)a[i]=0;
}
ll mul_c[N],mul_d[N];
void mul(ll *t,ll *a,ll *b)
{
ll *c=mul_c,*d=mul_d;
copy(c,a),copy(d,b);
NTT(c,1),NTT(d,1);
for(int i=0;i<len;i++)c[i]=c[i]*d[i]%mod;
NTT(c,-1);
copy(t,c);
}
ll inv_c[N];
void getinv(int p,ll *a,ll *b)
{
if(p==1)return a[0]=pw(b[0],mod-2),(void)1;
getinv((p+1)/2,a,b);
getlen(2*p);
ll *c=inv_c;
copy(c,b,p);
NTT(a,1),NTT(c,1);
for(int i=0;i<len;i++)a[i]=a[i]*(2-a[i]*c[i]%mod+mod)%mod;
NTT(a,-1);
for(int i=p;i<len;i++)a[i]=0;
}
void devir(ll *a,int n)
{
for(int i=1;i<=n;i++)a[i-1]=a[i]*i%mod;
a[n]=0;
}
void inter(ll *a,int n)
{
for(int i=n;i>=0;i--)a[i+1]=a[i]*inv[i+1]%mod;
a[0]=0;
}
ll ln_c[N];
void getln(int n,ll *a,ll *b)
{
ll *c=ln_c;
getlen(2*n);
copy(c,b,n);
getinv(n,a,c);
devir(c,n);
getlen(2*n);
mul(a,a,c);
inter(a,n);
for(int i=n;i<len;i++)a[i]=0;
}
ll exp_c[N];
void getexp(int p,ll *a,ll *b)
{
if(p==1)return a[0]=1,(void)1;
getexp((p+1)/2,a,b);
getlen(2*p);
ll *c=exp_c;
copy(c,a,0);
getln(p,c,a);
getlen(2*p);
for(int i=0;i<p;i++)c[i]=(b[i]-c[i]+mod)%mod;
c[0]=(c[0]+1)%mod;
mul(a,a,c);
for(int i=p;i<len;i++)a[i]=0;
}
void Init()
{
fac[0]=1;
for(int i=1;i<=lim;i++)fac[i]=fac[i-1]*i%mod;
ifac[lim]=pw(fac[lim],mod-2);
for(int i=lim;i>=1;i--)ifac[i-1]=ifac[i]*i%mod;
inv[1]=1;
for(int i=2;i<=lim;i++)
inv[i]=mod-mod/i*inv[mod%i]%mod;
}
ll f[N],g[N];
void work()
{
Init();
ll P,K;
if(bas==1)
{
printf("%lld\n",pw(n,2*(n-2)));
return ;
}
P=(pw(bas,mod-2)-1+mod)%mod;
K=n*n%mod*pw(P,mod-2)%mod;
for(int i=1;i<=n;i++)
f[i]=pw(i,i)*K%mod*ifac[i]%mod;
getexp(n+1,g,f);
ll ans=g[n]*fac[n]%mod;
ans=ans*pw(n,4*(mod-2))%mod*pw(P,n)%mod;
ans=ans*pw(bas,n)%mod;
printf("%lld\n",ans);
}
}
int main()
{
scanf("%lld%lld%d",&n,&bas,&tp);
if(tp==0)tp0::work();
else if(tp==1)tp1::work();
else tp2::work();
return 0;
}
标签:swap lse 要求 get ini else swa wap 复杂
原文地址:https://www.cnblogs.com/lishuyu2003/p/12146391.html