点分治

原理

见《进阶指南》第230页。

模板题

AcWing-252-树open in new window

#include <iostream>
#include <cstring>
#include <algorithm>
#define size SiZe
using namespace std;

const int N = 10010;
struct edge {
    int to, next, w;
};
edge e[2 * N];
int idx, head[N];
int n, k, res;
int size[N], min_size, root;
bool del[N];
int ptr, a[N], b[N], cnt[N], d[N];

void add_edge (int u, int v, int w) {
    e[idx].w = w;
    e[idx].to = v;
    e[idx].next = head[u];
    head[u] = idx ++;
}
void find (int cur, int fa, int tot) {
    size[cur] = 1;
    int max_size = 0;
    for (int i = head[cur]; i != -1; i = e[i].next) {
        int to = e[i].to;
        if (to == fa || del[to]) continue;
        find (to, cur, tot);
        size[cur] += size[to];
        max_size = max(max_size, size[to]);
    }
    max_size = max(max_size, tot - size[cur]);
    if (max_size < min_size) {
        min_size = max_size;
        root = cur;
    }
}
void dfs (int cur, int fa) {
    for (int i = head[cur]; i != -1; i = e[i].next) {
        int to = e[i].to, w = e[i].w;
        if (to == fa || del[to]) continue;
        a[++ ptr] = to;
        b[to] = b[cur];
        ++ cnt[b[to]];
        d[to] = d[cur] + w;
        dfs(to, cur);
    }
}
bool cmp (const int& a, const int& b) {
    return d[a] < d[b];
}
void solve (int v, int s) {
    // 树的重心
    min_size = s;
    find(v, -1, s);

    ptr = 0;
    a[++ ptr] = root;
    b[root] = root;
    memset(cnt, 0, sizeof(cnt));
    ++ cnt[root];
    memset(d, 0, sizeof(d));
    for (int i = head[root]; i != -1; i = e[i].next) {
        int to = e[i].to, w = e[i].w;
        if (del[to]) continue;
        a[++ ptr] = to;
        b[to] = to;
        ++ cnt[to];
        d[to] = w;
        dfs(to, root);
    }

    sort(a + 1, a + ptr + 1, cmp);
    for (int i = 1, j = ptr; i <= ptr; ++ i) {
        -- cnt[b[a[i]]];
        while (j > i && d[a[i]] + d[a[j]] > k) {
            -- cnt[b[a[j]]];
            -- j;
        }
        if (i != j) { // 一定有 d[a[i]] + d[a[j]] ≤ k
            res += j - i - cnt[b[a[i]]];
        } else {
            break;
        }
    }

    del[root] = true; // 删除
    for (int i = head[root]; i != -1; i = e[i].next) {
        int to = e[i].to;
        if (del[to]) continue;
        solve(to, size[to]);
    }
}

int main () {
    while (true) {
        // ---------- 重置 ----------
        idx = 0, res = 0;
        memset(head, -1, sizeof(head));
        memset(del, 0, sizeof(del));

        scanf("%d%d", &n, &k);
        if (n == 0 && k == 0) break;
        for (int i = 1, x, y, z; i <= n - 1; ++ i) {
            scanf("%d%d%d", &x, &y, &z);
            add_edge(x, y, z);
            add_edge(y, x, z);
        }

        solve(1, n);
        printf("%d\n", res);
    }
    return 0;
}
最后修改于: