标签:lan next += 16px struct == define top while
一道很有意思的题……树上差分系列。
这个题一开始就想着怎么用树剖+线段树求和了,然后一想线段树建树是O(nlogn),然后再乘以一个k的50,好像T了……?
后来发现这个题其实没有任何修改不需要线段树……完全是可以用差分求解的。
这样的话其实预处理就变成了O(nk)的,我们用pre[i][j]表示从根(1)到节点i上所经过的点的j次方和,很容易得到下式:
pre[i][j] = pre[fa[i]][j] + qpow(dep[i],j).这个显然是在树剖的第二次dfs处理好的。
然后我们就差分就行了,用树剖顺便把LCA求出来,然后答案就是:pre[x][k] + pre[y][k] - pre[lca(x,y)][k] - pre[fa[lca(x,y)]][k],然后注意这个有可能是负数,你???加一个mod可能还是负的……所以多加几个,然后取个模即可。
看一下代码:
#include<cstdio> #include<algorithm> #include<cstring> #include<iostream> #include<cmath> #include<set> #include<queue> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar(‘\n‘) using namespace std; typedef long long ll; const int M = 300005; const int INF = 1000000009; const ll mod = 998244353; ll read() { ll ans = 0,op = 1; char ch = getchar(); while(ch < ‘0‘ || ch > ‘9‘) { if(ch == ‘-‘) op = -1; ch = getchar(); } while(ch >= ‘0‘ && ch <= ‘9‘) { ans *= 10; ans += ch - ‘0‘; ch = getchar(); } return ans * op; } struct edge { int next,to; }e[M<<1]; int n,x,y,ecnt,head[M],fa[M],k,dep[M],top[M],size[M],hson[M],m; ll pre[M][55]; ll qpow(ll a,ll b) { ll p = 1; while(b) { if(b&1) p *= a,p %= mod; a *= a,a %= mod; b >>= 1; } return p; } void add(int x,int y) { e[++ecnt].to = y; e[ecnt].next = head[x]; head[x] = ecnt; } void dfs1(int x,int f,int depth) { size[x] = 1,fa[x] = f,dep[x] = depth; int maxson = -1; for(int i = head[x];i;i = e[i].next) { if(e[i].to == f) continue; dfs1(e[i].to,x,depth+1); size[x] += size[e[i].to]; if(size[e[i].to] > maxson) maxson = size[e[i].to],hson[x] = e[i].to; } } void dfs2(int x,int t) { top[x] = t; rep(i,1,50) pre[x][i] = pre[fa[x]][i] + qpow(dep[x],i),pre[x][i] %= mod; if(!hson[x]) return; dfs2(hson[x],t); for(int i = head[x];i;i = e[i].next) { if(e[i].to == fa[x] || e[i].to == hson[x]) continue; dfs2(e[i].to,e[i].to); } } int lca(int x,int y) { while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x,y); return x; } int main() { n = read(); rep(i,1,n-1) x = read(),y = read(),add(x,y),add(y,x); dfs1(1,0,0),dfs2(1,1); m = read(); rep(i,1,m) { x = read(),y = read(),k = read(); ll f = lca(x,y); ll g = (pre[x][k] + pre[y][k] - pre[f][k] - pre[fa[f]][k]); g = (g + 10 * mod) % mod; printf("%lld\n",g); } return 0; }
标签:lan next += 16px struct == define top while
原文地址:https://www.cnblogs.com/captain1/p/9757519.html