标签:amp math 合并 inline 代码 公式 put 一个 get
在一个有\(n\)个点的树上,找到一条长度为\(k\)的链(覆盖\(k+1\)个点),使得所有点到这条链的距离最小。
- 首先定义\(siz[u]\)表示以\(u\)为根节点的子树大小,\(g[u]\)表示\(u\)子树内所有点到\(u\)的距离和,\(h[u]\)表示\(u\)子树外节点到\(u\)的距离之和。
- \(siz\)比较好求解,\(g[u]\)其实就是\(\sum_{v\in son[u]}siz[v]+g[v]\),其意义就是其中的每一个点到\(v\)的点再加上\(u->v\)中\(1\)的贡献,\(h[u]\)的求解类似于\(g\),我们考虑从父亲节点\(h[fa]\)来转移。
- 其公式为\(h[u]=h[fa]+g[fa]-g[u]-siz[u]+n-siz[u]\),具体推导的过程就分成两个部分,先把非子树内的点到\(fa\)的距离算出来之后加上\(fa->u\)的距离。
- 接下来考虑树形\(DP\)。
- 我们定义\(dp[u][i]\)表示链的一段在\(u\),并且这条链覆盖了\(i\)个点,这个子树内其他点到这条链的最小距离。
- 状态转移方程:\(dp[u][i]=min(dp[u][i],dp[v][i-1]+g[u]-g[v]-siz[v])\)就是其余子树的点到\(u\)的距离一定比到\(v\)优。
- 那么整体的答案就是\(min(dp[u][k+1])\)这个是其中一部分答案。
- 还有一部分答案就是,两棵子树合并起来得到的答案。
- 也就是\(min(dp[u][i],dp[v][k+1-i])\)
- 时间复杂度 \(\mathcal O(nk)\)
#include <bits/stdc++.h>
#define int long long
#define FI first
#define SE second
#define REP(i, s, t) for (int i = s; i <= t; i++)
#define PER(i, s, t) for (int i = s; i >= t; i--)
#define pb push_back
#define mp make_pair
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template <class T> void chkmax(T& x, T y) { x = max(x, y); }
template <class T> void chkmin(T& x, T y) { x = min(x, y); }
char gc() {
static char buf[1 << 25], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 25, stdin), p1 == p2) ? EOF : *p1++;
}
template <class T>
void re(T& x) {
x = 0; char ch = 0; int f = 1;
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
x *= f;
}
template <class T>
void pr(T x) {
if (!x) { putchar('0'); return; }
if (x < 0) x = -x, putchar('-');
static int stk[25], top; top = 0;
while (x) { stk[++top] = x % 10; x /= 10; }
while (top) putchar(stk[top--] + 48);
}
const int N = 1e4 + 5;
vector<int> G[N];
int n, k, ans;
int siz[N], g[N], h[N];
int dp[N][105];
void dfs1(int u, int fff) {
siz[u] = 1, g[u] = 0;
for (auto v : G[u]) {
if (v == fff)
continue;
dfs1(v, u);
siz[u] += siz[v];
g[u] += siz[v] + g[v];
}
}
void dfs2(int u, int fff) {
dp[u][1] = g[u];
if (fff)
h[u] = h[fff] + g[fff] - g[u] - siz[u] + n - siz[u];
for (auto v : G[u]) {
if (v == fff)
continue;
dfs2(v, u);
for (int i = 0; i <= k + 1; i++)
chkmin(ans, dp[u][k + 1 - i] + dp[v][i] + h[u] - g[v] - siz[v]);
for (int i = 1; i <= k + 1; i++)
chkmin(dp[u][i], dp[v][i - 1] + g[u] - g[v] - siz[v]);
}
chkmin(ans, dp[u][k + 1] + h[u]);
}
void init() {
ans = inf;
for (int i = 1; i <= n; i++)
G[i].clear(), siz[i] = 0, g[i] = 0, h[i] = 0;
memset(dp, 0x3f, sizeof dp);
}
signed main() {
while (1) {
re(n), re(k);
if (n == 0 && k == 0)
return 0;
init();
for (int i = 1; i < n; i++) {
int u, v; re(u), re(v);
u++, v++;
G[u].pb(v);
G[v].pb(u);
}
dfs1(1, 0);
dfs2(1, 0);
pr(ans), puts("");
}
return 0;
}
标签:amp math 合并 inline 代码 公式 put 一个 get
原文地址:https://www.cnblogs.com/chhokmah/p/11838471.html