标签:查询 nim 线段树 name def names 利用 include ++
略。
首先设\(f_{x, c}\)表示以\(x\)为根的子树内,最终取到了\(c\)的概率。可以列出转移方程(假设有两个孩子\(u, v\))
\[
\begin{aligned}
f_{x, c}
= & f_{u, c} * (p * v子树中最终权值小于c的概率 + (1 - p) * v子树中最终权值大于c的概率) \+ & f_{v, c} * (p * u子树中最终权值小于c的概率 + (1 - p) * u子树中最终权值大于c的概率) \\end{aligned}
\]
然后就不会了……
考虑线段树合并(我真的不知道这东西复杂度是对的),每个节点开一个权值线段树。
合并的时候就考虑第一个加数就好了(第二个类似),而第一个加数相当于先继承过来,再乘上一些东西。
这些东西本来是可以在线段树上\(\mathcal O(\log n)\)分治查询的,但是在合并的时候只要用类似cdq的技巧,边合并,边统计即可。
具体来说,就是,现在要合并两个树上节点\(x, y\)的相同线段树区间节点,然后可以个区间划分成左右两块,节点\(x\)的叫做\(x_l\)和\(x_r\),节点\(y\)的叫做\(y_l\)和\(y_r\)。
那么,\(x_l\)对\(y_r\),\(y_r\)对\(x_l\),\(x_r\)对\(y_l\),\(y_l\)对\(x_r\)的影响都是确定的(即严格大于或小于的关系),然后累加这个影响就行了。(具体见实现)
线段树节点上只要维护前缀积的区间和即可(前缀积就相当于继承的意思),并且利用动态开点线段树就可以保证正确的复杂度。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 300005, mod = 998244353, inv = 796898467;
int n, m, p, tot, ans, v[N], a[N], sz[N], ch[N][2];
int rt[N], lc[N * 20], rc[N * 20], s[N * 20], tag[N * 20];
void add (int x, int y) {
ch[x][1] = y, swap(ch[x][0], ch[x][1]);
}
void mul (int x, int t) {
s[x] = 1ll * s[x] * t % mod;
tag[x] = 1ll * tag[x] * t % mod;
}
void pushdown (int x) {
if (tag[x] != 1) {
mul(lc[x], tag[x]);
mul(rc[x], tag[x]);
tag[x] = 1;
}
}
void insert (int &o, int l, int r, int k) {
if (!o) {
o = ++tot;
}
s[o] = tag[o] = 1;
if (l == r) {
return;
}
int mid = (l + r) >> 1;
if (k <= mid) {
insert(lc[o], l, mid, k);
} else {
insert(rc[o], mid + 1, r, k);
}
}
int merge (int x, int y, ll sumx = 0, ll sumy = 0) {
if (!x || !y) {
x ? mul(x, sumy) : mul(y, sumx);
return x | y;
}
pushdown(x), pushdown(y);
int x0 = s[lc[x]], x1 = s[rc[x]], y0 = s[lc[y]], y1 = s[rc[y]];
lc[x] = merge(lc[x], lc[y], ((1ll + mod - p) * x1 + sumx) % mod, ((1ll + mod - p) * y1 + sumy) % mod);
rc[x] = merge(rc[x], rc[y], (1ll * p * x0 + sumx) % mod, (1ll * p * y0 + sumy) % mod);
s[x] = (s[lc[x]] + s[rc[x]]) % mod;
return x;
}
int dfs (int x) {
if (!ch[x][0]) {
insert(rt[x], 1, m, lower_bound(a + 1, a + m + 1, v[x]) - a);
return rt[x];
}
int rl = dfs(ch[x][0]);
if (!ch[x][1]) {
return rl;
}
int rr = dfs(ch[x][1]);
p = v[x];
return merge(rl, rr);
}
int calc (int o, int l, int r) {
if (l == r) {
return 1ll * l * a[l] % mod * s[o] % mod * s[o] % mod;
}
pushdown(o);
int mid = (l + r) >> 1;
return (calc(lc[o], l, mid) + calc(rc[o], mid + 1, r)) % mod;
}
int main () {
scanf("%d", &n);
for (int i = 1, x; i <= n; ++i) {
scanf("%d", &x);
if (x) {
add(x, i);
}
}
for (int i = 1; i <= n; ++i) {
scanf("%d", &v[i]);
if (!ch[i][0]) {
a[++m] = v[i];
} else {
v[i] = 1ll * v[i] * inv % mod;
}
}
sort(a + 1, a + m + 1);
printf("%lld\n",calc(dfs(1), 1, m));
return 0;
}
标签:查询 nim 线段树 name def names 利用 include ++
原文地址:https://www.cnblogs.com/psimonw/p/11445087.html