最近公共祖先

洛谷-P3379-【模板】最近公共祖先(LCA)open in new window

树链剖分

代码

#include <iostream>
#include <cstring>
#define size SiZe
using namespace std;

const int N = 500010;
struct edge {
    int to, next;
};
edge e[2 * N];
int idx, head[N];
int n, m, root;
int dep[N], fa[N], size[N];
int heavy_son[N], top[N]; // 结点所在重链的链头结点

void add_edge (int u, int v) {
    e[idx].to = v;
    e[idx].next = head[u];
    head[u] = idx ++;
}
void dfs1 (int cur, int father) {
    size[cur] = 1;
    for (int i = head[cur]; i != -1; i = e[i].next) {
        int to = e[i].to;
        if (to == father) continue;
        dep[to] = dep[cur] + 1;
        fa[to] = cur;
        dfs1(to, cur);
        size[cur] += size[to];
        if (size[to] > size[heavy_son[cur]])
            heavy_son[cur] = to;
    }
}
void dfs2 (int cur, int top_node) {
    top[cur] = top_node;
    if (heavy_son[cur]) dfs2(heavy_son[cur], top_node);
    for (int i = head[cur]; i != -1; i = e[i].next) {
        int to = e[i].to;
        if (to == fa[cur] || to == heavy_son[cur]) continue;
        dfs2(to, to);
    }
}
int lca (int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    return x;
}

int main () {
    memset(head, -1, sizeof(head));
    scanf("%d%d%d", &n, &m, &root);
    for (int i = 1, u, v; i <= n - 1; ++ i) {
        scanf("%d%d", &u, &v);
        add_edge(u, v);
        add_edge(v, u);
    }

    dfs1(root, -1);
    dfs2(root, root);
    for (int i = 1, x, y; i <= m; ++ i) {
        scanf("%d%d", &x, &y);
        printf("%d\n", lca(x, y));
    }
    return 0;
}

树上倍增

原理

见《进阶指南》第375页。

代码

#include <iostream>
#include <cstring>
#include <queue>
#include <cmath>
using namespace std;

const int N = 500010;
struct edge {
    int to, next;
};
edge e[2 * N];
int idx, head[N];
int n, m, root;
int k, dep[N], fa[N][20];

void add_edge (int u, int v) {
    e[idx].to = v;
    e[idx].next = head[u];
    head[u] = idx ++;
}
void bfs () {
    dep[root] = 1;
    queue<int> q;
    q.push(root);
    while (!q.empty()) {
        int cur = q.front();
        q.pop();
        for (int i = head[cur]; i != -1; i = e[i].next) {
            int to = e[i].to;
            if (dep[to] == 0) {
                dep[to] = dep[cur] + 1;
                fa[to][0] = cur;
                for (int j = 1; j <= k; ++ j)
                    fa[to][j] = fa[fa[to][j - 1]][j - 1];
                q.push(to);
            }
        }
    }
}
int lca (int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    for (int i = k; i >= 0; -- i)
        if (dep[fa[y][i]] >= dep[x])
            y = fa[y][i];
    if (x == y) return x;
    for (int i = k; i >= 0; -- i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

int main () {
    memset(head, -1, sizeof(head));
    scanf("%d%d%d", &n, &m, &root);
    for (int i = 1, u, v; i <= n - 1; ++ i) {
        scanf("%d%d", &u, &v);
        add_edge(u, v);
        add_edge(v, u);
    }

    k = log(n) / log(2) + 1;
    bfs();
    for (int i = 1, x, y; i <= m; ++ i) {
        scanf("%d%d", &x, &y);
        printf("%d\n", lca(x, y));
    }
    return 0;
}
最后修改于: