跳转至

树分治

点分治

点分治的适用范围

点分治用于 O(nlogn)O(n\log n) 统计树上特定结构(例如路径、连通块等)的信息(大小、数量、最值、颜色数等)。“特定结构”需要满足以下条件:

  • 由一些点组成;
  • 可以由其包含的任意一个点统计贡献(对于其内部的每个点,统计出的贡献都是相同的);
  • 任选一个节点作为根,这类结构中的任意一个都满足:要么包含根节点,要么完全包含于根节点的一个子树内。

算法过程

P3806 【模板】点分治 1

给定一棵有 nn 个点的树,询问树上距离为 kk 的点对是否存在。

n104n\le 10^4,时限 200msn2n^2 过不去。

即统计长为 kk 的路径数量。

所有的分治都有一个相同的思路:将当前问题 AA 划分为多个规模较小的子问题 B1,B2,B3B_1,B_2,B_3\cdots,且满足 A=B1+B2+B3+A=B_1+B_2+B_3+\cdots,然后 O(n)O(n) 处理跨过不同子问题之间的贡献 (i,j)w(Bi,Bj)\sum_{(i,j)}{w(B_i,B_j)},然后递归解决子问题 B1,B2,B3B_1,B_2,B_3\cdots

在树上,若要统计一个连通块内的答案,我们可以选择一个节点作为根节点 rtrtO(n)O(n) 遍历整个连通块,统计跨过根节点 rtrt 的路径,然后将 rtrt 标记为已访问(这样子树内的递归就不会跨过 rtrt),最后递归处理完全包含于子树内的路径。

处理贡献

分治的时间复杂度 O(nlogn)O(n\log n) 得益于每次划分出的子问题的规模不超过 n2\frac{n}{2}(主定理)。因此我们希望能够在树上合适的选择 rtrt 节点,使得每棵子树的大小都不超过 n2\frac{n}{2}

注意到这正是重心的定义。因此我们对于 rtrt 的每棵子树,都进行两遍 dfs,找到子树的重心并从重心处递归。而在主函数中则先找到整棵树的重心,然后从重心处调用分治函数 solve()

递归顺序

两遍 dfs 寻找重心

寻找重心的依据是重心的定义:重心是最大子树最小的一个节点。第一遍 dfs 求出所有子树大小,第二遍 dfs 求出每个节点的最大子树(包括祖先方向的“子树”),取最小值即是重心。

int rt;
int sz[N], mxs[N] = {INF};

// 获取子树 sz
void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue; // 注意避开已访问的节点
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

// 找到连通块重心
// tot 表示当前连通块的大小
void get_root(int u, int f, int tot) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_root(v, u, tot);
        // 统计子树最大值
        mxs[u] = max(mxs[u], sz[v]);
    }
    // 计算祖先方向的“子树”大小
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}
分治过程
void solve(int u) {
    calc(u); // 统计经过 u 节点的路径产生的贡献
    vis[u] = 1; // 标记为已访问
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue; // 避开已访问的节点
        rt = 0; // 找到重心
        get_sz(v, u);
        get_root(v, u, sz[v]);
        solve(rt); // 从重心处调用子树 solve
    }
}

至于如何在 calc() 中处理经过 rtrt 节点的路径,我们只需要逐个考虑每棵子树,用桶数组 f 统计之前所有子树对当前子树的贡献,然后在遍历结束后把子树的 dis 信息更新到 f 上即可。注意更新 f 的操作要在子树 vv 遍历结束之后执行,因此需要一个 buf 数组暂存修改。

代码
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1E4 + 10;
const int V = 1E7 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w;
    int next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, m;
int qr[N], ans[N];

int vis[N];

int rt;
int sz[N], mxs[N] = {INF};

// This does what you think it does
void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

// This does what you think it does
void get_root(int u, int f, int tot) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_root(v, u, tot);
        mxs[u] = max(mxs[u], sz[v]);
    }
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

// 搭配 calc() 合并子树
int f[V];
int buf1[N], buf2[N], nn1, nn2;
void get_dis(int u, int fa, int sum) {
    if(sum > 1e7) return;
    for(int i = 1; i <= m; i++) {
        if(sum <= qr[i] && f[qr[i] - sum]) ans[i] = 1;
    }
    buf1[++nn1] = sum;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa || vis[v]) continue;
        get_dis(v, u, sum + w);
    }
}

// 处理经过 u 节点路径的贡献
void calc(int u) {
    f[0] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(vis[v]) continue;
        get_dis(v, u, w);
        // 清空 buffer
        while(nn1) {
            f[buf1[nn1]] = 1;
            buf2[++nn2] = buf1[nn1];
            --nn1;
        }
    }
    // 撤销
    while(nn2) {
        f[buf2[nn2--]] = 0;
    }
}

// 点分治
void solve(int u) {
    calc(u);
    vis[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue;
        rt = 0;
        get_sz(v, u);
        get_root(v, u, sz[v]);
        solve(rt);
    }
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    for(int i = 1; i <= m; i++) {
        cin >> qr[i];
    }

    rt = 0;
    get_root(1, 0, n);
    solve(rt);

    for(int i = 1; i <= m; i++) {
        cout << (ans[i] ? "AYE" : "NAY") << '\n';
    }

    return 0;
}

P4149 [IOI 2011] Race

题目大意

给一棵树,每条边有权。求一条简单路径,权值和等于 kk,且边的数量最小。

路径长度最值满足点分治的使用条件。考虑 calc(u) 应该如何统计经过节点 uu 的路径。我们沿用模板题的第一种思路,使用 get_dis(v) 函数合并当前子树 vv 和之前的所有子树。使用一个桶数组 f 记录uu 距离为 ii 的所有节点距离 uu 的最少边数。记 sumsum 表示 uu 到当前节点 xx 的路径边权和,lenlen 表示 uu 走到当前节点经过的边的数量,统计子树时时使用 f[k - sum] + len 更新答案即可。

注意边权之和 sumsum 是没有限制的。因此我们需要在 get_dis() 中剪掉 sum>ksum>k 的节点,否则无法存储在桶数组里。

代码
#include<iostream>
using namespace std;
const int N = 2E5 + 10;
const int V = 1E6 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, k, ans = INF;
int vis[N], sz[N], mxs[N] = {INF};

void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int f, int tot, int &rt) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxs[u] = max(mxs[u], sz[v]);
    }
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

int f[V];
int buf1[N][2], buf2[N], n1, n2;

void get_dis(int u, int fa, int sw, int se) {
    if(sw > k) return;
    ans = min(ans, f[k - sw] + se);
    buf1[++n1][0] = sw;
    buf1[n1][1] = se;
    buf2[++n2] = sw;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa || vis[v]) continue;
        get_dis(v, u, sw + w, se + 1);
    }
}

void calc(int u) {
    n2 = 0;
    f[0] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(vis[v]) continue;
        n1 = 0;
        get_dis(v, u, w, 1);
        for(int j = 1; j <= n1; j++) {
            f[buf1[j][0]] = min(f[buf1[j][0]], buf1[j][1]);
        }
    }
    for(int i = 1; i <= n2; i++) {
        f[buf2[i]] = INF;
    }
}

void solve(int u) {
    vis[u] = 1;
    calc(u);
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

int main() {

    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        ++u, ++v;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    for(int i = 0; i < V; i++) f[i] = INF;

    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, sz[1], rt);
    solve(rt);

    if(ans != INF) {
        cout << ans << endl;
    } else cout << -1 << endl;

    return 0;
}

P5311 [Ynoi2011] 成都七中

题目大意

给你一棵 nn 个节点的树,每个节点有一种颜色,有 mm 次查询操作。

查询操作给定参数 L,R,xL,R,x,需输出:

将树中编号在 [L,R][L,R] 内的所有节点保留,xx 所在连通块中颜色种类数。

每次查询操作独立。

详见我的题解