标签:clu 答案 des printf void 最坏情况 == str 滑动
给定 \(n\) 个点的无根树。\(m\) 种颜色,每种颜色权值为 \(c_i\)。
定义树上路径权值为路径颜色序列,将其分为每一段极大的相同颜色序列,每一段颜色设为 \(i\),权值即 \(\sum c_i\)。
求边数在 \([l, r]\) 范围的简单路径中路径权值最大值。
一般来说边数在 \([l, r]\) 的一些树上信息很容易想到就是点分治了。
设 \(c(x)\) 从根到 \(x\) 的路径上第一条边的颜色,\(w(x)\) 为从根到 \(x\) 这条路径权值, \(d(x)\) 为根到 \(x\) 经过的边数。
考虑合并两个到根的链的点 \(x, y\):
首先先来解决第一种情况 \(c(x) \not= c(y)\):
显然两个点 \(x, y\),若 \(d(x) = d(y)\) 且 \(w(x) > w(y)\),那么 \(y\) 就没用了。因为把 \(y\) 换成 \(x\) 会更好。所以不妨存个桶,\(b_i\) 表示经过边数为 \(i\) 的点的最大 \(w\)。
那么对于 \(d(x) = i\) 来说,他寻求拼合的另外一条链的边数在 \([l - i, r - i]\),即我们要求这个区间的 \(\max(b_i)\)。发现当 \(i\) 从小到大循环的过程中,这个区间左右端点都是递减的,即一个滑动窗口问题。
然后考虑加入第二种情况:这个东西解决起来不难想,就是把根所连接的所有子树联通快按他们之间连的那条边排一下序,这样同一个颜色就肯定在一个区间了。所以不同颜色求一遍答案,相同颜色求一遍答案再整体减去重复的颜色即可。注意当遍历到新的颜色的时候,要把之前的相同颜色合并到不同颜色里。
滑动窗口这个东西可以用线段树 \(/\) 单调队列来做,但是你发现第一次查询是 \(O(r - l)\) 复杂度的,最坏情况下暴力扫可能会被卡成 \(O(n^2)\) 的,所以需要一种特殊技巧(传说叫单调队列按秩合并)。
由于这是一道码农题,所以我锻炼一下写一下两种做法。
维护两个权值线段树,一个为异色联通块,一个为同色联通块。
对于每个联通块,先查询。查询完整体塞到同色联通块里。
碰到新的颜色,即把同色的合并到异色里,线段树合并即可。
\(O(n \log ^ 2 n)\)
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 200005, INF = 2e9;
int n, m, L, R, c[N], ans = -INF;
int maxPart, rt, now, S, sz[N], d[N], maxDep;
int head[N], numE = 0, tot;
bool vis[N];
struct E{
int next, v, w;
} e[N << 1];
struct Son{
int x, c;
bool operator < (const Son &b) const {
return c < b.c;
}
} sons[N];
struct T{
int l, r, dat;
};
struct SegTree{
int rt0, rt1, dat[N], idx;
T t[N * 20];
void change(int &p, int l, int r, int x, int c) {
if (!p) t[p = ++idx] = (T) { 0, 0, -INF };
t[p].dat = max(t[p].dat, c);
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid) change(t[p].l, l, mid, x, c);
else change(t[p].r, mid + 1, r, x, c);
}
int query(int p, int l, int r, int x, int y) {
if (x > y) return -INF;
if (!p) return -INF;
if (x <= l && r <= y) return t[p].dat;
int mid = (l + r) >> 1, res = -INF;
if (x <= mid) res = max(res, query(t[p].l, l, mid, x, y));
if (mid < y) res = max(res, query(t[p].r, mid + 1, r, x, y));
return res;
}
// 把 q 合并到 p 上
void merge(int &p, int &q, int l, int r) {
if (!p) { p = q; return; }
if (!q) return;
t[p].dat = max(t[p].dat, t[q].dat);
if (l == r) return;
int mid = (l + r) >> 1;
merge(t[p].l, t[q].l, l, mid);
merge(t[p].r, t[q].r, mid + 1, r);
}
} t;
void inline add(int u, int v, int w) {
e[++numE] = (E) { head[u], v, w };
head[u] = numE;
}
void getRoot(int u, int last) {
sz[u] = 1;
int s = 0;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v] || v == last) continue;
getRoot(v, u);
sz[u] += sz[v];
s = max(s, sz[v]);
}
s = max(s, S - sz[u]);
if (s < maxPart) maxPart = s, rt = u;
}
void dfs(int u, int last, int col, int w, int dep) {
maxDep = max(maxDep, dep);
if (L <= dep && dep <= R) ans = max(ans, w);
ans = max(ans, w + t.query(t.rt0, 1, n, max(L - dep, 1), min(R - dep, n)));
ans = max(ans, w + t.query(t.rt1, 1, n, max(L - dep, 1), min(R - dep, n)) - now);
d[dep] = max(d[dep], w);
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == last || vis[v]) continue;
dfs(v, u, e[i].w, w + (col == e[i].w ? 0 : c[e[i].w]), dep + 1);
}
}
void solve(int x) {
if (S == 1) return;
maxPart = 2e9, getRoot(x, 0), vis[rt] = true;
tot = 0; t.idx = t.rt0 = t.rt1 = 0;
for (int i = head[rt]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v]) continue;
sons[++tot] = (Son) { v, e[i].w };
}
sort(sons + 1, sons + 1 + tot);
for (int i = 1; i <= tot; i++) {
maxDep = 0, now = c[sons[i].c];
dfs(sons[i].x, rt, sons[i].c, c[sons[i].c], 1);
for (int j = 1; j <= maxDep; j++) {
t.change(t.rt1, 1, n, j, d[j]);
d[j] = -INF;
}
if (i < tot && sons[i].c != sons[i + 1].c) {
t.merge(t.rt0, t.rt1, 1, n);
t.rt1 = 0;
}
}
for (int i = head[rt]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v]) continue;
S = sz[v], solve(v);
}
}
int main() {
scanf("%d%d%d%d", &n, &m, &L, &R);
for (int i = 1; i <= n; i++) d[i] = -2e9;
for (int i = 1; i <= m; i++) scanf("%d", c + i);
for (int i = 1, u, v, w; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
add(u, v, w); add(v, u, w);
}
S = n;
solve(1);
printf("%d\n", ans);
return 0;
}
这个按秩合并非常神奇。具体的排序顺序是这样的:
然后查询的时候,分类讨论:
由于有 \(sort\),还是 \(O(n \log^2 n)\)
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 200005, INF = 2e9;
int n, m, L, R, c[N], ans = -INF;
int maxPart, rt, now, S, sz[N], d[N], maxDep, mxDep[N];
int head[N], numE = 0, tot, val[N], nowVal[N], zDep, q[N];
bool vis[N];
struct E {
int next, v, w;
} e[N << 1];
struct Son {
int x, c, d;
bool operator<(const Son &b) const {
if (c != b.c)
return mxDep[c] < mxDep[b.c];
else
return d < b.d;
}
} sons[N];
struct T {
int l, r, dat;
};
void inline add(int u, int v, int w) {
e[++numE] = (E){ head[u], v, w };
head[u] = numE;
}
void getRoot(int u, int last) {
sz[u] = 1;
int s = 0;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v] || v == last)
continue;
getRoot(v, u);
sz[u] += sz[v];
s = max(s, sz[v]);
}
s = max(s, S - sz[u]);
if (s < maxPart)
maxPart = s, rt = u;
}
void dfs(int u, int last, int col, int w, int dep) {
maxDep = max(maxDep, dep);
if (L <= dep && dep <= R)
ans = max(ans, w);
d[dep] = max(d[dep], w);
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == last || vis[v])
continue;
dfs(v, u, e[i].w, w + (col == e[i].w ? 0 : c[e[i].w]), dep + 1);
}
}
void dfs0(int u, int last, int dep) {
maxDep = max(maxDep, dep);
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == last || vis[v])
continue;
dfs0(v, u, dep + 1);
}
}
int inline work(int a[], int len1, int b[], int len2) {
int res = -INF;
int hh = 0, tt = -1;
int l = max(1, L - 1), r = R - 1;
if (l > r)
return res;
len1 = min(len1, R - 1);
for (int i = min(r, len2); i >= l; i--) {
while (hh <= tt && b[q[tt]] < b[i]) tt--;
q[++tt] = i;
}
if (hh <= tt)
res = max(res, a[1] + b[q[hh]]);
for (int i = 2; i <= len1; i++) {
if (q[hh] == r)
hh++;
r--;
if (l > 1) {
--l;
while (hh <= tt && b[q[tt]] < b[l]) tt--;
q[++tt] = l;
}
if (hh <= tt)
res = max(res, a[i] + b[q[hh]]);
}
return res;
}
void solve(int x) {
if (S == 1)
return;
maxPart = 2e9, getRoot(x, 0), vis[rt] = true;
tot = 0;
for (int i = head[rt]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v])
continue;
maxDep = 0;
dfs0(v, rt, 1);
mxDep[e[i].w] = max(mxDep[e[i].w], maxDep);
sons[++tot] = (Son){ v, e[i].w, maxDep };
}
sort(sons + 1, sons + 1 + tot);
zDep = 0;
int nowDep = 0;
for (int i = 1; i <= tot; i++) {
maxDep = 0, now = c[sons[i].c];
mxDep[e[i].w] = 0;
dfs(sons[i].x, rt, sons[i].c, c[sons[i].c], 1);
zDep = max(zDep, maxDep);
nowDep = max(nowDep, maxDep);
ans = max(ans, work(d, maxDep, nowVal, nowDep) - c[sons[i].c]);
for (int j = 1; j <= maxDep; j++) {
nowVal[j] = max(nowVal[j], d[j]);
d[j] = -INF;
}
if (i == tot || sons[i].c != sons[i + 1].c) {
ans = max(ans, work(nowVal, nowDep, val, zDep));
for (int j = 1; j <= nowDep; j++) {
val[j] = max(val[j], nowVal[j]);
nowVal[j] = -INF;
}
nowDep = 0;
}
}
for (int i = 1; i <= zDep; i++) val[i] = -INF;
for (int i = head[rt]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v])
continue;
S = sz[v], solve(v);
}
}
int main() {
scanf("%d%d%d%d", &n, &m, &L, &R);
for (int i = 1; i <= n; i++) d[i] = nowVal[i] = val[i] = -2e9;
for (int i = 1; i <= m; i++) scanf("%d", c + i);
for (int i = 1, u, v, w; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}
S = n;
solve(1);
printf("%d\n", ans);
return 0;
}
标签:clu 答案 des printf void 最坏情况 == str 滑动
原文地址:https://www.cnblogs.com/dmoransky/p/12670423.html