标签:lan main ret min test vector .com std void
题目是求一棵n节点树中对于C(n,k)颗子树,每棵子树为在n个节点中选不同的k个节点作为树的边界点,这样的所有子树共包含多少条边。
问题可以转化一下,对每一条边,不同的子树中可能包含可能不包含这条边,显然,只有子树那k个节点在该边的两侧均有分布时该边才被包含在子树中。所有边的被包含次数的和,即为answer。对于一条边的被包含次数,设该边两侧分别有a,b个节点,那么,该边被包含的次数为C(a+b,k)-C(a,k)-C(b,k)(也可以借助母函数函数求C(a,i)*C(b,k-i),i从1到min{a,b,k-1},结果一样)。
//dfs写的太搓了,调了半天才好。。。
题目链接: https://www.51nod.com/contest/problem.html#!problemId=1677
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 typedef long long LL; 5 const LL mod=1e9+7; 6 const LL M=1e5+3; 7 8 LL fac[100005]; //阶乘 9 LL inv_of_fac[100005]; //阶乘的逆元 10 11 LL qpow(LL x,LL n) 12 { 13 LL ret=1; 14 for(; n; n>>=1) 15 { 16 if(n&1) ret=ret*x%mod; 17 x=x*x%mod; 18 } 19 return ret; 20 } 21 void init() 22 { 23 fac[1]=1; 24 for(int i=2; i<=M; i++) 25 fac[i]=fac[i-1]*i%mod; 26 inv_of_fac[M]=qpow(fac[M],mod-2); 27 for(int i=M-1; i>=0; i--) 28 inv_of_fac[i]=inv_of_fac[i+1]*(i+1)%mod; 29 } 30 LL C(LL a,LL b) 31 { 32 if(b>a) return 0; 33 if(b==0) return 1; 34 return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod; 35 } 36 ///////////////////////////////////////////////////////////// 37 vector<int> adj[M]; 38 int vis[M]; 39 LL n,k,ans,du[M],hh; 40 void init1() 41 { 42 ans=0; 43 memset(vis,0,sizeof(vis)); 44 memset(du,0,sizeof(du)); 45 du[1]=n; 46 hh=C(n,k); 47 for(int i=1; i<=n; i++) 48 adj[i].clear(); 49 } 50 LL dfs(int s) 51 { 52 if(adj[s].size()==1&&s!=1) return du[s]=1; 53 if(du[s]&&s!=1) return du[s]; 54 vis[s]=1; 55 LL ret,cnt=0; 56 for(int i=0; i<adj[s].size(); i++) 57 { 58 if(!vis[adj[s][i]]) 59 { 60 // printf("%d -> %d\n",s,adj[s][i]); 61 cnt+=dfs(adj[s][i]); 62 ans=(ans+hh-C(dfs(adj[s][i]),k)-C(n-dfs(adj[s][i]),k))%mod; 63 } 64 } 65 return du[s]=cnt+1; 66 } 67 68 int main() 69 { 70 init(); 71 while(~scanf("%lld%lld",&n,&k)) 72 { 73 init1(); 74 for(int i=1; i<n; i++) 75 { 76 LL u,v; 77 scanf("%d%d",&u,&v); 78 adj[u].push_back(v); 79 adj[v].push_back(u); 80 } 81 dfs(1); 82 // for(int i=1; i<=n; i++) 83 // printf("%d:%lld=========\n",i,du[i]); 84 // for(int i=1; i<=n; i++) 85 // { 86 // printf("i=%d:\n",i); 87 // for(int j=0; j<adj[i].size(); j++) 88 // printf("%d ",adj[i][j]); 89 // puts(""); 90 // } 91 printf("%lld\n",(ans+mod)%mod); 92 } 93 }
标签:lan main ret min test vector .com std void
原文地址:http://www.cnblogs.com/Just--Do--It/p/6103326.html