题目大意:
给你一棵带权树,和一些选定的点。一个人从$i$点出发,要开车走遍所有选定的点(不必回到起点),要你分别输出$i=1\sim n$时,这个人走的最短的方案的长度。
解题思路:
首先把虚树构建出来(找出所有在这棵虚树中的节点即可),DFS一遍即可。
然后我们先假设它要回到起点,那么对于每个在虚树上的节点,它要走过的距离就是虚树上所有边的权值和的两倍,这个是显然的。
但是它不用回到起点,那我们只要减去起点到虚树上离它最远的点的距离即可(也就是少走最长的那条路,建立虚树的时候找出即可)。
而树上一个点离它最远的节点一定是树的直径的两个端点之一,因此我们求出虚树直径的两个端点,然后每次找一个离起点最远的端点,减去它到起点的距离即为答案(两遍DFS求虚树直径的顶点)。
然后对于不是虚树上的节点,只要找到离它最近的虚树节点,求这个节点的答案加上该节点到那个虚树节点的距离即可,一遍DFS即可找到所有非虚树节点的最近的虚树节点。
最后求答案即可,算两点间的距离时计算一下LCA即可
时间复杂度$O(n\log_2 n)$。
C++ Code:
#include<bits/stdc++.h> #define N 500005 #define ll long long #define Dis(a,b) (dis[a]+dis[b]-(dis[lca(a,b)]<<1)) int n,k,cnt=0,head[N],rt,nxtnd[N],L,R,fa[N][21],dep[N]; ll sum=0,dis[N],d[N]; struct edge{ int to,dis,nxt; }e[N<<1]; inline ll max(ll a,ll b){return a<b?b:a;} inline int readint(){ char c=getchar(); for(;!isdigit(c);c=getchar()); int d=0; for(;isdigit(c);c=getchar()) d=(d<<3)+(d<<1)+(c^‘0‘); return d; } void dfs(int now,int pr,int pre){ if(nxtnd[now]==-1)pr=now; for(int i=head[now];i;i=e[i].nxt) if(e[i].to!=pre){ dis[e[i].to]=dis[now]+e[i].dis; fa[e[i].to][0]=now; dep[e[i].to]=dep[now]+1; dfs(e[i].to,pr,now); if(nxtnd[e[i].to]==-1) nxtnd[now]=-1,sum+=e[i].dis; } if(!nxtnd[now])nxtnd[now]=pr; } void dfs2(int now,int pre){ for(int i=head[now];i;i=e[i].nxt) if(e[i].to!=pre&&nxtnd[e[i].to]==-1){ d[e[i].to]=d[now]+e[i].dis; dfs2(e[i].to,now); } } int lca(int x,int y){ if(dep[x]<dep[y])x^=y^=x^=y; for(int i=20;i>=0;--i) if(dep[fa[x][i]]>=dep[y])x=fa[x][i]; if(x==y)return x; for(int i=20;i>=0;--i) if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void dfs0(int now,int pre){ for(int i=head[now];i;i=e[i].nxt) if(e[i].to!=pre){ if(nxtnd[e[i].to]!=-1){ if(nxtnd[now]==-1)nxtnd[e[i].to]=now;else nxtnd[e[i].to]=nxtnd[now]; } dfs0(e[i].to,now); } } int main(){ n=readint(),k=readint(); memset(head,0,sizeof head); memset(dep,0,sizeof dep); for(int i=1;i<n;++i){ int x=readint(),y=readint(),z=readint(); e[++cnt]=(edge){y,z,head[x]}; head[x]=cnt; e[++cnt]=(edge){x,z,head[y]}; head[y]=cnt; } memset(nxtnd,0,sizeof nxtnd); memset(dis,0,sizeof dis); nxtnd[rt=readint()]=-1; dep[rt]=1; for(int i=1;i<k;++i)nxtnd[readint()]=-1; for(int i=head[rt];i;i=e[i].nxt){ dep[e[i].to]=2; fa[e[i].to][0]=rt; dis[e[i].to]=e[i].dis; dfs(e[i].to,rt,rt); if(nxtnd[e[i].to]==-1)sum+=e[i].dis; } dfs0(rt,0); memset(d,0,sizeof d); dfs2(rt,0); L=R=rt; for(int i=1;i<=n;++i) if(nxtnd[i]==-1&&d[i]>d[L])L=i; memset(d,0,sizeof d); dfs2(L,0); for(int i=1;i<=n;++i) if(nxtnd[i]==-1&&d[i]>d[R])R=i; for(int j=1;j<21;++j) if(1<<j<=n) for(int i=1;i<=n;++i) fa[i][j]=fa[fa[i][j-1]][j-1];else break; sum<<=1; for(int i=1;i<=n;++i){ if(nxtnd[i]==-1){ ll ans=sum-max(Dis(i,L),Dis(i,R)); printf("%lld\n",ans); }else{ ll ans=sum+Dis(i,nxtnd[i])-max(Dis(nxtnd[i],L),Dis(nxtnd[i],R)); printf("%lld\n",ans); } } return 0; }