根据题意答案显然分为两部分:b为a的祖先和b在a与c之间,前一部分答案等于到a距离不大于k的祖先数,后一部分等于a的(子树中除a以外的到a距离不大于k的点)的(子树大小减一)的和。
#include <cstdio>
#include <iostream>
#define mid ((L + R) >> 1)
#define travel(x, i) for (int i = fir[x]; i; i = e[i].nxt)
using namespace std;
typedef long long LL;
const int N = 3e5 + 5;
struct edge {
int nxt, to;
} e[N << 1];
int fir[N], cnt = 0;
int sz[N], dep[N], le[N], ri[N], clo, D = 0;
inline void addedge(int x, int y) {
e[++ cnt] = (edge){fir[x], y};
fir[x] = cnt;
}
struct node {
LL key;
node *lc, *rc;
node() {key = 0; lc = rc = NULL;}
} *rt[N];
inline node* build(node *p, int L, int R, int pos, int val) {
node *tmp = new node;
if (p != NULL) tmp -> key = p -> key;
else tmp -> key = 0;
tmp -> key += val;
if (L == R) return tmp;
if (pos <= mid) {
if (p != NULL) {
tmp -> rc = p -> rc;
tmp -> lc = build(p -> lc, L, mid, pos, val);
}
else tmp -> lc = build(NULL, L, mid, pos, val);
}
else {
if (p != NULL) {
tmp -> lc = p -> lc;
tmp -> rc = build(p -> rc, mid + 1, R, pos, val);
}
else tmp -> rc = build(NULL, mid + 1, R, pos, val);
}
return tmp;
}
inline void dfs(int x, int pa) {
le[x] = ++ clo;
dep[x] = dep[pa] + 1;
D = max(D, dep[x]);
sz[x] = 1;
travel(x, i)
if (e[i].to != pa) {
dfs(e[i].to, x);
sz[x] += sz[e[i].to];
}
ri[x] = clo;
}
inline void dfs2(int x, int pa) {
rt[le[x]] = build(rt[le[x] - 1], 1, D, dep[x], sz[x] - 1);
travel(x, i)
if (e[i].to != pa) dfs2(e[i].to, x);
}
inline LL query(node *p, int L, int R, int l, int r) {
if (L == l && R == r) {
if (p != NULL) return p -> key;
else return 0;
}
if (r <= mid) {
if (p != NULL) return query(p -> lc, L, mid, l, r);
else return 0;
}
else if (l > mid) {
if (p != NULL) return query(p -> rc, mid + 1, R, l, r);
else return 0;
}
else {
if (p != NULL) return query(p -> lc, L, mid, l, mid) + query(p -> rc, mid + 1, R, mid + 1, r);
else return 0;
}
}
inline void read(int &x) {
char ch;
while (!isdigit(ch = getchar()));
x = 0;
do {x = (x << 1) + (x << 3) + ch - ‘0‘;} while (isdigit(ch = getchar()));
}
inline void write(LL x){
LL y = 10, len = 1;
while (y <= x) {y *= 10; len ++;}
while (len --) {y /= 10; putchar(x / y + 48); x %= y;}
}
int main() {
int n, q, k, a;
read(n); read(q);
for (int i = 1, x, y; i < n; i ++) {
read(x); read(y);
addedge(x, y); addedge(y, x);
}
dfs(1, 0);
dfs2(1, 0);
LL ans;
while (q --) {
read(a); read(k);
ans = 1LL * min(dep[a] - 1, k) * (sz[a] - 1);
if (le[a] != ri[a]) ans += query(rt[ri[a]], 1, D, dep[a] + 1, min(dep[a] + k, D)) - query(rt[le[a]], 1, D, dep[a] + 1, min(dep[a] + k, D));
write(ans); putchar(‘\n‘);
}
return 0;
}