码迷,mamicode.com
首页 > 其他好文 > 详细

主席树入门

时间:2018-08-27 21:21:40      阅读:152      评论:0      收藏:0      [点我收藏+]

标签:时间   pos   algorithm   离散   log   入门   for   个数   efi   

/*
主席树入门

从一个题目切入吧
HDU6230 Palindrome 
最后转化成求区间[l,r]里面有几个数比x小 
一开始就想无脑主席树
之前只会了一个板子 很不灵活 只会查第k小
然后二分i是第几小,套上主席树 这时候会多一个二分logn
然后跑的比较慢在超时的边缘试探 
然后看到了划分树这个东西 还蛮好理解就写了一发
依旧在超时的边缘
其实查有几个比x小的数不需要套那一层二分
只要理解好主席树 就可以扔掉二分 
下面简单说一下我理解的这种比较神奇的数据结构

也是根据题目来说 板子题 poj2104
我们开一颗线段树 叶子节点就表示i出现的次数
准确的说 是前缀出现的次数 那就有n棵线段树
查询的时候利用前缀和的思想来做
下面优化空间 
我们考虑从i到i+1发生了什么变化
只加进来一个数 之后从根节点到他的这条链发生了改变
因此我们只记录这个东西
这样就很好的解决了上面的问题
注意存的是每个数的个数 下标是数的大小 如果数字比较大要离散化


然后这个是poj2104的代码 
*/
#include<cstdio>
#include<algorithm>
#define maxn 60010
#define mid (l+r)/2
using namespace std;
int n,m,num,a[maxn],A[maxn],s[maxn],lc[maxn],rc[maxn],r[maxn],cnt;
int init(){
    int x=0,f=1;char s=getchar();
    while(s<0||s>9){if(s==-)f=-1;s=getchar();}
    while(s>=0&&s<=9){x=x*10+s-0;s=getchar();}
    return x*f;
}
int Build(int S,int L,int R){
    cnt++;s[cnt]=S;
    lc[cnt]=L;rc[cnt]=R;
    return cnt;
}
void Insert(int &now,int pre,int l,int r,int k){
    now=Build(s[pre]+1,lc[pre],rc[pre]);
    if(l==r)return;
    if(k<=mid)Insert(lc[now],lc[pre],l,mid,k);
    else Insert(rc[now],rc[pre],mid+1,r,k);
}
int Query(int L,int R,int l,int r,int k){
    if(l==r)return l;
    int sum=s[lc[R]]-s[lc[L]];
    if(sum>=k)return Query(lc[L],lc[R],l,mid,k);
    else return Query(rc[L],rc[R],mid+1,r,k-sum);
}
int main()
{
    n=init();m=init();
    for(int i=1;i<=n;i++){
        a[i]=init();A[i]=a[i];
    }
    sort(A+1,A+1+n);
    num=unique(A+1,A+1+n)-A-1;
    for(int i=1;i<=n;i++){
        int pos=lower_bound(A+1,A+1+num,a[i])-A;
        Insert(r[i],r[i-1],1,num,pos);
    }
    for(int i=1;i<=m;i++){
        int L=init(),R=init();
        int pos=(R-L)/2+1;
        pos=Query(r[L-1],r[R],1,num,pos);
        printf("%d\n",A[pos]);
    }
    return 0;
}
/*
 然后我们回到hdu6230 考虑快速查询区间比x小的数的个数
 先看下施展套一个二分的无脑主席树 2800ms+ (考虑到数字和下标一样 就没有离散化) 
*/

#include<iostream>
#include<cstdio>
#include<cstring>
#define mid (l+r)/2
#define maxn 500010
#define ll long long
using namespace std;
int T,n,len[maxn],a[maxn],s[maxn*30],lc[maxn*30],rc[maxn*30],r[maxn*30],cnt; 
ll ans;
char c[maxn];
void Clear(){
    memset(len,0,sizeof(len));
    memset(s,0,sizeof(s));
    memset(c,0,sizeof(c));
    ans=0;cnt=0; 
}
void Ready(){
    c[0]=#;c[n+1]=#;
}
void Mar(){
    int mx=-1,id=-1;
    for(int i=1;i<=n;i++){
        if(i<=id+mx)len[i]=min(len[2*id-i],id+mx-i);
        while(i-len[i]-1>=1&&i+len[i]+1<=n&&c[i-len[i]-1]==c[i+len[i]+1])len[i]++;
        if(i+len[i]>id+mx)id=i,mx=len[i];
    }
}
int Build(int S,int L,int R){
    cnt++;s[cnt]=S;
    lc[cnt]=L;rc[cnt]=R;
    return cnt;
}
void Insert(int &now,int pre,int l,int r,int k){
    now=Build(s[pre]+1,lc[pre],rc[pre]); 
    if(l==r)return;
    if(k<=mid)Insert(lc[now],lc[pre],l,mid,k);
    else Insert(rc[now],rc[pre],mid+1,r,k);
}
int Query(int L,int R,int l,int r,int k){
    if(l==r)return l;
    int sum=s[lc[R]]-s[lc[L]];
    if(sum>=k)return Query(lc[L],lc[R],l,mid,k);
    else return Query(rc[L],rc[R],mid+1,r,k-sum);
}
void Solve(){
    for(int i=1;i<=n;i++)
        a[i]=i-len[i];
    for(int i=1;i<=n;i++){
        Insert(r[i],r[i-1],1,n,a[i]);
    }
    for(int i=1;i<=n;i++){
        int L=i+1,R=i+len[i];
        int Li=1,Ri=R-L+1;
        while(Li<=Ri){
            int Mid=(Li+Ri)/2;
            int pos=Query(r[L-1],r[R],1,n,Mid);
            if(pos<=i)Li=Mid+1;
            else Ri=Mid-1;
        }
        ans+=Li-1;
    }
}
int main(){
    scanf("%d",&T);
    while(T--){
        Clear();scanf("%s",c+1);n=strlen(c+1);
        Ready();Mar();Solve();printf("%lld\n",ans);
    }
    return 0;
}

/*这个是套的划分树  时间差不多*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mid (l+r)/2
#define maxn 500010
#define ll long long
using namespace std;
int T,n,len[maxn],a[maxn],val[20][maxn],num[20][maxn];
ll ans;
char ss[maxn],c[maxn];
void Clear(){
    memset(len,0,sizeof(len));
    memset(a,0,sizeof(a));ans=0;
    memset(num[0],0,sizeof(num[0]));
}
void Ready(){
    c[0]=#;c[n+1]=#;
}
void Mar(){
    int mx=-1,id=-1;
    for(int i=1;i<=n;i++){
        if(i<=id+mx)len[i]=min(len[2*id-i],id+mx-i);
        while(i-len[i]-1>=1&&i+len[i]+1<=n&&c[i-len[i]-1]==c[i+len[i]+1])len[i]++;
        if(i+len[i]>id+mx)id=i,mx=len[i];
    }
}
void Build(int l,int r,int c){
    if(l==r)return;int isame=mid-l+1;//isame保存有多少和sorted[mid]一样大的数进入左孩子
    for(int i=l;i<=r;i++)if(val[c][i]<a[mid])isame--;
    int ln=l,rn=mid+1;//本结点两个孩子结点的开头,ln左
    for(int i=l;i<=r;i++){
        if(i==l)num[c][i]=0;
        else num[c][i]=num[c][i-1];
        if(val[c][i]<a[mid]||(val[c][i]==a[mid]&&isame>0)){
            val[c+1][ln++]=val[c][i];num[c][i]++;
            if(val[c][i]==a[mid])isame--;
        }
        else val[c+1][rn++]=val[c][i];
    }
    Build(l,mid,c+1);Build(mid+1,r,c+1);
}
int Query(int c,int sl,int sr,int l,int r,int k){
    if(sl==sr)return val[c][sl];
    int ly;if(l==sl)ly=0;else ly=num[c][l-1];//ly 表示l 前面有多少元素进入左孩子
    int tolef=num[c][r]-ly;  //这一层l到r之间进入左子树的有tolef个
       if(tolef>=k){
        return Query(c+1,sl,(sl+sr)/2,sl+ly,sl+num[c][r]-1,k);
       }
       else{
         // l-sl 表示l前面有多少数,再减ly 表示这些数中去右子树的有多少个
         int lr = (sl+sr)/2 + 1 + (l-sl-ly);  //l-r 去右边的开头位置
         // r-l+1 表示l到r有多少数,减去去左边的,剩下是去右边的,去右边1个,下标就是lr,所以减1
         return Query(c+1,(sl+sr)/2+1,sr,lr,lr+r-l+1-tolef-1,k-tolef);
       }
}
void Solve(){
    
    for(int i=1;i<=n;i++){
        a[i]=i-len[i];val[0][i]=a[i];
    }
    sort(a+1,a+1+n);Build(1,n,0);
    for(int i=1;i<=n;i++){
        int L=i+1,R=i+len[i];
        int Li=1,Ri=R-L+1;
        while(Li<=Ri){
            int Mid=(Li+Ri)/2;
            if(Query(0,1,n,L,R,Mid)<=i)Li=Mid+1;
            else Ri=Mid-1;
        }
        ans+=Li-1;
    }
}
int main(){
    scanf("%d",&T);
    while(T--){
        Clear();scanf("%s",c+1);n=strlen(c+1);
        Ready();Mar();Solve();printf("%lld\n",ans);
    }
    return 0;
}
/*

很慢啊这样子施展 随便加几组数据就要GG了 
因为是按下标存的 找<=x的数的个数 实际上就是1-x的区间和
我们退化回线段树的思想  就简单的区间查询就好了 去掉了二分
1200+ms 

*/

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mid (l+r)/2
#define maxn 500010
#define ll long long
using namespace std;
int T,n,len[maxn],a[maxn],s[maxn*30],lc[maxn*30],rc[maxn*30],r[maxn*30],cnt;
ll ans;
char c[maxn];
void Clear(){
    memset(len,0,sizeof(len));
    memset(c,0,sizeof(c));
    ans=0;cnt=0;
}
void Ready(){
    c[0]=#;c[n+1]=#;
}
void Mar(){
    int mx=-1,id=-1;
    for(int i=1;i<=n;i++){
        if(i<=id+mx)len[i]=min(len[2*id-i],id+mx-i);
        while(i-len[i]-1>=1&&i+len[i]+1<=n&&c[i-len[i]-1]==c[i+len[i]+1])len[i]++;
        if(i+len[i]>id+mx)id=i,mx=len[i];
    }
}
int Build(int S,int L,int R){
    cnt++;s[cnt]=S;
    lc[cnt]=L;rc[cnt]=R;
    return cnt;
}
void Insert(int &now,int pre,int l,int r,int k){
    now=Build(s[pre]+1,lc[pre],rc[pre]);
    if(l==r)return;
    if(k<=mid)Insert(lc[now],lc[pre],l,mid,k);
    else Insert(rc[now],rc[pre],mid+1,r,k);
}
int Query(int L,int R,int x,int y,int l,int r){
    if(x<=l&&y>=r)return s[R]-s[L];
    int res=0;
    if(x<=mid)res+=Query(lc[L],lc[R],x,y,l,mid);
    if(y>mid)res+=Query(rc[L],rc[R],x,y,mid+1,r);
    return res;
}
void Solve(){
    for(int i=1;i<=n;i++)
        a[i]=i-len[i];
    for(int i=1;i<=n;i++)
        Insert(r[i],r[i-1],1,n,a[i]);
    for(int i=1;i<=n;i++){
        int L=i+1,R=i+len[i];
        if(L>R)continue;
        ans+=Query(r[L-1],r[R],1,i,1,n);
    }
}
int main(){
    scanf("%d",&T);
    while(T--){
        Clear();scanf("%s",c+1);n=strlen(c+1);
        Ready();Mar();Solve();printf("%lld\n",ans);
    }
    return 0;
}

 

主席树入门

标签:时间   pos   algorithm   离散   log   入门   for   个数   efi   

原文地址:https://www.cnblogs.com/yanlifneg/p/9544175.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!