标签:
题意:求树上距离小于等于K的点对有多少个。
思路:这道题很容易想到树分治,对于当前的根节点来说,任意两个结点之间要么过根结点,要么在一棵子树中。
那么我们dfs一次求出所有点到根结点的距离,然后用o(n)的时间判定有多少节点对符合,(判断方法稍后说)
但是这样有很多在一颗子树中的节点对我们会求重复,我们需要减去在一棵子树中结点对小于等于k的数量,也就是说,我们这一步求的是在不同子树中距离小于等于k的节点对的个数。
接下来说判定方法,将每个点到根结点的距离排序,用两个指针指向队首和队尾,当结点距离和大于k时,队尾指针减一
否则更新答案并将队首指针加一。
这道题还有一个问题样例相当好,因为假如树退化成一条链时,如果我们以任意结点为根那么递归深度将达到O(n),将会tle
所以每次树分治时我们都要找到当前树的重心,并以重心为根继续分治。
#include<cstdio> #include<cstring> #include<cmath> #include<cstdlib> #include<iostream> #include<algorithm> #include<vector> #include<map> #include<queue> #include<stack> #include<string> #include<map> #include<set> #define eps 1e-6 #define LL long long #define pii (pair<int, int>) //#pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; const int maxn = 10000 + 100; const int INF = 0x3f3f3f3f; int n, root, k, cnt; //root记录重心 int des[maxn], bal[maxn], vis[maxn];//des记录以该节点为祖先的后代数,bal记录以该节点为根的最大子树的节点数 vector<int> G[maxn], L[maxn], dist; void get_root(int cur, int fa) { des[cur] = 1; bal[cur] = 0; int sz = G[cur].size(); for(int i = 0; i < sz; i++) { int u = G[cur][i]; if(vis[u] || u == fa) continue; get_root(u, cur); des[cur] += des[u]; bal[cur] = max(bal[cur], des[u]); } bal[cur] = max(bal[cur], cnt-des[cur]); if(bal[cur] < bal[root]) root = cur; } void dfs(int cur, int init, int fa) { cnt++; dist.push_back(init); int sz = G[cur].size(); for(int i = 0; i < sz; i++) { int u = G[cur][i]; if(vis[u] || u==fa) continue; dfs(u, init+L[cur][i], cur); } } int cal(int cur, int init) { int ans = 0; dist.clear(); dfs(cur, init, -1); sort(dist.begin(), dist.end()); int l = 0, r = dist.size()-1; while(l<r) { if(dist[l]+dist[r] <= k) ans += r-l, l++; else r--; } return ans; } int work(int cur) { int ans = 0; vis[cur] = 1; ans += cal(cur, 0); for(int i = 0; i < G[cur].size(); i++) { int u = G[cur][i]; if(vis[u]) continue; cnt = 0; ans -= cal(u, L[cur][i]); root = 0; get_root(u, -1); ans += work(root); } return ans; } void init() { for(int i = 1; i <= n; i++) { G[i].clear(); L[i].clear(); } memset(vis, 0, sizeof(vis)); } int main() { // freopen("input.txt", "r", stdin); while(scanf("%d%d", &n, &k) == 2 && n) { init(); for(int i = 1; i < n; i++) { int u, v, l; scanf("%d%d%d", &u, &v, &l); G[u].push_back(v); L[u].push_back(l); G[v].push_back(u); L[v].push_back(l); } cnt = n; root = 0; bal[0] = INF; get_root(1, -1); //cout << root << endl; int ans = work(root); printf("%d\n", ans); } return 0; }
版权声明:本文为博主原创文章,未经博主允许不得转载。
标签:
原文地址:http://blog.csdn.net/u014664226/article/details/47430867