题目大意:求区间不同类型的数字之和。
离线:将询问按 r 值排序,拿到一个询问的时候,我们将比位置在r之前的数字全部加入树状数组,加入数字的规则如下:
假设加入的数字为x,如果x之前没有出现过,直接加入,否则删除前一个x,再加入x,这样保证了树状数组里面一类数只有一个,并且都是最靠右的数,
这样显然是最优并且正确的。
#include<bits/stdc++.h> #define fi first #define se second #define mk make_pair #define pii pair<int,int> #define read(x) scanf("%d",&x) #define sread(x) scanf("%s",x) #define dread(x) scanf("%lf",&x) #define lread(x) scanf("%lld",&x) using namespace std; typedef long long ll; const int inf=0x3f3f3f3f; const int INF=0x3f3f3f3f3f3f3f3f; const int N=3e5+7; const int M=1e5+7; int n,a[N],pre[N]; ll ans[N]; struct Qus { int l,r,id; bool operator < (const Qus &rhs)const { if(r==rhs.r) return l<rhs.l; return r<rhs.r; } }qus[M]; struct BIT { ll a[N]; void init(){memset(a,0,sizeof(a));} void modify(int pos,int v) { for(int i=pos;i<=n;i+=i&-i) a[i]+=v; } ll sum(int pos) { ll ans=0; for(int i=pos;i;i-=i&-i) ans+=a[i]; return ans; } }bit; int main() { int T; read(T); while(T--) { map<int,int> mp; bit.init(); read(n); for(int i=1;i<=n;i++) { read(a[i]); pre[i]=mp[a[i]]; mp[a[i]]=i; } int q; read(q); for(int i=1;i<=q;i++) read(qus[i].l),read(qus[i].r),qus[i].id=i; sort(qus+1,qus+q+1); int now=1; for(int i=1;i<=q;i++) { while(now<=n && now<=qus[i].r) { if(pre[now]) bit.modify(pre[now],-a[now]); bit.modify(now,a[now]); now++; } ans[qus[i].id]=bit.sum(qus[i].r)-bit.sum(qus[i].l-1); } for(int i=1;i<=q;i++) printf("%lld\n",ans[i]); } return 0; } /* */
在线:主席树,root [r] 表示,第加入第r个数之后的版本的根,root [r]为根的树保存的是每类数最靠右的那一个,思路和离线差不多。
#include<bits/stdc++.h> #define fi first #define se second #define mk make_pair #define pii pair<int,int> #define read(x) scanf("%d",&x) #define sread(x) scanf("%s",x) #define dread(x) scanf("%lf",&x) #define lread(x) scanf("%lld",&x) using namespace std; typedef long long ll; const int inf=0x3f3f3f3f; const int INF=0x3f3f3f3f3f3f3f3f; const int N=3e5+7; const int M=51; int n,pre[N],root[N]; map<int,int> mp; struct Chairman_tree { int cnt; struct node{ int l,r; ll sum; }a[20*N]; void update(int l,int r,int &x,int y,int pos,int v) { a[++cnt]=a[y]; x=cnt; a[x].sum+=v; if(l==r) return; int mid=(l+r)>>1; if(pos<=mid) update(l,mid,a[x].l,a[y].l,pos,v); else update(mid+1,r,a[x].r,a[y].r,pos,v); } ll query(int l,int r,int pos,int x) { if(l>=pos) return a[x].sum; if(r<pos) return 0; int mid=(l+r)>>1; return query(l,mid,pos,a[x].l)+query(mid+1,r,pos,a[x].r); } }seg; int main() { int T; read(T); while(T--) { map<int,int> mp; seg.cnt=0; read(n); for(int i=1;i<=n;i++) { int x; read(x); if(mp[x]) { seg.update(1,n,root[n+1],root[i-1],mp[x],-x); seg.update(1,n,root[i],root[n+1],i,x); } else seg.update(1,n,root[i],root[i-1],i,x); mp[x]=i; } int q; read(q); while(q--) { int l,r; read(l); read(r); printf("%lld\n",seg.query(1,n,l,root[r])); } } return 0; }