标签:tmp add lld efi space stdout swa bsp dde
Code:
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cmath>
#include <map>
#define N 100003
#define ll long long
#define setIO(s) freopen(s".in", "r" , stdin) , freopen(s".out", "w" , stdout)
using namespace std;
namespace IO
{
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int readint() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;}
ll readll() {ll x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;}
};
vector <int> G[N], ty[N], node;
int n , edges, tim, toop;
ll dis[N], depth[N];
int col[N], tax[N], id[N], A[N], size[N], S[N];
int hd[N], nex[N << 1], to[N << 1], top[N], dfn[N], fa[N], dep[N], son[N], siz[N];
ll val[N << 1];
bool cmp(int a, int b)
{
return dfn[a] < dfn[b];
}
inline void addedge(int u, int v, int c)
{
nex[++ edges] = hd[u], hd[u] = edges, to[edges] = v, val[edges] = 1ll * c;
}
void dfs1(int u, int ff)
{
int i, v;
fa[u] = ff, dep[u] = dep[ff] + 1, dfn[u] = ++ tim, siz[u] = 1;
for(i = hd[u] ; i ; i = nex[i])
{
v = to[i];
if(v == ff) continue;
depth[v] = depth[u] + 1ll * val[i], dfs1(v, u), siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp)
{
top[u] = tp;
if(son[u]) dfs2(son[u], tp);
for(int i = hd[u] ; i ; i = nex[i])
{
int v = to[i];
if(v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
inline int LCA(int x, int y)
{
while(top[x] ^ top[y])
{
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
}
return dep[x] < dep[y] ? x : y;
}
inline ll Dis(int x, int y)
{
return depth[x] + depth[y] - (depth[LCA(x, y)] << 1);
}
void solve1(int u, int ff, int cur)
{
size[u] = (col[u] == cur), dis[u] = 0;
for(int i = hd[u] ; i ; i = nex[i])
{
int v = to[i];
if(v == ff) continue;
solve1(v, u, cur),size[u] += size[v], dis[u] += (dis[v] + 1ll * size[v] * val[i]);
}
}
void solve(int u, int ff, int cur)
{
for(int i = hd[u] ; i ; i = nex[i])
{
int v = to[i];
if(v == ff) continue;
dis[v] += (dis[u] - dis[v] - 1ll * size[v] * val[i] + 1ll * (tax[cur] - size[v]) * val[i]);
solve(v, u, cur);
}
}
inline void addvir(int u, int v)
{
G[u].push_back(v);
}
inline void insert(int x)
{
if(toop < 2)
{
S[++ toop] = x;
return ;
}
int lca = LCA(x, S[toop]);
if(lca != S[toop])
{
while(toop > 1 && dep[S[toop - 1]] >= dep[lca]) addvir(S[toop - 1], S[toop]),-- toop;
if(S[toop] != lca) addvir(lca, S[toop]), S[toop] = lca;
}
S[++ toop] = x;
}
void pre(int u, int ff, int cur)
{
size[u] = (col[u] == cur), dis[u] = 0;
for(int i = 0; i < G[u].size(); ++ i)
{
int v = G[u][i];
pre(v, u, cur), size[u] += size[v], dis[u] += dis[v] + 1ll * size[v] * Dis(v, u);
}
}
void work(int u, int ff, int cur)
{
for(int i = 0; i < G[u].size() ; ++ i)
{
int v = G[u][i];
dis[v] += (dis[u] - dis[v] - 1ll * size[v] * Dis(u, v) + 1ll * (tax[cur] - size[v]) * Dis(u, v));
work(v, u, cur);
}
}
void clear(int u)
{
size[u] = dis[u] = 0;
for(int i = 0; i < G[u].size(); ++ i) clear(G[u][i]) ;
G[u].clear();
}
struct Node
{
int a, b;
}ask[N];
vector < int > P[N];
vector < ll > answer[N];
int point[N];
int main()
{
using namespace IO;
// setIO("input");
int i , j, idx = 0, m, Q;
n = readint();
m = sqrt(n);
for(i = 1; i <= n ; ++ i) col[i] = readint(), ++tax[col[i]], ty[col[i]].push_back(i);
for(i = 1; i < n ; ++ i)
{
int a = readint(), b = readint(), c = readint();
addedge(a, b, c), addedge(b, a, c);
}
dfs1(1, 0), dfs2(1, 1);
for(i = 1; i <= n ; ++ i) if(tax[i] >= m) id[i] = ++idx;
Q = readint();
for(i = 1; i <= Q; ++ i)
{
ask[i].a = readint(), ask[i].b = readint();
if(tax[ask[i].a] < tax[ask[i].b]) swap(ask[i].a, ask[i].b);
if(tax[ask[i].a] >= m) P[ask[i].a].push_back(ask[i].b);
}
for(i = 1; i <= n ; ++ i)
{
if(tax[i] >= m)
{
solve1(1, 0, i), solve(1, 0, i);
for(j = 0 ; j < P[i].size() ; ++ j)
{
int cur = P[i][j];
ll re = 0;
for(int k = 0; k < ty[cur].size(); ++ k)
{
re += dis[ty[cur][k]];
}
answer[i].push_back(re);
}
}
}
for(int cas = 1; cas <= Q; ++ cas)
{
int a, b;
a = ask[cas].a, b = ask[cas].b;
if(tax[a] >= m) printf("%lld\n", a == b ? answer[a][point[a] ++ ] / 2 : answer[a][point[a] ++ ]);
else
{
int tmp = 0;
ll re = 0;
for(i = 0; i < ty[a].size(); ++ i) A[++ tmp] = ty[a][i];
for(i = 0; i < ty[b].size(); ++ i) A[++ tmp] = ty[b][i];
sort(A + 1, A + 1 + tmp, cmp);
tmp = unique(A + 1, A + 1 + tmp) - (A + 1);
toop = 0;
if(A[1] != 1) S[++ toop] = 1;
for(i = 1 ; i <= tmp ; ++ i) insert(A[i]);
while(toop > 1) addvir(S[toop - 1], S[toop]), --toop;
pre(1, 0, b), work(1, 0, b);
for(i = 0; i < ty[a].size(); ++ i) re += dis[ty[a][i]];
printf("%lld\n", a == b ? re / 2 : re);
}
}
return 0;
}
newcoder 79F 小H和圣诞树 换根 DP + 根号分治
标签:tmp add lld efi space stdout swa bsp dde
原文地址:https://www.cnblogs.com/guangheli/p/11367477.html