线段树总结
导引
有时候我们经常需要对一组序列进行操作,修改或者查询一段区间的信息,朴素的算法是暴力修改暴力查询,但如果数据范围到\(10^4\)以上的话就可能会超时,所以,我们就需要用一些数据结构去维护这个序列,线段树就是其中之一。
思想
对于一个序列,我们用二分的思想,将他分为两个长度相等的区间,对每个区间内的所有信息分别维护,比如区间和,区间最大值,区间积等等,这个我们可以直接很快将两个区间信息合并,就成了整个区间的信息,而如果区间内只有一个元素的时候,它的信息就可以直接通过这个值得出,然后一步步合并,我们就把整个序列维护成了一颗完全二叉树,所以树高是log级别的,并且我们知道完全二叉树的节点编号的是有规律的,一个节点的编号位N,那么左儿子是N×2和右儿子是N×2+1,这样构建就非常方便,而这个构建过程往往使用递归实现。
以区间求和为例:
void build(int now,int l,int r){
if(l==r){
tr[now]=a[l];
return;
}
int mid=(l+r)>>1;
build(now*2,l,mid);
build(now*2+1,mid+1,r);
tr[now]=tr[now*2]+tr[now*2+1];
}
修改
一般题目都不会仅仅只有查询操作,往往都会有修改操作。修改一个区间我们往往不一一修改包括这个区间的节点,而是在节点上做一个标记,每当访问到一个节点时,就将这个节点的信息下放到它的儿子,于是时间复杂度就变成了O(\(\log n\))。
void update(int now,int l,int r,int ll,int rr,int k)
{
if(l>rr||r<ll)return;
if(l>=ll&&r<=rr)
{
lz[now]+=k;
tr[now]+=k*(r-l+1);
return;
}
int mid=(l+r)>>1;
update(now*2,l,mid,ll,rr,k);
update(now*2+1,mid+1,r,ll,rr,k);
tr[now]=tr[now*2]+tr[now*2+1]+lz[now]*(r-l+1);
}
查询
对于查询一个区间的信息往往都是查询深度最小的在这个区间内的节点,其实就是每当访问到一个在区间内的节点时就将信息合并,然后立刻返回,若访问到与这个区间没有交点的节点时,直接返回,这样时间复杂度也是O(\(\log n\)).
LL sum1(int now,int l,int r,int ll,int rr,LL s)
{
if(l>rr)return 0;
if(r<ll)return 0;
if(l>=ll&&r<=rr)return tr[now]+(r-l+1)*s;
int mid=(l+r)>>1;
return sum1(now*2,l,mid,ll,rr,s+lz[now])+sum1(now*2+1,mid+1,r,ll,rr,s+lz[now]);
}
完整代码(题目来源 洛谷P3372【模板】线段树1)
这是好久以前打的,特别丑,其实不想贴的
#include<iostream>
#include<cstdio>
#include<cstring>
#define LL long long
using namespace std;
LL a[500001];
LL tr[10000001];
LL lz[10000001];
void bt(int now,int l,int r)
{
if(l==r)
{
tr[now]=a[l];
return;
}
int mid=(l+r)>>1;
bt(now*2,l,mid);
bt(now*2+1,mid+1,r);
tr[now]=tr[now*2]+tr[now*2+1];
}
void update(int now,int l,int r,int ll,int rr,int k)
{
if(l>rr||r<ll)return;
if(l>=ll&&r<=rr)
{
lz[now]+=k;
tr[now]+=k*(r-l+1);
return;
}
int mid=(l+r)>>1;
update(now*2,l,mid,ll,rr,k);
update(now*2+1,mid+1,r,ll,rr,k);
tr[now]=tr[now*2]+tr[now*2+1]+lz[now]*(r-l+1);
}
LL sum1(int now,int l,int r,int ll,int rr,LL s)
{
if(l>rr)return 0;
if(r<ll)return 0;
if(l>=ll&&r<=rr)return tr[now]+(r-l+1)*s;
int mid=(l+r)>>1;
return sum1(now*2,l,mid,ll,rr,s+lz[now])+sum1(now*2+1,mid+1,r,ll,rr,s+lz[now]);
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i)
{
scanf("%d",&a[i]);
}
bt(1,1,n);
while(m--)
{
LL op,x,y;
scanf("%lld%lld%lld",&op,&x,&y);
if(op==1)
{
LL k;
scanf("%lld",&k);
update(1,1,n,x,y,k);
}
if(op==2)
{
printf("%lld\n",sum1(1,1,n,x,y,0));
}
}
return 0;
}