题目大意
? 给你一棵树,求有多少个组点满足\(x\neq y,x\neq z,y\neq z,dist_{x,y}=dist_{x,z}=dist_{y,z}\)
? \(1\leq n\leq 100000\)
题解
? 问题转换为有多少个组点满足\(dist_{i,x}=dist_{i,y}=dist_{i,z}\)
? 我们考虑树形DP
? \(f_{i,j}=\)以\(i\)为根的子树中与\(i\)的距离为\(j\)的节点数
? \(g_{i,j}=\)以\(i\)为根的子树外选择一个点\(s\)满足\(s\)到\(i\)的距离为\(j\),能新增的的方案数
? 若\(v\)是\(u\)的重儿子,则:\(f_{u,j}+=f_{v,j-1},g_{u,j}+=g_{v,j+1}\),这样就可以由\(u\)的重儿子转移到\(u\)
? 否则:\(g_{u,j}+=g_{v,{j+1}}+f_{v,j-1}\times f_{u,j},f_{u,j}+=f_{v,j-1}\)
? 答案为\(\sum f_{x,j}\times g_{y,j}\),其中\(x\)是\(y\)的兄弟
? 可以用长链剖分辅助转移
? 时间复杂度:\(O(n)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
struct list
{
int v[200010];
int t[200010];
int h[100010];
int n;
void clear()
{
n=0;
memset(h,0,sizeof h);
}
void add(int x,int y)
{
n++;
v[n]=y;
t[n]=h[x];
h[x]=n;
}
};
list l;
ll ans;
ll f[100010];
ll g[200010];
int d[100010];
int bg[100010];
int ed[100010];
int ch[100010];
int t[100010];
int w[100010];
int ti;
void dfs(int x,int fa)
{
d[x]=1;
ch[x]=0;
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa)
{
dfs(l.v[i],x);
if(d[l.v[i]]+1>d[x])
{
d[x]=d[l.v[i]]+1;
ch[x]=l.v[i];
}
}
}
void dfs2(int x,int fa,int top)
{
t[x]=top;
w[x]=++ti;
if(x==top)
bg[top]=ti;
ed[top]=ti;
if(ch[x])
dfs2(ch[x],x,top);
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=ch[x]&&l.v[i]!=fa)
dfs2(l.v[i],x,l.v[i]);
}
ll& getf(int x,int y)
{
return f[w[x]+y];
}
ll& getg(int x,int y)
{
return g[2*(w[t[x]]-1)+2*d[t[x]]-d[x]+1-y];
}
void solve(int x,int fa)
{
if(ch[x])
solve(ch[x],x);
int i,j;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa&&l.v[i]!=ch[x])
{
int v=l.v[i];
solve(v,x);
for(j=0;j<d[v];j++)
ans+=getf(v,j)*getg(x,j+1);
for(j=1;j<d[v];j++)
ans+=getg(v,j)*getf(x,j-1);
for(j=0;j<d[v];j++)
getg(x,j+1)+=getf(v,j)*getf(x,j+1);
for(j=1;j<d[v];j++)
getg(x,j-1)+=getg(v,j);
for(j=0;j<d[v];j++)
getf(x,j+1)+=getf(v,j);
}
ans+=getg(x,0);
getf(x,0)++;
}
int main()
{
int n;
scanf("%d",&n);
l.clear();
memset(bg,0,sizeof bg);
memset(ed,0,sizeof ed);
memset(f,0,sizeof f);
memset(g,0,sizeof g);
memset(d,0,sizeof d);
memset(ch,0,sizeof ch);
memset(t,0,sizeof t);
memset(w,0,sizeof w);
ans=0;
ti=0;
int i,x,y;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
l.add(x,y);
l.add(y,x);
}
dfs(1,0);
dfs2(1,0,1);
solve(1,0);
printf("%lld\n",ans);
return 0;
}