码迷,mamicode.com
首页 > 其他好文 > 详细

Count on a tree SPOJ 主席树+LCA(树链剖分实现)(两种存图方式)

时间:2019-08-29 23:56:56      阅读:225      评论:0      收藏:0      [点我收藏+]

标签:去重   ext   就会   query   size   max   display   root   n+1   

Count on a tree SPOJ 主席树+LCA(树链剖分实现)(两种存图方式)

题外话,这是我第40篇随笔,纪念一下。<( ̄︶ ̄)↗[GO!]

题意

是说有棵树,每个节点上都有一个值,然后让你求从一个节点到另一个节点的最短路上第k小的值是多少。

解题思路

看到这个题一想以为是树链剖分+主席树,后来写着写着发现不对,因为树链剖分我们分成了一小段一小段,这些小段不能合并起来求第k小,所以这个想法不对。奈何不会做,查了查题解,需要用LCA(最近公共祖先),然后根据主席树具有区间加减的性质,我们查询一段区间的状态可以从LCA的角度去看问题,找到LCA(x, y)然后,我们只要一个LCA节点,然后求出区间X到根节点,以及Y到根节点的关系式来推这个关系,但是千万不要去减两倍LCA的关系,因为那样就会少掉一个节点了,于是,就dfs()往下建树,就是寻找到最后的答案。
\[ t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum \]
注意:这里我求LCA的方法是用的树链剖分的方法,求LCA的方法有很多,但是我就会这一种??

如果没有学过树链剖分,我这里有一些学习资料推荐,点我

如果没有学过主席树,别急,我这也有好的视频和博客推荐,点我

上面两个都是我学习过程中遇到的好的博客文章的收集,节省再次查找的时间

代码实现(图用vector存版,用链式向前星版)

//vector版,方便但是稍微慢一些
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn=1e5+100;
struct node{
    int l, r, sum;
}t[maxn*40];
vector<int> g[maxn];
vector<int> v;
int tot, root[maxn], w[maxn];

int dep[maxn], f[maxn], size[maxn], son[maxn]; 
int top[maxn];
int n, m, nn; //nn是实际去重后的个数
int read() //快读函数,这里已经实验过了,用和不用都行,用了会快一些。
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}
int getid(int x) //离散化之后求坐标
{
    return lower_bound(v.begin() , v.end() , x)- v.begin() +1;
}
void dfs1(int u, int fa, int depth)
{
    size[u]=1;
    dep[u]=depth;
    f[u]=fa;
    int len=g[u].size();
    for(int i=0; i<len; i++)
    {
        int v=g[u][i];
        if(v==fa) 
            continue;
        dfs1(v, u, depth+1);
        size[u]+=size[v];
        if(size[v] > size[son[u]])
            son[u]=v;
    }
}
void dfs2(int u, int tp)
{
    top[u]=tp;
    if(!son[u])
        return ;
    dfs2(son[u], tp);
    int len=g[u].size();
    for(int i=0; i<len; i++)
    {
        int v=g[u][i];
        if(v==son[u] || v==f[u])
            continue;
        dfs2(v, v);
    }
}
int lca(int x,int y)
{
    int fx=top[x], fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx] < dep[fy]) 
        {
            swap(x, y);
            swap(fx, fy);
        }
        x=f[fx]; //这里右边是fx,千万别写错了,我就是这犯了错,wa了几十下。。。
        fx=top[x];
    }
    return dep[x] < dep[y] ? x : y ;
}
void update(int l, int r, int pre, int &now, int pos)
{
    t[++tot]=t[pre];
    t[tot].sum++;
    now=tot;
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(pos<=mid)
        update(l, mid, t[pre].l, t[now].l, pos);
    else 
        update(mid+1, r, t[pre].r, t[now].r, pos);
}
int query(int l, int r, int x, int y, int lca, int gra, int k)
{
    if(l==r)
        return l;
    int mid=(l+r)>>1;
    int sum=t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum ;
    if(k<=sum)
        return query(l, mid, t[x].l, t[y].l, t[lca].l, t[gra].l, k);
    else 
        return query(mid+1, r, t[x].r, t[y].r, t[lca].r, t[gra].r, k-sum); 
}
void dfs(int u)
{
    int pos=getid(w[u]);
    update(1, nn, root[f[u]], root[u], pos);
    int len=g[u].size();
    for(int i=0; i<len; i++)
    {
        int v=g[u][i];
        if(v==f[u])
            continue;
        dfs(v);
    }
}
int main()
{
    scanf("%d%d", &n, &m); 
    for(int i=1; i<=n; i++)
    {
        w[i]=read();
        v.push_back(w[i]);
    }
    sort(v.begin() , v.end() );
    v.erase( unique( v.begin() , v.end() ), v.end());
    nn=v.size();
    int x, y;
    for(int i=1; i<n; i++)
    {
        scanf("%d%d", &x, &y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(1, 0, 1);
    dfs2(1, 1);
    dfs(1);
    int k, la;
    for(int i=1; i<=m; i++)
    {
        scanf("%d%d%d", &x, &y, &k);
        la=lca(x, y);
        printf("%d\n", v[ query(1, nn, root[x], root[y], root[la], root[ f[la] ], k) -1 ] );
    }
    return 0;
}
// 链式向前星版存图
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<iostream>
using namespace std;
const int maxn=1e5+10000;
struct node{
    int l, r, sum;
}t[maxn*40];
int tot; //tot是主席树点的个数 

struct edge{
    int to, next;
}e[maxn<<1];
int head[maxn], cnt; //cnt是边的个数 

int root[maxn], w[maxn], id[maxn];

int dep[maxn], f[maxn], size[maxn], son[maxn];
int top[maxn];
int n, m, nn;
inline int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}
inline void add(int u, int v)
{
    e[++cnt].next=head[u];
    e[cnt].to=v;
    head[u]=cnt;
}

int getid(int x)
{
    return (lower_bound(&id[1] , &id[nn+1] , x)-id);
}

void update(int l, int r, int pre, int &now, int pos)
{
    t[++tot]=t[pre];
    t[tot].sum++;
    now=tot;
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(pos<=mid)
        update(l, mid, t[pre].l, t[now].l, pos);
    else
        update(mid+1, r, t[pre].r, t[now].r, pos);
}

int query(int l, int r, int x, int y, int lca, int gra, int k)
{
    if(l==r)
        return l;
    int mid=(l+r)>>1;
    int sum=t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum ;
    if(k<=sum)
        return query(l, mid, t[x].l, t[y].l, t[lca].l, t[gra].l, k);
    else
        return query(mid+1, r, t[x].r, t[y].r, t[lca].r, t[gra].r, k-sum);
}
void dfs1(int u, int fa, int depth)
{
    size[u]=1;
    dep[u]=depth;
    f[u]=fa;
    for(int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)
            continue;
        dfs1(v, u, depth+1);
        size[u]+=size[v];
        if(size[v] > size[son[u]])
            son[u]=v;
    }
}
void dfs2(int u, int tp)
{
    top[u]=tp;
    if(!son[u])
        return ;
    dfs2(son[u], tp);
    for(int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].to;
        if(v==son[u] || v==f[u])
            continue;
        dfs2(v, v);
    }
}

int lca(int x,int y)
{
    int fx=top[x], fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx] < dep[fy])
        {
            swap(x, y);
            swap(fx, fy);
        }
        x=f[fx];
        fx=top[x];
    }
    return dep[x] < dep[y] ? x : y ;
}
void dfs(int u)
{
    int pos=getid(w[u]);
    update(1, nn, root[f[u]], root[u], pos);
    for(int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].to;
        if(v==f[u])
            continue;
        dfs(v);
    }
}
int main()
{
    scanf("%d%d", &n, &m);
    for(int i=1; i<=n; i++)
    {
        scanf("%d", &w[i]);
        id[i]=w[i];
    }
    sort(&id[1], &id[n+1] );
    nn = (unique( &id[1], &id[n+1])-id-1) ;
    int x, y;
    for(int i=1; i<n; i++)
    {
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    dfs1(1, 0, 1);
    dfs2(1, 1);
    dfs(1);
    int k, la;
    for(int i=1; i<=m; i++)
    {
        scanf("%d%d%d", &x, &y, &k);
        la=lca(x, y);
        printf("%d\n", id[ query(1, nn, root[x], root[y], root[la], root[ f[la] ], k) ] );
    }
    return 0;
}

Count on a tree SPOJ 主席树+LCA(树链剖分实现)(两种存图方式)

标签:去重   ext   就会   query   size   max   display   root   n+1   

原文地址:https://www.cnblogs.com/alking1001/p/11432494.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!