我的第一道题解就写它吧。
维护区间和+花式修改。用线段树或者树状数组可以解决,但是我没怎么写过树状数组。
维护和的操作直接把左右子树和加起来。
重点是修改。
去年刚学完线段树刷完数列操作a,b,c后看道这题就弃了。现在知道了关键就是推式子,跟HAOI2012高速公路是一个套路的。
线段树就是用于维护区间的,而且因为延迟标记的存在,所以我们先考虑区间增量。
对于区间 [l,r] 增量为 Σri=l(i-L)*x。L是总的修改范围,l,r,是线段树中节点的范围,第一次就是因为没注意这两者的关系,导致公式错误连样例都过不了。
提公因式x,Σ里是等差数列直接求和之后相乘就算出增量了,这时候要考虑如何打lazy标记。
思想也是提公因式。
Σri=l ( i-L ) * x = ( Σri=l i ) * x - ( Σri=l1 ) * L * x
设 A = Σri=l i,B = Σri=l1 ,可以发现对于线段树中的每段区间A和B的值是固定的,这样一来我们只需要累计每次修改的x以及L*x的值就可以顺利下传延迟修改了。
这道题就顺利解决了。
// q.c #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; typedef long long LL; const int M=300000+10; const int mod=(int)1e9+7; struct Node { int l,r,sum,s1,s2; Node():l(0),r(0),sum(0),s1(0),s2(0) {} }; struct SegmentTree { int root; Node nd[M<<2]; SegmentTree():root(1) {} void update(int o) { nd[o].sum=((LL)nd[o<<1].sum+nd[o<<1^1].sum)%mod; } void pushdown(int o) { Node &p=nd[o],&lc=nd[o<<1],&rc=nd[o<<1^1]; lc.sum=(lc.sum+(LL)p.s1*(lc.r+lc.l)*(lc.r-lc.l+1)/2)%mod; lc.sum=(lc.sum-(LL)p.s2*(lc.r-lc.l+1)%mod+mod)%mod; lc.s1=(lc.s1+p.s1)%mod; lc.s2=(lc.s2+p.s2)%mod; rc.sum=(rc.sum+(LL)p.s1*(rc.r+rc.l)*(rc.r-rc.l+1)/2)%mod; rc.sum=(rc.sum-(LL)p.s2*(rc.r-rc.l+1)%mod+mod)%mod; rc.s1=(rc.s1+p.s1)%mod; rc.s2=(rc.s2+p.s2)%mod; p.s1=p.s2=0; } void build(int o,int l,int r) { nd[o].l=l,nd[o].r=r; if(l!=r) { int mid=(l+r)>>1; build(o<<1,l,mid); build(o<<1^1,mid+1,r); } } void add(int o,int l,int r,int x) { Node &p=nd[o]; if(l<=p.l&&p.r<=r) { p.sum=(p.sum+(LL)x*(p.r-l+p.l-l)*(p.r-p.l+1)/2)%mod; p.s1=(p.s1+x)%mod; p.s2=(p.s2+(LL)l*x)%mod; } else { if(p.s1||p.s1) pushdown(o); int mid=(p.l+p.r)>>1; if(l<=mid) add(o<<1,l,r,x); if(r>mid) add(o<<1^1,l,r,x); update(o); } } int query(int o,int l,int r) { Node p=nd[o]; if(l<=p.l&&p.r<=r) return p.sum; else { if(p.s1||p.s2) pushdown(o); int mid=(p.l+p.r)>>1,ans=0; if(l<=mid) ans=((LL)ans+query(o<<1,l,r))%mod; if(r>mid) ans=((LL)ans+query(o<<1^1,l,r))%mod; return ans; } } }t; int n,m; int main() { freopen("segment.in","r",stdin); freopen("segment.out","w",stdout); scanf("%d%d",&n,&m); t.build(t.root,1,n); int opt,l,r,x; for(int i=1;i<=m;i++) { scanf("%d%d%d",&opt,&l,&r); if(opt) scanf("%d",&x),t.add(t.root,l,r,x); else printf("%d\n",t.query(t.root,l,r)); } return 0; }
公式真的好难打啊。