树上路径统计的点分治算法解析
概述
点分治是一种针对带权树结构中简单路径统计的高效分治算法,通过分治策略将原始O(n²)复杂度优化至O(nlogn)
核心问题: 在带权树中统计所有路径长度小于等于k的路径数量
该问题对应经典例题POJ1741。传统暴力解法需对每个节点进行DFS遍历,导致O(n²)时间复杂度。如何优化?
算法核心思想:
- 选取树中关键节点t,将问题拆分为两类:
- 经过节点t的路径
- 不经过节点t的路径
对于第二类问题,通过删除节点t后形成的多个子树递归处理。针对第一类问题,通过DFS收集各节点到t的距离,利用双指针技术快速统计满足条件的路径数。
注意:需排除同一子树内节点的路径,例如当k=7时,R→X→R→Y的路径实际长度为9,不符合简单路径定义。
优化策略:
- 每次选择树的重心作为分治节点,确保每次分割后的子树规模不超过原树的一半
- 通过分治层数控制在O(logn)级别,总时间复杂度为O(nlog²n)
以下是优化后的代码实现:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 20000 + 10;
struct Edge {
int to, weight;
};
vector<Edge> tree[MAXN];
int size[MAXN], depth[MAXN], dist[MAXN], result, total, center;
bool visited[MAXN];
void findCentroid(int u, int parent) {
size[u] = 1;
for (auto& e : tree[u]) {
if (e.to != parent && !visited[e.to]) {
findCentroid(e.to, u);
size[u] += size[e.to];
}
}
}
void calculateDistances(int u, int parent, int currentDist) {
dist[++total] = currentDist;
for (auto& e : tree[u]) {
if (e.to != parent && !visited[e.to]) {
calculateDistances(e.to, u, currentDist + e.weight);
}
}
}
int countValidPaths(int k) {
total = 0;
calculateDistances(center, -1, 0);
sort(dist + 1, dist + total + 1);
int left = 1, right = total, count = 0;
while (left <= right) {
while (right && dist[left] + dist[right] > k) right--;
if (left > right) break;
count += right - left + 1;
left++;
}
return count;
}
void divideAndConquer(int u) {
findCentroid(u, -1);
visited[u] = true;
result += countValidPaths(k);
for (auto& e : tree[u]) {
if (!visited[e.to]) {
result -= countValidPaths(k - e.weight);
int subSize = size[e.to];
findCentroid(e.to, -1);
divideAndConquer(center);
}
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
while (cin >> n >> m && n != 0) {
for (int i = 1; i <= n; ++i) tree[i].clear();
for (int i = 0; i < n - 1; ++i) {
int a, b, w; cin >> a >> b >> w;
tree[a].push_back({b, w});
tree[b].push_back({a, w});
}
result = 0;
findCentroid(1, -1);
divideAndConquer(center);
cout << result << '\n';
}
return 0;
}
例题实现:
树路径存在性检测
问题描述 给定n个节点的树,判断是否存在距离为k的点对
输入格式 第一行两个数n,m 接下来n-1行描述树结构 m行查询k值
优化方案 采用点分治算法,通过预处理距离数组并使用二分查找提高查询效率。注意避免重复计算同一子树内的路径。
代码实现:
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 20000 + 10;
struct Edge {
int to, weight;
};
vector<Edge> tree[MAXN];
int size[MAXN], depth[MAXN], dist[MAXN], result, total, center;
bool visited[MAXN];
int queries[MAXN], answer[MAXN];
void findCentroid(int u, int parent) {
size[u] = 1;
for (auto& e : tree[u]) {
if (e.to != parent && !visited[e.to]) {
findCentroid(e.to, u);
size[u] += size[e.to];
}
}
}
void calculateDistances(int u, int parent, int currentDist) {
dist[++total] = currentDist;
for (auto& e : tree[u]) {
if (e.to != parent && !visited[e.to]) {
calculateDistances(e.to, u, currentDist + e.weight);
}
}
}
void processQueries(int k, int& count) {
total = 0;
calculateDistances(center, -1, 0);
sort(dist + 1, dist + total + 1);
for (int i = 1; i <= total; ++i) {
int target = k - dist[i];
auto it = lower_bound(dist + i + 1, dist + total + 1, target);
if (it != dist + total + 1 && *it == target) {
count++;
}
}
}
void divideAndConquer(int u) {
findCentroid(u, -1);
visited[u] = true;
for (int i = 0; i < m; ++i) {
if (queries[i] >= 0) {
int cnt = 0;
processQueries(queries[i], cnt);
answer[i] += cnt;
}
}
for (auto& e : tree[u]) {
if (!visited[e.to]) {
int subSize = size[e.to];
findCentroid(e.to, -1);
divideAndConquer(center);
}
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 0; i < n - 1; ++i) {
int a, b, w; cin >> a >> b >> w;
tree[a].push_back({b, w});
tree[b].push_back({a, w});
}
for (int i = 0; i < m; ++i) cin >> queries[i];
findCentroid(1, -1);
divideAndConquer(center);
for (int i = 0; i < m; ++i) {
cout << (answer[i] ? "AYE" : "NAY") << '\n';
}
return 0;
}
树距离总和计算
问题描述 计算树中所有相同颜色节点对的最短路径长度总和
解决方案 采用点分治算法,维护颜色对应的距离总和和数量统计。在遍历过程中动态更新全局统计信息。
代码实现:
#include
#define int long long
using namespace std;
const int MAXN = 200000 + 10;
vector tree[MAXN];
int size[MAXN], depth[MAXN], dist[MAXN], result, total, center;
bool visited[MAXN];
int color[MAXN], sumDist[MAXN], countColor[MAXN];
void findCentroid(int u, int parent) {
size[u] = 1;
for (auto v : tree[u]) {
if (v != parent && !visited[v]) {
findCentroid(v, u);
size[u] += size[v];
}
}
}
void calculateDistances(int u, int parent, int currentDist) {
dist[++total] = currentDist;
for (auto v : tree[u]) {
if (v != parent && !visited[v]) {
calculateDistances(v, u, currentDist + 1);
}
}
}
void processNodes(int u) {
total = 0;
calculateDistances(u, -1, 0);
for (int i = 1; i <= total; ++i) {
int c = color[dist[i]];
result += sumDist[c] + depth[dist[i]] * countColor[c];
}
}
void divideAndConquer(int u) {
findCentroid(u, -1);
visited[u] = true;
processNodes(u);
for (auto v : tree[u]) {
if (!visited[v]) {
divideAndConquer(v);
}
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin >> n;
for (int i = 0; i < n - 1; ++i) {
int a, b; cin >> a >> b;
tree[a].push_back(b);
tree[b].push_back(a);
}
for (int i = 1; i <= n; ++i) cin >> color[i];
findCentroid(1, -1);
divideAndConquer(center);
cout << result << '\n';
return 0;
}