码迷,mamicode.com
首页 > 其他好文 > 详细

树链剖分模板

时间:2019-08-22 23:51:04      阅读:92      评论:0      收藏:0      [点我收藏+]

标签:sum   ==   最短路   line   code   algorithm   ems   更新   include   

树链剖分模板

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define LL long long
using namespace std;
const int MAXN = 2e5+10;//点的个数
struct node{
  int l,r;
  int sum,laze;//线段树
}tree[MAXN<<2];
struct edge{
   int next,to;
}e[MAXN<<1];
int a[MAXN];
int head[MAXN];
int siz[MAXN];//子树的大小
int top[MAXN];//重链的顶端
int son[MAXN];//每个节点的重儿子
int d[MAXN];//每个节点的深度
int fa[MAXN];//每个节点的父亲节点
int id[MAXN];//每个节点的DFS序
int rk[MAXN];//每个DFS序对应的节点
inline int L(int x){return x<<1;};
inline int R(int x){return x<<1|1;};
inline int MID(int l,int r){return (l+r)>>1;};
int n,m,r,MOD,uu,vv;
int cnt=0;
void add(int x,int y){
   e[++cnt].next=head[x];
   e[cnt].to=y;
   head[x]=cnt;
}
void dfs1(int u,int f,int depth){
   d[u]=depth;
   fa[u]=f;
   siz[u]=1;  //这个点本身size=1
   for (int i=head[u];i;i=e[i].next){
     int v=e[i].to;
     if (v==f)
        continue;
     dfs1(v,u,depth+1); //层次深度+1
     siz[u]+=siz[v];   //子节点的size已经被处理,用它来更新父亲节点
     if (siz[v]>siz[son[u]])
        son[u]=v;      //选取size最大的作为重儿子并不断更新
   }
}
void dfs2(int u,int t){
    top[u]=t;     //标记这个节点,重链顶端
    id[u]=++cnt;  //标记DFS序列
    rk[cnt]=a[u];
    if (!son[u])//如果到根节点
        return;
    dfs2(son[u],t);
    //我们选择有限进入重儿子,让重儿子的DFS序连续
    for (int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if (v!=son[u] && v!=fa[u])//如果一个点不是重儿子,并且这个节点也不是其父亲节点
            dfs2(v,v);
    }
}
void push_down(int root){
    if (tree[root].laze){
        tree[L(root)].laze+=tree[root].laze;
        tree[R(root)].laze+=tree[root].laze;
        tree[L(root)].sum+=(tree[L(root)].r-tree[L(root)].l+1)*tree[root].laze;
        tree[R(root)].sum+=(tree[R(root)].r-tree[R(root)].l+1)*tree[root].laze;
        tree[L(root)].sum%=MOD;
        tree[R(root)].sum%=MOD;
        tree[root].laze=0;
    }
}
void buildtree(int root,int l,int r){
   tree[root].l=l;
   tree[root].r=r;
   if (l==r){
      tree[root].sum=rk[l]%MOD;
      return ;
   }
   int mid=MID(l,r);
   buildtree(L(root),l,mid);
   buildtree(R(root),mid+1,r);
   tree[root].sum=(tree[L(root)].sum+tree[R(root)].sum)%MOD;
}
int query(int root,int ql,int qr){
  int l=tree[root].l;
  int r=tree[root].r;
  int res=0;
  if (ql<=l && r<=qr){
    return tree[root].sum;
  }
  push_down(root);
  int mid=MID(l,r);
  if (qr<=mid){
     res=query(L(root),ql,qr);
  }else if (ql>mid){
     res=query(R(root),ql,qr);
  }else {
     res=query(L(root),ql,mid);
     res+=query(R(root),mid+1,qr);
  }
  return res%MOD;
}
void update(int root,int ul,int ur,int w){
   int l=tree[root].l;
   int r=tree[root].r;
   if (ul<=l && r<=ur){
      tree[root].laze+=w;
      tree[root].sum+=(r-l+1)*w;
      return ;
   }
   push_down(root);
   int mid=MID(l,r);
   if (ur<=mid){
       update(L(root),ul,ur,w);
   }else if (ul>mid){
       update(R(root),ul,ur,w);
   }else{
       update(L(root),ul,mid,w);
       update(R(root),mid+1,ur,w);
   }
   tree[root].sum=(tree[L(root)].sum+tree[R(root)].sum)%MOD;
}
int qRange(int x,int y){
  int ans=0;
  while(top[x]!=top[y]){//不在一条链上
    if (d[top[x]]<d[top[y]])swap(x,y);//把x变成深的节点
    ans+=query(1,id[top[x]],id[x]);//求和
    ans%=MOD;
    x=fa[top[x]];//在跳到链的顶端的上面一个点
  }//直到两个点处于一条链上
  if (d[x]>d[y])swap(x,y);//在同一层后继续
  ans+=query(1,id[x],id[y]);
  return ans%MOD;
}
void updRange(int x,int y,int k){
  k%=MOD;
  while(top[x]!=top[y]){
    if (d[top[x]]<d[top[y]])swap(x,y);
    update(1,id[top[x]],id[x],k);
    x=fa[top[x]];
  }
  if (d[x]>d[y])swap(x,y);
  update(1,id[x],id[y],k);
}
int qson(int x){
   return query(1,id[x],id[x]+siz[x]-1)%MOD;
}
void updson(int x,int k){
//  cout<<id[x]<<" "<<siz[x]<<endl;
  update(1,id[x],id[x]+siz[x]-1,k);
}
int main(){

  while(~scanf("%d%d%d%d",&n,&m,&r,&MOD)){
    memset(head,0,sizeof(head));
    memset(id,0,sizeof(id));
    for (int i=1;i<=n;i++){
        scanf("%d",&a[i]);
    }
    for (int i=1;i<n;i++){
        scanf("%d%d",&uu,&vv);
        add(uu,vv);
        add(vv,uu);
    }
    cnt=0;
    dfs1(r,0,1);
    dfs2(r,r);
    buildtree(1,1,n);
    while(m--){
        int op,x,y,z;
        scanf("%d",&op);
        if (op==1){
            scanf("%d%d%d",&x,&y,&z);
            updRange(x,y,z);
        }else if (op==2){
          scanf("%d%d",&x,&y);
          printf("%d\n",qRange(x,y));
        }else if (op==3){
          scanf("%d%d",&x,&y);
         // cout<<x<<" "<<y<<endl;
          updson(x,y);
        }else {
          scanf("%d",&x);
          printf("%d\n",qson(x));
        }
    }
  }
  return 0;
}

 

树链剖分模板

标签:sum   ==   最短路   line   code   algorithm   ems   更新   include   

原文地址:https://www.cnblogs.com/bluefly-hrbust/p/11397445.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!