码迷,mamicode.com
首页 > 其他好文 > 详细

LCA

时间:2015-10-07 22:56:59      阅读:265      评论:0      收藏:0      [点我收藏+]

标签:

LCA,全称为Lowest Common Ancestor, 即最近公共祖先。这是对于有根树而言的,两个节点u, v的公共祖先中距离最近的那个被称为最近公共祖先(这解释。。真通俗。。。)我们来看个图:

 技术分享

4和7的LCA是2,5和6的LCA是1,2和5的LCA是2。

最笨的实现方法就是;

对于同一深度的点,一个个往上走知道走到相同的点为止;不同深度的点转化为同一深度的点再往上走!

复杂度为O(dep[u] + dep[v] - 2 * dep[c])(u, v表示要查询的点,c为u, v的LCA)

代码如下:

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100000 + 5;

vector<int> G[N];

int rt;//表示根,即root

int dsu[N], dep[N];

//dsu数组表示父亲,dep表示深度 

void dfs(int v, int p, int d){
    dsu[v] = p;
    dep[v] = d;
    for(int i = 0; i < G[v].size(); i ++)
        if(G[v][i] != p) dfs(G[v][i], v, d + 1);
    //G[v][i] != p的原因是读入的时候是无向,所以存在反向边,例在处理节点v的时候,假设dsu[v] = u, 则存在i,使得G[v][i] = u,这样会导致在两个点中一直循环,导致陷入死循环。
}

void init(){
    dfs(rt, -1, 0);
}

int lca(int u, int v){
    if(dep[u] > dep[v]) swap(u, v);
    while(dep[v] > dep[u]) v = dsu[v];
    //让u, v处于同一深度。
    while(u != v){
        u = dsu[u];
        v = dsu[v];
    }
    return u;
}

int n, q, a, b;

int main(){
    //我们以n个点n - 1条边,q次询问为例
    while(scanf("%d%d", &n, &q) == 2){
        for(int i = 1; i <= n; i ++)G[i].clear();
        for(int i = 1; i < n; i ++){
            scanf("%d%d", &a, &b);
            G[a].push_back(b);
            G[b].push_back(a);
        }
        rt = 1;//以1为根
        init();
        while(q --){
            scanf("%d%d", &a, &b);
            printf("%d\n", lca(a, b));
        }
    }
    return 0;
}

我们发现如果有n个点最坏的情况下有O(n)的复杂度,如果多次查询复杂度肯定会爆掉,所以我们必须要有高效的算法。

实现LCA的高效算法有二种,分别是倍增法和RMQ法。

一.倍增法

 我们首先这们想如果相同深度的两个节点u, v当往上走k步的时候走到同一节点,那么往上走k + 1步还是同一节点,k + 2步也是, k + k即2k步也是。我们把上面的dsu数组变成2维数组dsu[k][v]表示节点v往上走2k步所走到的节点。那么dsu[k + 1][v] = dsu[k][dsu[k][v]];这样我们就可以通过二分来查找他们的LCA了,每次查询复杂度为O(logn), 预处理为O(nlogn)。

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 10000 + 5;

vector<int> G[N];


int rt, a, b, c, q, n, m, T;

int dsu[40][N], dep[N];

void dfs(int v, int p, int d){
    dsu[0][v] = p;
    dep[v] = d;
    for(int i = 0; i < G[v].size(); i ++)
        if(G[v][i] != p) dfs(G[v][i], v, d + 1);
}

void init(){
    dfs(rt, -1, 0);
    for(int k = 0; k + 1 < 32; k ++){
        for(int v = 1; v <= n; v ++){
            if(dsu[k][v] < 0)dsu[k + 1][v] = -1;//根节点的父亲设为-1
            else dsu[k + 1][v] = dsu[k][dsu[k][v]];
        }
    }
}

int lca(int u, int v){
    if(dep[u] > dep[v]) swap(u, v);
    for(int k = 0; k < 32; k ++){
        if((dep[v] - dep[u]) >> k & 1)
            v = dsu[k][v];
    }
    if(u == v)return u;
    for(int k = 31; k >= 0; k --){
        if(dsu[k][u] != dsu[k][v]){
            u = dsu[k][u];
            v = dsu[k][v];
        }
    }
    return dsu[0][u];
}


int main(){
    while(scanf("%d%d", &n, &q) == 2){
        for(int i = 1; i <= N; i ++)G[i].clear();
        for(int i = 1; i < n; i ++){
            scanf("%d%d", &a, &b);
            G[a].push_back(b);
            G[b].push_back(a);
        }
        rt = 1;
        init();
        while(q --){
            scanf("%d%d", &a, &b);
            printf("%d\n", lca(a, b));
        }
    }
    return 0;
}

二.RMQ法

其实这里还涉及到另外一个东西,叫做dfs序。是指你用dfs遍历一棵树时,每个节点会按照遍历到的先后顺序得到一个序号。然后你用这些序号,可以把整个遍历过程表示出来。如下图:

技术分享

 如上图所示,则整个遍历过程为1 2 4 2 5 7 5 8 5 2 1 3 6 3 1

我们将他保存在一个vs数组,并开个id数组记录第一次在vs中出现的下标,例如id[1] = 1, id[4] = 3;

并用dep数组储存vs数组中每个数的深度,例如dep[2] = dep[4] = 1(vs数组中第2个和第4个都是2,2的深度为2)。

而LCA(u, v)就是第一次访问u之后到第一次访问v之前所经过顶点中离根最近的那个。假设id[u] <= id[v],那么LCA(u, v) = vs[t] t为id[u]与id[v]中dep最小的那一个。

而这个不就相当于求区间的RMQ吗?

附上代码:

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100000 + 5;

vector<int>G[N];

int rt, n, m, a, b, c, q;

int vs[N * 2 - 1], dep[N * 2 - 1], id[N], sum[N];

int dp[N][40];

struct RMQ{
    int log2[N];
    void init(int n){
        log2[0] = -1;
        for(int i = 1; i <= n; i ++)log2[i] = log2[i >> 1] + 1;
        for(int i = 1; i <= n; i ++)dp[i][0] = vs[i];
        for(int j = 1; j <= log2[n]; j ++){
            for(int i = 1; i + (1 << j) <= n + 1; i ++){
                int ca = dp[i][j - 1];
                int cb = dp[i + (1 << j)][j - 1];
                if(vs[ca] < vs[cb]) dp[i][j] = ca;
                else dp[i][j] = cb;
            }
        }
    }
    int query(int ls, int rs){
        int k = log2[rs - ls + 1];
        int ca = dp[ls][k];
        int cb = dp[rs - (1 << k) + 1][k];
        if(vs[ca] < vs[cb]) return ca;
        else return cb;
    }
}rmq;

void dfs(int v, int p, int d, int& k){
    //printf("v = %d\n", v);
    id[v] = k;
    vs[k] = v;
    dep[k ++] = d;
    for(int i = 0; i < G[v].size(); i ++){
        if(G[v][i] != p){
            dfs(G[v][i], v, d + 1, k);
            vs[k] = v;
            dep[k ++] = d;
        }
    }
}

void init(int V){
    int k = 1;
    dfs(rt, -1, 0, k);
    rmq.init(V * 2 - 1);
}

int lca(int u, int v){
    return vs[rmq.query(min(id[u], id[v]), max(id[u], id[v]))];
}

void print(){
    for(int i = 0; i < 2 * n; i ++) printf("vs[%d] = %d\n", i, vs[i]);
    for(int i = 1; i <= n; i ++) printf("id[%d] = %d\n", i, id[i]);
    for(int i = 1; i <= n; i ++) printf("dep[%d] = %d\n", i, dep[i]);
    for(int i = 1; i <= n; i ++) printf("sum[%d] = %d\n", i, sum[i]);
}

int main(){
    while(scanf("%d%d", &n, &q) == 2){
        for(int i = 1; i <= n; i ++)G[i].clear();
        for(int i = 1; i < n; i ++){
            scanf("%d%d", &a, &b);
            G[a].push_back(b);
            G[b].push_back(a);
        }
        rt = 1;
        init(n);
        //print();
        while(q --){
            scanf("%d%d", &a, &b);
            printf("%d\n", lca(a, b));
        }
    }
    return 0;
}

 LCA的应用

LCA可以用来求树上的两个顶点之间的权值和,让任意一个点作为根节点,设sum[u]为顶点rt到u的权值和,那么u到v的权值和就是sum[u] - sum[lca(u, v)] + sum[v] - sum[lca(u, v)]。

来看一道题: 传送门

技术分享

 

附上两种方法的代码:

1.倍增法:

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 40000 + 5;

vector<int> G[N];

vector<int> E[N];

int rt, a, b, c, q, n, m, T;

int dsu[40][N], dep[N], sum[N];

void dfs(int v, int p, int d){
    dsu[0][v] = p;
    dep[v] = d;
    for(int i = 0; i < G[v].size(); i ++)
        if(G[v][i] != p){
            sum[G[v][i]] = sum[v] + E[v][i];
            dfs(G[v][i], v, d + 1);
        }
}

void init(){
    dfs(rt, -1, 0);
    for(int k = 0; k + 1 < 32; k ++){
        for(int v = 1; v <= n; v ++){
            if(dsu[k][v] < 0)dsu[k + 1][v] = -1;
            else dsu[k + 1][v] = dsu[k][dsu[k][v]];
        }
    }
}

int lca(int u, int v){
    if(dep[u] > dep[v]) swap(u, v);
    for(int k = 0; k < 32; k ++){
        if((dep[v] - dep[u]) >> k & 1)
            v = dsu[k][v];
    }
    if(u == v)return u;
    for(int k = 31; k >= 0; k --){
        if(dsu[k][u] != dsu[k][v]){
            u = dsu[k][u];
            v = dsu[k][v];
        }
    }
    return dsu[0][u];
}


int main(){
    scanf("%d", &T);
    while(T--){
        scanf("%d%d", &n, &q);
        for(int i = 1; i <= N; i ++)G[i].clear(), E[i].clear();
        for(int i = 1; i < n; i ++){
            scanf("%d%d%d", &a, &b, &c);
            G[a].push_back(b);
            G[b].push_back(a);
            E[a].push_back(c);
            E[b].push_back(c);
        }
        rt = 1;
        init();
        while(q --){
            scanf("%d%d", &a, &b);
            c = lca(a, b);
            printf("%d\n", sum[a] + sum[b] - 2 * sum[c]);
        }
    }
    return 0;
}

2RMQ + dfs序:

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100000 + 5;

vector<int>G[N];

vector<int>E[N];

int rt, n, m, a, b, c, q;

int vs[N * 2 - 1], dep[N * 2 - 1], id[N], sum[N];

int dp[N][40];

struct RMQ{
    int log2[N];
    void init(int n){
        log2[0] = -1;
        for(int i = 1; i <= n; i ++)log2[i] = log2[i >> 1] + 1;
        for(int i = 1; i <= n; i ++)dp[i][0] = vs[i];
        for(int j = 1; j <= log2[n]; j ++){
            for(int i = 1; i + (1 << j) <= n + 1; i ++){
                int ca = dp[i][j - 1];
                int cb = dp[i + (1 << j)][j - 1];
                if(vs[ca] < vs[cb]) dp[i][j] = ca;
                else dp[i][j] = cb;
            }
        }
    }
    int query(int ls, int rs){
        int k = log2[rs - ls + 1];
        int ca = dp[ls][k];
        int cb = dp[rs - (1 << k) + 1][k];
        if(vs[ca] < vs[cb]) return ca;
        else return cb;
    }
}rmq;

void dfs(int v, int p, int d, int& k){
    //printf("v = %d\n", v);
    id[v] = k;
    vs[k] = v;
    dep[k ++] = d;
    for(int i = 0; i < G[v].size(); i ++){
        if(G[v][i] != p){
            sum[G[v][i]] = sum[v] + E[v][i];
            dfs(G[v][i], v, d + 1, k);
            vs[k] = v;
            dep[k ++] = d;
        }
    }
}

void init(int V){
    int k = 1;
    sum[0] = sum[1] = 0;
    dfs(rt, -1, 0, k);
    rmq.init(V * 2 - 1);
}

int lca(int u, int v){
    return vs[rmq.query(min(id[u], id[v]), max(id[u], id[v]))];
}

void print(){
    for(int i = 0; i < 2 * n; i ++) printf("vs[%d] = %d\n", i, vs[i]);
    for(int i = 1; i <= n; i ++) printf("id[%d] = %d\n", i, id[i]);
    for(int i = 1; i <= n; i ++) printf("dep[%d] = %d\n", i, dep[i]);
    for(int i = 1; i <= n; i ++) printf("sum[%d] = %d\n", i, sum[i]);
}

int main(){
    int T;
    scanf("%d", &T);
    while(T--){
        scanf("%d%d", &n, &q);
        for(int i = 1; i <= n; i ++)G[i].clear(), E[i].clear();
        for(int i = 1; i < n; i ++){
            scanf("%d%d%d", &a, &b, &c);
            G[a].push_back(b);
            G[b].push_back(a);
            E[a].push_back(c);
            E[b].push_back(c);
        }
        rt = 1;
        init(n);
        //print();
        while(q --){
            scanf("%d%d", &a, &b);
            c = lca(a, b);
            //printf("a = %d, b = %d, c = %d\n", a, b, c);
            printf("%d\n", sum[a] + sum[b] - 2 * sum[c]);
        }
    }
    return 0;
}

LCA

标签:

原文地址:http://www.cnblogs.com/zyf0163/p/4859281.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!