裸题嘛。。
先考虑一条线段上如何查询颜色段数,只要对每个线段树节点多维护一个左颜色和右颜色,然后合并的时候sum[x]=sum[lc]+sum[rc]-(左儿子的右颜色==右儿子的左颜色)。。实在太久没写树剖结果码+调试花了两节多晚自习,,各种傻逼错误,什么反向边忘加,标记忘记下传。。。还有就是更新答案的时候,关键的一点是要保证当前的两点(也就是a,b)是没有被更新到的,否则很难搞。。
表示LCT要更好写。。不过在BZOJ上我的树链剖分6000+MS,LCT要13000+MS。。
树链剖分:
#include<iostream> #include<cstdio> #include<memory.h> #define maxn 100005 using namespace std; struct edge{ int e,next; } ed[maxn*2]; int n,m,i,cnt=0,ne=0,nd=0,s,e,ll,rr,cc,a[maxn],root[maxn],xu[maxn],belong[maxn],top[maxn],col[maxn],size[maxn],chain[maxn],last[maxn],pre[maxn],d[maxn]; int c[maxn*2][2],sum[maxn*2],l[maxn*2],r[maxn*2],lcol[maxn*2],rcol[maxn*2],tag[maxn*2]; char ch; void add(int s,int e){ed[++ne].e=e;ed[ne].next=a[s];a[s]=ne;} void dfs(int x) { int j,i=0,maxs=0,to; size[x]=1;belong[x]=0; for (j=a[x];j;j=ed[j].next) if (ed[j].e!=pre[x]) { to=ed[j].e;d[to]=d[x]+1;pre[to]=x; dfs(to);size[x]+=size[to]; if (size[to]>maxs) maxs=size[to],i=to; } for (j=a[x];j;j=ed[j].next) if (ed[j].e!=pre[x]) { to=ed[j].e; if (to==i) xu[x]=xu[to]+1,belong[x]=belong[to],last[x]=to,top[belong[to]]=x; else top[belong[to]]=to; } if (!belong[x]) belong[x]=++cnt,xu[x]=1,last[x]=-1,top[cnt]=x; } void update(int x) { lcol[x]=lcol[c[x][0]];rcol[x]=rcol[c[x][1]]; sum[x]=sum[c[x][0]]+sum[c[x][1]]-(rcol[c[x][0]]==lcol[c[x][1]]); } void build(int &x,int ll,int rr) { x=++nd;l[x]=ll;r[x]=rr;tag[x]=-1; if (ll==rr) { lcol[x]=rcol[x]=chain[ll]; sum[x]=1; return; } build(c[x][0],ll,(ll+rr)/2);build(c[x][1],(ll+rr)/2+1,rr); update(x); } void prepare() { int i,j; pre[1]=0;d[1]=1;dfs(1); for (i=1;i<=cnt;i++) { for (j=top[i];j!=-1;j=last[j]) chain[xu[j]]=col[j]; build(root[i],1,xu[top[i]]); } // for (i=1;i<=n;i++) printf("%d %d %d %d*\n",i,belong[i],root[belong[i]],d[i]); } void mark(int x,int cc) { if (!x) return; tag[x]=cc;sum[x]=1; lcol[x]=rcol[x]=cc; } void down(int x) { if (tag[x]!=-1) mark(c[x][0],tag[x]),mark(c[x][1],tag[x]),tag[x]=-1; } void ins(int x,int ll,int rr,int cc) { if (ll>r[x]||rr<l[x]) return; if (ll<=l[x]&&rr>=r[x]) {mark(x,cc);return;} down(x); ins(c[x][0],ll,rr,cc);ins(c[x][1],ll,rr,cc); update(x); } int query(int x,int ll,int rr) { if (ll>r[x]||rr<l[x]) return 0; if (ll<=l[x]&&rr>=r[x]) return sum[x]; down(x); int ans=query(c[x][0],ll,rr)+query(c[x][1],ll,rr),mid=(l[x]+r[x])/2; // printf("%d %d %d %d %d###\n",ll,rr,l[x],r[x],ans); if (ll<=mid&&rr>mid) ans-=(rcol[c[x][0]]==lcol[c[x][1]]); return ans; } int getc(int x,int w) { int mid=(l[x]+r[x])/2; down(x); if (l[x]==r[x]) return lcol[x]; return w<=mid? getc(c[x][0],w):getc(c[x][1],w); } void change(int a,int b,int cc) { int ba,bb; while (belong[a]!=belong[b]) { ba=belong[a];bb=belong[b]; if (d[top[ba]]<d[top[bb]]) swap(a,b),swap(ba,bb); ins(root[ba],xu[a],xu[top[ba]],cc); a=top[ba];if (pre[a]) a=pre[a]; } if (d[a]<d[b]) swap(a,b); ins(root[belong[a]],xu[a],xu[b],cc); } int solve(int a,int b) { int ba,bb,ans=0,t,aaa=a,bbb=b; bool f=false; if (a==b) return 1; while (belong[a]!=belong[b]) { f=true; ba=belong[a];bb=belong[b]; if (d[top[ba]]<d[top[bb]]) swap(a,b),swap(ba,bb),swap(aaa,bbb); ans+=query(root[ba],xu[a],xu[top[ba]]); a=top[ba]; if (pre[a]&&getc(root[ba],xu[a])==getc(root[belong[pre[a]]],xu[pre[a]])) ans--; if (pre[a]) a=pre[a];//printf("%d %d %d**\n",a,b,ans); } if (d[a]<d[b]) swap(a,b),swap(aaa,bbb); t=query(root[belong[a]],xu[a],xu[b]); if (!f) return t; // printf("%d %d***\n",t,ans); return ans+t; } int main() { freopen("2243.in","r",stdin); freopen("2243.out","w",stdout); scanf("%d%d",&n,&m); for (i=1;i<=n;i++) scanf("%d",&col[i]); for (i=1;i<n;i++) { scanf("%d%d",&s,&e); add(s,e);add(e,s); } prepare(); scanf("\n"); for (i=1;i<=m;i++) { scanf("%c%d%d",&ch,&ll,&rr); if (ch=='Q') printf("%d\n",solve(ll,rr)); else scanf("%d",&cc),change(ll,rr,cc); scanf("\n"); } fclose(stdin); fclose(stdout); }
#include<iostream> #include<cstdio> #include<memory.h> #define maxn 100005 using namespace std; struct edge{ int e,next; }ed[maxn*2]; int n,m,s,e,l,r,cc,ne=0,i,a[maxn],c[maxn][2],pre[maxn],sum[maxn],lcol[maxn],rcol[maxn],tag[maxn],col[maxn],d[maxn]; char ch; void add(int s,int e){ed[++ne].e=e;ed[ne].next=a[s];a[s]=ne;} void dfs(int x) { int j,i=0,to; tag[x]=-1;lcol[x]=rcol[x]=col[x]; sum[x]=1;c[x][0]=c[x][1]=0; for (j=a[x];j;j=ed[j].next) if (ed[j].e!=pre[x]) { to=ed[j].e;d[to]=d[x]+1;pre[to]=x; dfs(to); } } void update(int x) { sum[x]=sum[c[x][0]]+sum[c[x][1]]+1; if (c[x][0]) lcol[x]=lcol[c[x][0]],sum[x]-=(rcol[c[x][0]]==col[x]); else lcol[x]=col[x]; if (c[x][1]) rcol[x]=rcol[c[x][1]],sum[x]-=(lcol[c[x][1]]==col[x]); else rcol[x]=col[x]; } void mark(int x,int cc) { if (!x) return; lcol[x]=rcol[x]=col[x]=cc; sum[x]=1;tag[x]=cc; } void down(int x) { if (tag[x]!=-1) mark(c[x][0],tag[x]),mark(c[x][1],tag[x]),tag[x]=-1; } bool isroot(int x){return !pre[x]||(c[pre[x]][0]!=x&&c[pre[x]][1]!=x);} void rot(int x,int kind) { int y=pre[x],z=pre[y]; down(y);down(x); if (!isroot(y)&&z) c[z][c[z][1]==y]=x; c[y][!kind]=c[x][kind];pre[c[x][kind]]=y; c[x][kind]=y;pre[y]=x; pre[x]=z; update(y);update(x); } void splay(int x) { int y,z,kind; while (!isroot(x)) { y=pre[x]; if (isroot(y)) rot(x,c[y][0]==x); else { int z=pre[y],kind=c[z][0]==y; if (c[y][kind]==x) rot(x,!kind);else rot(y,kind); rot(x,kind); } } down(x); } void access(int x) { int u; splay(x); c[x][1]=0;update(x); while (pre[x]) { u=pre[x];splay(u); c[u][1]=x;update(u);splay(x); } } int lca(int x,int y) { access(y); int u; splay(x); c[x][1]=0; while (pre[x]) { u=pre[x];splay(u); if (pre[u]==0) return u; c[u][1]=x;update(u);splay(x); } return x; } int getpre(int x) { splay(x); if (!c[x][0]) return 0; x=c[x][0]; while (c[x][1]) x=c[x][1]; return x; } void change(int x,int y,int cc) { if (d[x]<d[y]) swap(x,y); int fa=lca(x,y),u; u=getpre(fa); if (!u) mark(fa,cc); else splay(u),mark(c[u][1],cc); access(x);splay(fa); if (!u) mark(fa,cc); else splay(u),mark(c[u][1],cc); } int query(int x,int y) { if (d[x]<d[y]) swap(x,y); int fa=lca(x,y),u,ans=0; u=getpre(fa); if (!u) ans+=sum[fa]; else splay(u),ans+=sum[c[u][1]]; access(x);splay(fa); if (!u) ans+=sum[fa]; else splay(u),ans+=sum[c[u][1]]; ans--; return ans; } int main() { freopen("2243.in","r",stdin); freopen("my.out","w",stdout); scanf("%d%d",&n,&m); memset(a,0,sizeof(a)); sum[0]=pre[0]=0; for (i=1;i<=n;i++) scanf("%d",&col[i]); for (i=1;i<n;i++) { scanf("%d%d",&s,&e); add(s,e);add(e,s); } pre[1]=0;d[1]=1;dfs(1); scanf("\n"); for (i=1;i<=m;i++) { scanf("%c%d%d",&ch,&l,&r); if (ch=='Q') printf("%d\n",query(l,r)); else { scanf("%d",&cc); change(l,r,cc); } scanf("\n"); } fclose(stdin); fclose(stdout); }
原文地址:http://blog.csdn.net/tag_king/article/details/45132169