题意:给定一棵树,查询时给定两个点,求出两个点的距离。
暴力做肯定超时的。我的做法是采用lca(最近公共祖先)的离线算法,即tarjan算法(据说Tarjan提出了很多算法,可能还有很多tarjan算法),算法里用到了并查集。在输入完所有查询之后,在求出答案。tarjan算法的做法是:一开始vis数组初始化为0,从树根开始递归往下对点进行染色,刚到一个点的时候将vis取为-1,在继续递归;遍历完子节点返回之后vis变为1。在vis变为1之前,检索一下当前节点的所有查询,设查询中的另外一个节点为To,如果vis[To]==0,就continue,因为To还没有处理,不知道它的信息;如果vis[To]==-1,说明To被访问了一次,但是还没有返回到,这意味着To是当前节点的祖先,因此To就是当前节点的最近公共祖先;如果vis[To]==1,说明To已经处理完了,这时候并查集就派上用场了。在递归时,当一个节点处理完返回到父亲那里时,就把父亲变成其所在集合的代表元素。在刚才讨论到vis[To]==1的情况中,可以知道find(To)(即To所在集合的代表元素)就是To和当前节点的最近公共祖先了(这个可以画图演算一下)。在这道题中,我们一开始可以用一个简单的递归算出每个点到根节点的距离dis[i]。那么对于一个查询的两个点fir和sec,它们的距离就是dis[fir]-dis[lca]+dis[sec]-dis[lca],lca是fir和sec的最近公共祖先。
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<string> #include<cmath> #include<set> #include<climits> #include<queue> #include<vector> #include<map> using namespace std; struct node { int to,id; node(int t,int i) { to=t; id=i; } node(){} }; const int maxn=50005; vector<node>vec[maxn]; vector<pair<int,int>>query; int father[maxn],fir[maxn<<1],nxt[maxn<<1],vv[maxn<<1],val[maxn<<1],dis[maxn],ans[75005],e; int vis[maxn];//0 means it's white,-1 means it's grey, 1 means it's black int findn(int n) { if(n!=father[n]) father[n]=findn(father[n]); return father[n]; } void add(int a,int b,int c,int i) { vv[e]=b; val[e]=c; nxt[e]=fir[a]; fir[a]=e++; } void get_height(int sroot,int dist) { vis[sroot]=1; dis[sroot]=dist; for(int i=fir[sroot];i!=-1;i=nxt[i]) { int v=vv[i]; if(!vis[v]) { get_height(v,dist+val[i]); } } } void dfs(int cur,int fa) { vis[cur]=-1; for(int i=fir[cur];i!=-1;i=nxt[i]) { int v=vv[i]; if(!vis[v]) { dfs(v,cur); father[v]=cur; } } int size=vec[cur].size(); for(int i=0;i<size;i++) { node nxt=vec[cur][i]; if(!vis[nxt.to]) continue; if(-1==vis[nxt.to]) { ans[nxt.id]=nxt.to; } else if(1==vis[nxt.to]) { ans[nxt.id]=findn(nxt.to); } } vis[cur]=1; } int main() { #pragma comment(linker, "/STACK:102400000,102400000")//此代码需要扩栈,可能在递归时耗的内存有点大 int n; while(scanf("%d",&n)!=EOF) { for(int i=0;i<=n;i++) { father[i]=i; fir[i]=-1; vis[i]=0; vec[i].clear(); } e=0;//important int a,b,c; for(int i=0;i<n-1;i++) { scanf("%d%d%d",&a,&b,&c); add(a,b,c,i); add(b,a,c,i); } get_height(0,0); int q; scanf("%d",&q); for(int i=0;i<q;i++) { scanf("%d%d",&a,&b); vec[a].push_back(node(b,i)); vec[b].push_back(node(a,i)); query.push_back(make_pair<int,int>(a,b)); } for(int i=0;i<=n;i++) vis[i]=0; dfs(0,0); int size=query.size(); for(int i=0;i<size;i++) { int fir=query[i].first; int sec=query[i].second; int lca=ans[i]; int distance=abs(dis[lca]-dis[fir])+abs(dis[lca]-dis[sec]); printf("%d\n",distance); } } }
原文地址:http://blog.csdn.net/u014088857/article/details/43063717