码迷,mamicode.com
首页 > 其他好文 > 详细

[Avito Code Challenge 2018 G] Magic multisets(线段树)

时间:2018-10-09 10:21:42      阅读:226      评论:0      收藏:0      [点我收藏+]

标签:space   font   线段   namespace   tchar   ios   define   void   链接   

题目链接:http://codeforces.com/contest/981/problem/G

题目大意:

有n个初始为空的‘魔法’可重集,向一个‘可重集’加入元素时,若该元素未出现过,则将其加入;否则该可重集中所有元素的个数都会翻倍。

例如将$2$加入${1,3}$会得到${1,2,3}$,将$2$加入${1,2,3,3}$会得到${1,1,2,2,3,3,3,3}$.

$q$次操作,每次操作要么向一个区间内的所有可重集加入某个元素,要么询问一个区间内可重集的大小之和。

$n,q ≤ 2×10^5$

题解:

发现对于出现过该元素的区间就是区间乘,没有出现过的就是区间加

操作1我们先把$[l,r]$全部乘上2,再把之前已经出现过当前元素的区间乘上2的逆元再+1,最后合并一下左右区间就好了。合并大概就是保证这个操作时间复杂度的关键了,只是我也不知道怎么算

发现这样我们要维护每个元素出现的区间,同时支持方便的合并,我们开n个set就好了

正是做完这题我发现我竟然不会写区间乘法线段树

#include<algorithm>
#include<cstring>
#include<cstdio>
#include<iostream>
#include<set>
#define pa pair<int,int>
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;

const int N=2e5+15;
const int mod=998244353;
ll n,q,inv;
ll sum[N<<2],add[N<<2],mul[N<<2];
set <pa> st[N];  
inline ll read()
{
    char ch=getchar();
    ll s=0,f=1;
    while (ch<0||ch>9) {if (ch==-) f=-1;ch=getchar();}
    while (ch>=0&&ch<=9) {s=(s<<3)+(s<<1)+ch-0;ch=getchar();}
    return s*f;
}
ll qpow(ll a,ll b)
{
    ll re=1;
    for (;b;b>>=1,a=a*a%mod) if (b&1) re=re*a%mod;
    return re;
}
void build(int o,int l,int r)
{
    mul[o]=1;add[o]=sum[o]=0;
    if (l==r) return;
    build(o<<1,l,mid);
    build(o<<1|1,mid+1,r);
}
void pushup(int o,int l,int r)
{
    //sum[o]=sum[o<<1]+sum[o<<1|1];
    sum[o]=(sum[o<<1]*mul[o<<1]+add[o<<1]*(mid-l+1))%mod;
    sum[o]=(sum[o]+sum[o<<1|1]*mul[o<<1|1]+add[o<<1|1]*(r-mid))%mod;
}
void pushdown(int o,int l,int r)
{
    if (mul[o]!=1)
    {
        ll p=mul[o];
        mul[o]=1;
        //(sum[o<<1]*=p)%=mod;
        //(sum[o<<1|1]*=p)%=mod;
        (add[o<<1]*=p)%=mod;
        (add[o<<1|1]*=p)%=mod;
        (mul[o<<1]*=p)%=mod;
        (mul[o<<1|1]*=p)%=mod;
    }
    if (add[o]!=0)
    {
        ll p=add[o];
        add[o]=0;
    //  (sum[o<<1]+=p*(mid-l+1))%=mod;
    //    (sum[o<<1|1]+=p*(r-mid))%=mod;
        (add[o<<1]+=p)%=mod;
        (add[o<<1|1]+=p)%=mod;
    }
}
void update(int o,int l,int r,int x,int y,ll z,int flag)
{
    if (l>=x&&r<=y)
    {
        if (flag==1)
        {
            (mul[o]*=z)%=mod;
            (add[o]*=z)%=mod;
        //    (sum[o]*=z)%=mod;
        }
        if (flag==2)
        {
            (add[o]+=z)%=mod;
        //    (sum[o]+=(r-l+1)*z)%=mod;
        }
        return;
    }
    pushdown(o,l,r);
    if (x<=mid) update(o<<1,l,mid,x,y,z,flag);
    if (y>mid) update(o<<1|1,mid+1,r,x,y,z,flag);
    pushup(o,l,r);
} 
void merge(int x,int L,int R)
{
    set<pa>::iterator it;
    it=st[x].lower_bound(pa(L,L));
    for (;it!=st[x].end();it++)
    {
        set<pa>::iterator lst=it;lst--;
        int l=(*lst).second+1;
        int r=(*it).first-1;
        int upl=max(L,l);
        int upr=min(R,r);
        if (upr>=upl) 
        {
            update(1,1,n,upl,upr,inv,1);
            update(1,1,n,upl,upr,1,2);
        }
        if ((*it).first>=R) break;
    }
    int mergeL=L,mergeR=R;
    it=st[x].upper_bound(pa(L,L));it--;
    if ((*it).second>=mergeL) mergeL=(*it).first;
    it=st[x].upper_bound(pa(R,R));it--;
    if ((*it).second>=mergeR) mergeR=(*it).second;
    vector <pa> er;
    it=st[x].lower_bound(pa(mergeL,mergeL));
    for (;it!=st[x].end();it++)
    {
        pa e=*it;
        if (e.first>=mergeL&&e.second<=mergeR) er.push_back(e);
        else break;
    }
    for (int i=0;i<er.size();i++) st[x].erase(er[i]);
    st[x].insert(pa(mergeL,mergeR));
}
ll query(int o,int l,int r,int x,int y)
{
    if (l>=x&&r<=y) return (sum[o]*mul[o]+add[o]*(r-l+1))%mod;
//    if (l>=x&&r<=y) return sum[o]%mod;
    pushdown(o,l,r);
    ll re=0;
    if (x<=mid) (re+=query(o<<1,l,mid,x,y))%=mod;
    if (y>mid) (re+=query(o<<1|1,mid+1,r,x,y))%=mod;
    pushup(o,l,r);
    return re; 
}
int main()
{
    inv=qpow(2,mod-2);
    //inv=(mod+1)/2;
    n=read();q=read();
    for (int i=0;i<=n;i++) 
    {
        st[i].insert(pa(0,0));
        st[i].insert(pa(n+1,n+1));
    }
    build(1,1,n);
    while (q--)
    {
        int opt=read();
        if (opt==1)
        {
            int l=read(),r=read(),z=read();
            update(1,1,n,l,r,2,1);
            merge(z,l,r);
        }
        if (opt==2)
        {
            int l=read(),r=read();
            printf("%lld\n",query(1,1,n,l,r));
        }
    }
    return 0;
}

[Avito Code Challenge 2018 G] Magic multisets(线段树)

标签:space   font   线段   namespace   tchar   ios   define   void   链接   

原文地址:https://www.cnblogs.com/xxzh/p/9758635.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!