http://www.lydsy.com/JudgeOnline/problem.php?id=3572
明显需要构造虚树
点属于谁管理分三种情况:
1、属于虚树的点
2、在虚树上的边上的点
3、既不属于虚树的点,又不属于虚树上的边的点
第一种情况:
先做一遍树形dp,得到子树中距离它最近的点
再dfs一遍,看看父节点那一块 是否有比它现在的点更近的点
第二种情况:
一条边u-->v 如果u和v属于同一点x管理,那么这条边所代表的所有点也都属于x管理
否则的话,二分一个点tmp,tmp以上的点归管理u的点管理,tmp及tmp以下的点归管理v的点管理
第三种情况:
归这个种树的根 的父节点 管理
第三种情况可以合并到第一种情况中,即用siz[x]表示虚树中一个点x在原树代表多少个点
开始siz[x]=原树中x的子树大小
虚树中每加一条边x-->y,若y属于x的子节点中k的子树,就在siz[x]中减去 原树中k的子树大小
#include<cmath> #include<cstdio> #include<iostream> #include<algorithm> #define N 300001 typedef long long LL; #define min(x,y) ((x)<(y) ? (x) : (y)) int n,lim; int num,id[N]; int fa[N][19],SIZ[N],prefix[N],dep[N]; void read(int &x) { x=0; char c=getchar(); while(!isdigit(c)) c=getchar(); while(isdigit(c)) { x=x*10+c-‘0‘;c=getchar(); } } namespace Original { int front[N],nxt[N<<1],to[N<<1]; int tot; void add(int u,int v) { to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; to[++tot]=u; nxt[tot]=front[v]; front[v]=tot; } void dfs(int x) { id[x]=++num; SIZ[x]=1; int t; for(int i=front[x];i;i=nxt[i]) { t=to[i]; if(t!=fa[x][0]) { fa[t][0]=x; dep[t]=dep[x]+1; dfs(t); SIZ[x]+=SIZ[t]; } } } void multiplication() { lim=log(n)/log(2); for(int i=1;i<=lim;++i) for(int j=1;j<=n;++j) fa[j][i]=fa[fa[j][i-1]][i-1]; } void Prefix_dfs(int x) { int t; for(int i=front[x];i;i=nxt[i]) { t=to[i]; if(t!=fa[x][0]) { prefix[t]=prefix[x]+SIZ[x]-SIZ[t]; Prefix_dfs(t); } } } void main() { int u,v; read(n); for(int i=1;i<n;++i) { read(u); read(v); add(u,v); } dfs(1); multiplication(); Prefix_dfs(1); return; } } namespace Imaginary { int cnt,use[N]; int st[N],top; int tot; int front[N],to[N],nxt[N],from[N],val[N]; int bin[N],bin_cnt; int siz[N]; int mi[N],bl[N]; int dy[N],ans[N]; bool cmp(int p,int q) { return id[p]<id[q]; } int find_ancestor(int x,int y) { for(int i=lim;i>=0;--i) if(y>=(1<<i)) { x=fa[x][i]; y-=(1<<i); } return x; } void add(int u,int v,int w) { to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; from[tot]=u; val[tot]=w; int s=find_ancestor(v,dep[v]-dep[u]-1); siz[u]-=SIZ[s]; } int get_lca(int x,int y) { if(id[x]<id[y]) std::swap(x,y); for(int i=lim;i>=0;--i) if(id[fa[x][i]]>id[y]) x=fa[x][i]; return fa[x][0]; } int get_dis(int u,int v) { int lca=get_lca(u,v); return dep[u]+dep[v]-dep[lca]*2; } void build() { std::sort(use+1,use+cnt+1,cmp); tot=0; st[top=1]=1; bin[bin_cnt=1]=1; siz[1]=SIZ[1]; int i=1; if(use[1]==1) i=2; int x,lca; for(;i<=cnt;++i) { x=use[i]; lca=get_lca(x,st[top]); while(id[lca]<id[st[top]]) { if(id[lca]>=id[st[top-1]]) { add(lca,st[top],dep[st[top]]-dep[lca]); if(lca!=st[--top]) { st[++top]=lca; siz[lca]+=SIZ[lca]; bin[++bin_cnt]=lca; } break; } add(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]); top--; } st[++top]=x; siz[x]+=SIZ[x]; bin[++bin_cnt]=x; } while(top>1) { add(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]); top--; } } int dfs1(int x) { int p,d; mi[x]=0; if(dy[x]) { for(int i=front[x];i;i=nxt[i]) dfs1(to[i]); bl[x]=x; return x; } for(int i=front[x];i;i=nxt[i]) { p=dfs1(to[i]); d=dep[p]-dep[x]; if(!mi[x] || d<mi[x]) { mi[x]=d; bl[x]=p; } else if(d==mi[x] && p<bl[x]) bl[x]=p; } return bl[x]; } int dfs2(int x) { int t; for(int i=front[x];i;i=nxt[i]) { t=to[i]; if(!dy[t]) if(bl[x]!=bl[t]) { if(mi[x]+val[i]<mi[t]) { mi[t]=mi[x]+val[i]; bl[t]=bl[x]; } else if(mi[x]+val[i]==mi[t] && bl[x]<bl[t]) bl[t]=bl[x]; } dfs2(t); } ans[dy[bl[x]]]+=siz[x]; } void belong() { dfs1(1); dfs2(1); } void get_ans() { int f,s; int l,r,mid,tmp,tmp_son; int u,v; bool equal; int up,down; for(int i=1;i<=tot;++i) { u=from[i]; v=to[i]; r=dep[v]-dep[u]-1; if(!r) continue; s=find_ancestor(v,r); if(bl[u]==bl[v]) ans[dy[bl[u]]]+=prefix[v]-prefix[s]; else { tmp=v; l=1; equal=false; while(l<=r) { mid=l+r>>1; f=find_ancestor(v,mid); down=get_dis(f,bl[v]); up=get_dis(f,bl[u]); if(down<up) tmp=f,l=mid+1; else if(down==up) { tmp=f; equal=true; tmp_son=find_ancestor(v,mid-1); break; } else r=mid-1; } if(!equal) { ans[dy[bl[v]]]+=prefix[v]-prefix[tmp]; ans[dy[bl[u]]]+=prefix[tmp]-prefix[s]; } else { ans[dy[bl[v]]]+=prefix[v]-prefix[tmp_son]; ans[dy[bl[u]]]+=prefix[tmp]-prefix[s]; ans[dy[min(bl[v],bl[u])]]+=prefix[tmp_son]-prefix[tmp]; } } } for(int i=1;i<=cnt;++i) printf("%d ",ans[i]); printf("\n"); } void clear() { for(int i=1;i<=bin_cnt;++i) { front[bin[i]]=0; ans[i]=0; dy[bin[i]]=0; siz[bin[i]]=0; } tot=0; } void main() { int m; read(m); while(m--) { read(cnt); for(int i=1;i<=cnt;++i) { read(use[i]); dy[use[i]]=i; } build(); belong(); get_ans(); clear(); } return; } } int main() { Original::main(); Imaginary::main(); return 0; }