Educational Codeforces Round 36 (Rated for Div. 2)
F. Imbalance Value of a Tree
You are given a tree T consisting of n vertices. A number is written on each vertex; the number written on vertex i is ai. Let‘s denote the function I(x,?y) as the difference between maximum and minimum value of ai on a simple path connecting vertices x and y.
Your task is to calculate .
算出最大值最小值,然后减一下就是答案了。
如何计算最大值?
把点按点权升序排序,从小到大处理,从而保证当前处理的点的点权是最大的。当前点的点权乘上路径的条数就是这个点对答案的贡献了。注意要避免重复计算。具体细节看代码。
1 #include<cstdio> 2 #include<vector> 3 #include<cstring> 4 #include<iostream> 5 #include<algorithm> 6 #define fir first 7 #define sec second 8 #define pb push_back 9 using namespace std; 10 inline char nc() { 11 static char b[1<<17],*s=b,*t=b; 12 return s==t&&(t=(s=b)+fread(b,1,1<<17,stdin),s==t)?-1:*s++; 13 } 14 inline void read(int &x) { 15 char b = nc(); x = 0; 16 for (; !isdigit(b); b = nc()); 17 for (; isdigit(b); b = nc()) x = x * 10 + b - ‘0‘; 18 } 19 typedef long long ll; 20 const int N = 1000005; 21 int n, w[N]; 22 pair < int , int > a[N]; 23 int sz[N], fa[N], vis[N]; 24 vector < int > g[N]; 25 int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);} 26 void merge(int a, int b) { 27 if ((a = find(a)) != (b = find(b))) { 28 if (sz[a] < sz[b]) swap(a, b); 29 fa[b] = a; sz[a] += sz[b]; 30 } 31 } 32 int main() { 33 read(n); ll mx = 0, mn = 0; 34 for (int i = 1; i <= n; ++i) 35 read(w[i]), a[i] = make_pair(w[i], i); 36 for (int u, v, i = 1; i < n; ++i) 37 read(u), read(v), g[u].pb(v), g[v].pb(u); 38 sort(a + 1, a + 1 + n, less < pair < int , int > >()); 39 for (int i = 1; i <= n; ++i) 40 fa[i] = i, sz[i] = 1; 41 for (int i = 1; i <= n; ++i) { 42 int u = a[i].sec; ll cnt = sz[find(u)], tcnt; 43 for (int v, j = 0; j < g[u].size(); ++j) { 44 if (w[v=g[u][j]] <= w[u] && find(u) != find(v)) { 45 v = find(v); 46 tcnt = sz[v]; 47 mx += cnt * tcnt * w[u]; 48 cnt += tcnt; 49 merge(u, v); 50 } 51 } 52 // mx += w[u]; 53 } 54 sort(a + 1, a + 1 + n, greater < pair < int , int > >()); 55 for (int i = 1; i <= n; ++i) 56 fa[i] = i, sz[i] = 1; 57 for (int i = 1; i <= n; ++i) { 58 int u = a[i].sec; ll cnt = sz[find(u)], tcnt; 59 for (int v, j = 0; j < g[u].size(); ++j) { 60 if (w[v=g[u][j]] >= w[u] && find(u) != find(v)) { 61 v = find(v); 62 tcnt = sz[v]; 63 mn += cnt * tcnt * w[u]; 64 cnt += tcnt; 65 merge(u, v); 66 } 67 } 68 // mn += w[u]; 69 } 70 printf("%lld\n", mx - mn); 71 return 0; 72 }