标签:res push 翻转 表示 lan code 使用 接下来 using
给定一棵根节点为 \(1\) 的二叉树 \(T\),你需要先保留一个包含 \(1\) 号节点的连通块,然后给每个点确定一个权值 \(a_i\),使得对于每个点 \(u\) 都有其权值 \(a_u\) 大于等于其所有儿子的权值和 \(\sum a_v[(u,v)\in T]\)。
最后,你需要使得根节点权值为 \(m\),求方案数,答案对 \(998244353\) 取模。
\(n\le 10^5,m\le 10^{18}\)
假设在 \(i\) 处放一个权值,那么其所有祖先都要放一个权值。
于是我们假设确定了每个点的 "额外权值" 那么合法等价于 \(m\) 大于等于额外权值和。
我们可以这样考虑,假设以 \(1\) 为根的连通块有 \(L\) 个点,那么此时的方案数显然就是 \(\frac{1}{(1-x)^{L}}[x^m]\) 即 \(\binom{L+m-1}{m}\)
于是我们只需要知道以 \(1\) 为根的,大小为 \(i\) 的连通块有多少个,使用生成函数来刻画答案,那么转移形如:
暴力背包,\(\mathcal O(n^2)\)(树背包复杂度)
接下来考虑优化,我们先找到一条重链,然后对忽略此重链的树递归,对于一条重链,假设每个点都有另一个儿子(没有将多项式设为 \(0\))此时我们相当于计算,给定一个序列和多项式 \(F_2(x),F_3(x)...\)(注意 \(F_1(x)\) 为空)(方便起见给所有多项式先加 \(1\),然后给 \(F_2(x)\) 乘以 \((x+1)\)),求:
方便起见令 \(G_i(x)=F_i(x)x\),那么就有:
不难发现这个算式相当于将 \(G\) 翻转后统计 \(\sum_i \prod_{j\le i}G_j(x)\)
使用分治 NTT 加速即可,复杂度为 \(\mathcal O(n\log^3 n)\)
复杂度分析:
设 \(T(n)\) 表示复杂度,则最后一次合并的复杂度为 \(\mathcal O(n\log^2 n)\)
接下来,对于每棵子树,由于重链剖分,问题规模至少缩小了一半,且问题规模和仍然是 \(n\),于是递归层数至多为 \(\log n\),每层复杂度仍为 \(\mathcal O(\sum (\textrm{size})\log^2 (\textrm{size}))=\mathcal O(n\log^2 n)\) 总复杂度为 \(\mathcal O(n\log^3 n)\)
\(Code:\)
#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define Rep(i, s, t) for(register int i = (s); i < (t); ++ i)
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define mp make_pair
#define pi pair<int, int>
#define pb push_back
#define int long long
#define vi vector<int>
int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc < ‘0‘ || cc > ‘9‘ ) { if( cc == ‘-‘ ) flus = - flus ; cc = getchar() ; }
while( cc >= ‘0‘ && cc <= ‘9‘ ) cn = cn * 10 + cc - ‘0‘, cc = getchar() ;
return cn * flus ;
}
const int N = 4e5 + 5 ;
const int P = 998244353 ;
const int G = 3 ;
const int Gi = 332748118 ;
int fpow(int x, int k) {
int ans = 1, base = x ;
while(k) {
if(k & 1) ans = 1ll * ans * base % P ;
base = 1ll * base * base % P, k >>= 1 ;
} return ans ;
}
int n, X, ch[N], sz[N], deg[N], R[N], fa[N], fac[N], inv[N], ind[N], L, limit, Inv ;
vector<int> F[N], E[N] ;
void init(int x) {
limit = 1, L = 0 ; while( limit < x ) limit <<= 1, ++ L ;
Rep(i, 0, limit) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1)) ;
Inv = fpow(limit, P - 2) ;
}
void NTT(vi& a, int type) {
Rep(i, 0, limit) if(R[i] > i) swap( a[i], a[R[i]] ) ;
for(re int k = 1; k < limit; k <<= 1) {
int d = fpow( (type) ? G : Gi, (P - 1) / (k << 1) ) ;
for(re int i = 0; i < limit; i += (k << 1))
for(re int j = i, g = 1; j < i + k; ++ j, g = g * d % P) {
int nx = a[j], ny = a[j + k] * g % P ;
a[j] = (nx + ny) % P, a[j + k] = (nx - ny + P) % P ;
}
}
if( !type ) Rep(i, 0, limit) a[i] = a[i] * Inv % P ;
}
void dfs1(int x, int ff) {
sz[x] = 1, fa[x] = ff ;
for(int v : E[x]) if(v ^ ff) {
dfs1(v, x), sz[x] += sz[v], ++ deg[x] ;
if( sz[v] >= sz[ch[x]] ) ch[x] = v ;
}
}
vector<int> p[N], f[N], st[N] ; int top ;
vi operator + (vi a, vi b) {
vi ans ; int cnt = max(a.size(), b.size()) ;
ans.resize(cnt), a.resize(cnt), b.resize(cnt) ;
Rep(i, 0, cnt) ans[i] = a[i] + b[i] ;
return ans ;
}
vi operator * (vi a, vi b) {
vi ans ; init(a.size() + b.size() + 2) ; int cnt = a.size() + b.size() - 1 ;
ans.resize(limit), a.resize(limit), b.resize(limit) ;
NTT(a, 1), NTT(b, 1) ;
Rep( i, 0, limit ) ans[i] = a[i] * b[i] % P ;
NTT(ans, 0), ans.resize(cnt) ;
return ans ;
}
void Solve(int l, int r) {
if(l == r) { p[l] = st[l], f[l] = st[l], ++ f[l][0] ; return ; }
int mid = (l + r) >> 1 ;
Solve(l, mid), Solve(mid + 1, r) ;
vi fl = f[l], pr = p[mid + 1], fr = f[mid + 1] ;
-- fl[0], f[l] = (fl * pr + fr), p[l] = p[l] * p[mid + 1] ;
}
void count(int x) {
Solve(1, top), F[x] = f[1] ; top = 0 ;
}
void solve(int x, int ff) {
if( deg[x] <= 1 ) F[x].resize(1), F[x][0] = 1 ;
for(int v : E[x]) {
if(v == fa[x] || v == ch[x]) continue ;
solve(v, v), F[x] = F[v] ;
}
if( ch[x] ) solve(ch[x], x) ;
int cnt = F[x].size() ; F[x].resize(cnt + 1) ;
drep( i, 1, cnt ) F[x][i] = F[x][i - 1] ;
F[x][0] = 0, st[++ top] = F[x] ;
if( x == ff ) count(x) ;
cnt = F[x].size() ;
}
signed main()
{
n = gi(), X = gi() ; int x, y ;
rep( i, 2, n ) x = gi(), y = gi(), E[x].pb(y), E[y].pb(x) ;
dfs1(1, 1), solve(1, 1) ;
int Ans = 0 ; fac[0] = inv[0] = ind[0] = 1 ;
rep( i, 1, n ) fac[i] = fac[i - 1] * i % P ;
rep( i, 1, n ) inv[i] = fpow( fac[i], P - 2 ) ;
rep( i, 1, n ) ind[i] = ind[i - 1] * ((X + i) % P) % P ;
rep( i, 0, n - 1 ) Ans = (Ans + ind[i] * inv[i] % P * F[1][i + 1] % P) % P ;
cout << Ans << endl ;
return 0 ;
}
标签:res push 翻转 表示 lan code 使用 接下来 using
原文地址:https://www.cnblogs.com/Soulist/p/14027920.html