第一次写虚树,感觉好厉害呀~首先,这道题目的树形dp是非常显然的,要控制一个点&其子树所有点,要么在子树内部割边,要么直接切点该点与父亲的连边。所以dp[u]表示控制u点所需的最小代价。只是,注意到这样dp的复杂度是O(nm)的,十分不可接受,妥妥的TLE。不过,题目给出的条件中还有一条:Σki<=500000;说明虽然总共的点很多,但实际上每一次对答案可能有影响的点很少。
再一次想到之前dp的时候就应该发现的性质:一条链上,只需要在意链首的点(控制了链首也就控制了整棵子树),这样我们就可以想到:要是这一棵树很小,能够把对答案没有贡献的点尽量都去掉就可以以很小的复杂度完成每一轮询问的dp了。
怎样构建一棵虚树呢?首先,将整颗树dfs一遍,保存每一个点的dfs序号(记为dfn[i])。对于一次询问中的点:a[i]而言,将其按照dfs序从小到大排序,之后两两求出lca,排除那些在同一条链上的点(只保留链首)。之后,我们将1号点放入栈中。这个栈是一个单调栈,保证在任何时候栈中的元素都是一条链,且栈顶元素深度最大。我们记栈顶元素与当前点(之前保留下来的点)的lca为lca,之后的操作就十分显然了,我们要不断的将栈顶元素退栈(在退栈的同时连边构造虚树),直到退回lca与当前元素的那一条链上。注意如果lca不是最后的栈顶元素,lca也要进栈(在虚树上必须保留的一个点,记录了分叉的情况)。最后不要忘记将剩下的元素也用边连起来。
在虚树上面跑dp,一共也没几个点,自然就跑得很快啦。
#include <bits/stdc++.h> using namespace std; #define INF 99999999999LL #define maxn 280000 #define ll long long int timer, cnp = 1, n, m, head[maxn], dfn[maxn], dep[maxn], gra[maxn][25]; int a[maxn], s[maxn], top, tot; ll val[maxn], dp[maxn]; struct edge { int to, last; ll co; }E[maxn * 2]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < ‘0‘ || c > ‘9‘) { if(c == ‘-‘) k = -1; c = getchar(); } while(c >= ‘0‘ && c <= ‘9‘) x = x * 10 + c - ‘0‘, c = getchar(); return x * k; } bool cmp(int a, int b) { return dfn[a] < dfn[b]; } void add(int x, int y, int co = 0) { if(x == y) return; E[cnp].to = y, E[cnp].co = co; E[cnp].last = head[x], head[x] = cnp ++; } int LCA(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = 24; ~i; i --) if(dep[gra[x][i]] >= dep[y]) x = gra[x][i]; for(int i = 24; ~i; i --) if(gra[x][i] != gra[y][i]) x = gra[x][i], y = gra[y][i]; return (x == y) ? x : gra[x][0]; } void dfs(int u, int fa) { gra[u][0] = fa, dfn[u] = ++ timer, dep[u] = dep[fa] + 1; for(int i = 1; i <= 24; i ++) gra[u][i] = gra[gra[u][i - 1]][i - 1]; for(int i = head[u]; i; i = E[i].last) { int v = E[i].to; if(v == fa) continue; val[v] = min(E[i].co, val[u]); dfs(v, u); } } void DP(int u, int fa) { ll c = 0; dp[u] = val[u]; for(int i = head[u]; i; i = E[i].last) { int v = E[i].to; if(v == fa) continue; DP(v, u); c += dp[v]; } head[u] = 0; if(c && c < val[u]) dp[u] = c; } void solve() { int k = read(); for(int i = 1; i <= k; i ++) a[i] = read(); sort(a + 1, a + 1 + k, cmp); cnp = 1, s[1] = 1; top = 1, tot = 1; for(int i = 2; i <= k; i ++) if(LCA(a[i], a[tot]) != a[tot]) a[++ tot] = a[i]; for(int i = 1; i <= tot; i ++) { int lca = LCA(s[top], a[i]); while(23333) { if(dep[lca] >= dep[s[top - 1]]) { add(lca, s[top]); top --; if(lca != s[top]) s[++ top] = lca; break; } add(s[top - 1], s[top]), top --; } s[++ top] = a[i]; } while(top > 1) add(s[top - 1], s[top]), top --; DP(1, 0); printf("%lld\n", dp[1]); } int main() { n = read(); val[1] = INF; for(int i = 1; i <= n - 1; i ++) { int x = read(), y = read(), z = read(); add(x, y, z), add(y, x, z); } dfs(1, 0); memset(head, 0, sizeof(head)); m = read(); for(int i = 1; i <= m; i ++) solve(); return 0; }