跳转至

点分治

点分治的适用范围

点分治(starch)用于 \(O(n\log n)\) 统计树上特定结构(例如路径、连通块等)的信息(大小、数量、最值、颜色数等)。“特定结构”需要满足以下条件:

  • 由一些点组成;
  • 可以由其包含的任意一个点统计贡献(对于其内部的每个点,统计出的贡献都是相同的);
  • 任选一个节点作为根,这类结构中的任意一个都满足:要么包含根节点,要么完全包含于根节点的一个子树内。

一般,使用点分治统计路径和连通块是较为常见的。

点分治的算法流程如下:

  • 收到一个连通块,需要统计完全包含于当前连通块的贡献
  • get_sz get_rt:两遍 dfs 找到连通块的重心 \(rt\)
  • \(rt\) 标记为“禁止经过”;
  • calc:统计经过 \(rt\) 的路径(或连通块)的贡献;
  • 枚举 \(rt\) 的出边 \(rt\to v\),递归 solve(v);由于 \(rt\) 被标记为禁止经过,因此子树 \(v\) 不会跨过 \(u\) 去处理子树外面的部分;

注意到路径(或连通块)要么跨过 \(rt\) 节点,被本轮 calc 统计,要么完全包含于 \(rt\) 的一个子树,被子树递归统计。

\(size[u]\) 表示 solve(u)\(u\) 所在的连通块的大小。容易证明:\(\sum size[u]\)\(O(n\log n)\) 级别的。

因为每分治一层,每个节点所在的连通块大小就至少减少一半,因此最多分治 \(\log n\) 层,每个节点最多被统计 \(\log n\) 次。因此 \(\sum size[u]\)\(O(n\log n)\) 级别。

得益于这个结论,calc 的时间复杂度可以为 \(O(n)\) 甚至 \(O(n\log n)\)\(n\) 指当前连通块大小)。也就是说,我们可以在 calc 中遍历整个连通块,算出 \(rt\) 的贡献。

分治过程

在代码中,我们一般将求重心的步骤放到父节点处理。这里传入的 \(u\) 就是重心。

void solve(int u) {
    calc(u); // 统计经过 u 节点的路径产生的贡献
    vis[u] = 1; // 标记为已访问
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue; // 避开已访问的节点
        rt = 0; // 找到重心
        get_sz(v, u);
        get_root(v, u, sz[v]);
        solve(rt); // 从重心处调用子树 solve
    }
}
两遍 dfs 寻找重心

寻找重心的依据是重心的定义:重心是最大子树最小的一个节点。第一遍 dfs 求出所有子树大小,第二遍 dfs 求出每个节点的最大子树(包括祖先方向的“子树”),取最小值即是重心。

int rt;
int sz[N], mxs[N] = {INF};

// 获取子树 sz
void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue; // 注意避开已访问的节点
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

// 找到连通块重心
// tot 表示当前连通块的大小
void get_root(int u, int f, int tot) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_root(v, u, tot);
        // 统计子树最大值
        mxs[u] = max(mxs[u], sz[v]);
    }
    // 计算祖先方向的“子树”大小
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

P3806 【模板】点分治 1

题意

给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。

\(n\le 10^4\),时限 200ms\(n^2\) 过不去。

即统计长为 \(k\) 的路径数量。

考虑如何在 calc() 中处理经过 \(rt\) 节点的路径。我们只需要逐一考虑每棵子树,用桶数组 f 记录之前子树的信息,然后在遍历结束后把当前子树的信息和 f 合并考虑,算出贡献;最后将当前子树的信息也加入 f 数组。

代码
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1E4 + 10;
const int V = 1E7 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w;
    int next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, m;
int qr[N], ans[N];

int vis[N];

int rt;
int sz[N], mxs[N] = {INF};

void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_root(int u, int f, int tot) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_root(v, u, tot);
        mxs[u] = max(mxs[u], sz[v]);
    }
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

// 搭配 calc() 合并子树
int f[V];
int buf1[N], buf2[N], nn1, nn2;
void get_dis(int u, int fa, int sum) {
    if(sum > 1e7) return;
    for(int i = 1; i <= m; i++) {
        if(sum <= qr[i] && f[qr[i] - sum]) ans[i] = 1;
    }
    buf1[++nn1] = sum;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa || vis[v]) continue;
        get_dis(v, u, sum + w);
    }
}

// 处理经过 u 节点路径的贡献
void calc(int u) {
    f[0] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(vis[v]) continue;
        get_dis(v, u, w);
        // 清空 buffer
        while(nn1) {
            f[buf1[nn1]] = 1;
            buf2[++nn2] = buf1[nn1];
            --nn1;
        }
    }
    // 撤销
    while(nn2) {
        f[buf2[nn2--]] = 0;
    }
}

// 点分治
void solve(int u) {
    calc(u);
    vis[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue;
        rt = 0;
        get_sz(v, u);
        get_root(v, u, sz[v]);
        solve(rt);
    }
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    for(int i = 1; i <= m; i++) {
        cin >> qr[i];
    }

    rt = 0;
    get_root(1, 0, n);
    solve(rt);

    for(int i = 1; i <= m; i++) {
        cout << (ans[i] ? "AYE" : "NAY") << '\n';
    }

    return 0;
}

P4149 [IOI 2011] Race

题目大意

给一棵树,每条边有权。求一条简单路径,权值和等于 \(k\),且边的数量最小。

路径长度最值满足点分治的使用条件。考虑 calc(u) 应该如何统计经过节点 \(u\) 的路径。我们沿用模板题的第一种思路,使用 get_dis(v) 函数合并当前子树 \(v\) 和之前的所有子树。使用一个桶数组 f 记录\(u\) 距离为 \(i\) 的所有节点距离 \(u\) 的最少边数。记 \(sum\) 表示 \(u\) 到当前节点 \(x\) 的路径边权和,\(len\) 表示 \(u\) 走到当前节点经过的边的数量,统计子树时时使用 f[k - sum] + len 更新答案即可。

注意边权之和 \(sum\) 是没有限制的。因此我们需要在 get_dis() 中剪掉 \(sum>k\) 的节点,否则无法存储在桶数组里。

代码
#include<iostream>
using namespace std;
const int N = 2E5 + 10;
const int V = 1E6 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, k, ans = INF;
int vis[N], sz[N], mxs[N] = {INF};

void get_sz(int u, int f) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int f, int tot, int &rt) {
    mxs[u] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxs[u] = max(mxs[u], sz[v]);
    }
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

int f[V];
int buf1[N][2], buf2[N], n1, n2;

void get_dis(int u, int fa, int sw, int se) {
    if(sw > k) return;
    ans = min(ans, f[k - sw] + se);
    buf1[++n1][0] = sw;
    buf1[n1][1] = se;
    buf2[++n2] = sw;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa || vis[v]) continue;
        get_dis(v, u, sw + w, se + 1);
    }
}

void calc(int u) {
    n2 = 0;
    f[0] = 0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(vis[v]) continue;
        n1 = 0;
        get_dis(v, u, w, 1);
        for(int j = 1; j <= n1; j++) {
            f[buf1[j][0]] = min(f[buf1[j][0]], buf1[j][1]);
        }
    }
    for(int i = 1; i <= n2; i++) {
        f[buf2[i]] = INF;
    }
}

void solve(int u) {
    vis[u] = 1;
    calc(u);
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

int main() {

    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        ++u, ++v;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    for(int i = 0; i < V; i++) f[i] = INF;

    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, sz[1], rt);
    solve(rt);

    if(ans != INF) {
        cout << ans << endl;
    } else cout << -1 << endl;

    return 0;
}

P5311 [Ynoi2011] 成都七中

题目大意

给你一棵 \(n\) 个节点的树,每个节点有一种颜色,有 \(m\) 次查询操作。

查询操作给定参数 \(L,R,x\),需输出:

将树中编号在 \([L,R]\) 内的所有节点保留,\(x\) 所在连通块中颜色种类数。

每次查询操作独立。

我们发现一个点 \(u\) 能对 \(x\) 处的查询产生贡献,当且仅当 \(u\rightarrow x\) 的路径上节点编号的最小值 \(mn\) 满足 \(L\le mn\),最大值 \(mx\) 满足 \(mx\le R\)。因此我们可以从连通块的任意一个节点出发,统计连通块的大小。

我们简化问题,先考虑 \(x\) 固定的情况:以 \(x\) 为根,一遍 dfs 求出 \(x\) 到所有节点的路径的 \(mx\)\(mn\)。然后直接二维数颜色即可。

接下来考虑查询位置 \(x\) 不固定的情况。如果使用换根则无法维护路径最值。这里我们注意到,如果一个询问 \(x\) 和当前的根节点 \(rt\) “连通”(只保留询问的 \([L,R]\) 的节点情况下,两点处于同一连通块),则该询问等价于在当前的根节点上询问。因此我们遍历一遍整棵树,筛选出和根节点连通的询问,然后对它们进行处理。

否则这个询问所在的子树就和根节点以及其他子树分隔开了,成为一个较小的子问题,可以递归处理。考虑使用点分治,每次将根节点 \(rt\) 选为连通块的重心。总时间复杂度为 \(O(n\log n+q\log^2 n)\)

代码
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N = 1e5 + 10;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

struct Qr {
    int id, l, r;
};

struct Op {
    bool tp;
    int mn, mx, z;
    inline bool operator<(const Op &b) const {
        if(mx != b.mx) return mx < b.mx;
        return tp < b.tp;
    }
};

int n, m;
int col[N];

vector<Qr> q[N];
int ans[N];

vector<Op> op;

int sz[N], mxp[N] = {INF};
int vis[N];

void dfs(int u, int fa, int mn, int mx) {
    mn = min(mn, u);
    mx = max(mx, u);
    op.push_back({0, mn, mx, col[u]});
    for(Qr &o : q[u]) {
        if(~o.id && o.l <= mn && o.r >= mx) op.push_back({1, o.l, o.r, o.id}), o.id = -1;
    }
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        dfs(v, u, mn, mx);
    }
}

namespace BIT {
    int n = 1e5;
    int sum[N];
    inline int lowbit(int x) { return x & -x; }
    inline void add(int p, int v) {
        for(int i = n - p + 1; i <= n; i += lowbit(i)) sum[i] += v;
    }
    inline void clear(int p) {
        for(int i = n - p + 1; i <= n; i += lowbit(i)) sum[i] = 0;
    }
    inline int query(int p) {
        int res = 0;
        for(int i = n - p + 1; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
    inline int query(int l, int r) {
        return query(r) - query(l - 1);
    }
}

int pos[N];

void calc(int u) {
    op.clear();
    dfs(u, 0, INF, 0);
    sort(op.begin(), op.end());
    for(Op &o : op) {
        if(!o.tp) {
            if(!pos[o.z]) {
                BIT::add(o.mn, 1);
                pos[o.z] = o.mn;
            } else if(o.mn > pos[o.z]) {
                BIT::add(pos[o.z], -1);
                BIT::add(o.mn, 1);
                pos[o.z] = o.mn;
            }
        } else {
            ans[o.z] = BIT::query(o.mn);
        }
    }
    for(Op &o : op) {
        if(!o.tp) {
            BIT::clear(o.mn);
            pos[o.z] = 0;
        }
    }
}

void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int fa, int tot, int &rt) {
    mxp[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxp[u] = max(mxp[u], sz[v]);
    }
    mxp[u] = max(mxp[u], tot - sz[u]);
    if(mxp[u] < mxp[rt]) rt = u;
}

void solve(int u) {
    vis[u] = 1;
    calc(u);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> col[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for(int i = 1; i <= m; i++) {
        int l, r, x;
        cin >> l >> r >> x;
        q[x].push_back({i, l, r});
    }

    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, sz[1], rt);
    solve(rt);

    for(int i = 1; i <= m; i++) cout << ans[i] << '\n';

    return 0;
}

P3714 [BJOI2017] 树的难题

题意

给定一棵 \(n\) 个点的无根树,边有颜色,颜色有 \(m\) 种,第 \(i\) 种颜色的权值为 \(c_i\)

对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。

请你计算,经过边数在 \(l\)\(r\) 之间的所有简单路径中,路径权值的最大值。

\(n,m\le 2\times 10^5,\ |c_i|\le 10^4\)

将每个节点的出边按照颜色排好序,相同颜色和不同颜色的贡献分开统计,用树状数组维护路径长度即可。

时间复杂度 \(O(n\log^2 n)\)

代码
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstring>
#include<cassert>
#define int long long
using namespace std;
const int N = 2e5 + 10;
const int INF = 0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w;
    int next;
    inline bool operator<(const Edge &b) const {
        return w < b.w;
    }
} pool[2 * N];
int ne, head[N];

vector<Edge> adj[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, m, ql, qr, ans = -INF;
int c[N];

int vis[N];
int sz[N], mxs[N] = {INF};
void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int fa, int tot, int &rt) {
    mxs[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxs[u] = max(mxs[u], sz[v]);
    }
    mxs[u] = max(mxs[u], tot - sz[u]);
    if(mxs[u] < mxs[rt]) rt = u;
}

namespace SegT1 {
    int mx[4 * N];
    int vis[4 * N], buf[4 * N], top;
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) { mx[p] = max(mx[lc(p)], mx[rc(p)]); }
    void insert(int p, int l, int r, int q, int v) {
        if(!vis[p]) { vis[p] = 1; buf[++top] = p; }
        if(l == r) {
            mx[p] = max(mx[p], v);
            return;
        }
        int mid = (l + r) >> 1;
        if(q <= mid) insert(lc(p), l, mid, q, v);
        else insert(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) return mx[p];
        int mid = (l + r) >> 1;
        if(qr <= mid) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
    }
    void clear() {
        while(top) {
            vis[buf[top]] = 0;
            mx[buf[top]] = -INF;
            --top;
        }
    }
    void init() {
        for(int i = 0; i < 4 * N; i++) mx[i] = -INF;
    }
}

namespace SegT2 {
    int mx[4 * N];
    int vis[4 * N], buf[4 * N], top;
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) { mx[p] = max(mx[lc(p)], mx[rc(p)]); }
    void insert(int p, int l, int r, int q, int v) {
        if(l == r) {
            mx[p] = max(mx[p], v);
            if(!vis[p]) { vis[p] = 1; buf[++top] = p; }
            return;
        }
        int mid = (l + r) >> 1;
        if(q <= mid) insert(lc(p), l, mid, q, v);
        else insert(rc(p), mid + 1, r, q, v);
        push_up(p);
        if(!vis[p]) { vis[p] = 1; buf[++top] = p; }
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) return mx[p];
        int mid = (l + r) >> 1;
        if(qr <= mid) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
    }
    void clear() {
        while(top) {
            vis[buf[top]] = 0;
            mx[buf[top]] = -INF;
            --top;
        }
    }
    void init() {
        for(int i = 0; i < 4 * N; i++) mx[i] = -INF;
    }
}

int buf1[N], nb1;
int buf2[N], nb2;

int dep[N], f[N], col[N];
void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1;
    if(dep[u] > qr) return;
    buf1[++nb1] = u;
    buf2[++nb2] = u;
    f[u] = f[fa] + (col[u] != col[fa] ? c[col[u]] : 0);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(v == fa || vis[v]) continue;
        col[v] = w;
        dfs(v, u);
    }
}

void calc(int u) {
    SegT1::clear();
    SegT2::clear();
    SegT1::insert(1, 0, n, 0, 0);
    nb1 = nb2 = 0;
    int nc = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(vis[v]) continue;
        f[u] = dep[u] = 0;
        col[u] = 0;
        col[v] = w;
        if(w != nc) {
            for(int j = 1; j <= nb1; j++) {
                int i = buf1[j];
                SegT1::insert(1, 0, n, dep[i], f[i]);
            }
            nb1 = 0;
            SegT2::clear();
            nc = w;
        }
        dfs(v, u);
        for(int j = 1; j <= nb2; j++) {
            int i = buf2[j];
            ans = max(ans, SegT1::query(1, 0, n, max(0ll, ql - dep[i]), qr - dep[i]) + f[i]);
            ans = max(ans, SegT2::query(1, 0, n, max(0ll, ql - dep[i]), qr - dep[i]) + f[i] - c[w]);
        }
        for(int j = 1; j <= nb2; j++) {
            int i = buf2[j];
            SegT2::insert(1, 0, n, dep[i], f[i]);
        }
        nb2 = 0;
    }
}

void solve(int u) {
    vis[u] = 1;
    calc(u);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m >> ql >> qr;
    for(int i = 1; i <= m; i++) cin >> c[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v, c;
        cin >> u >> v >> c;
        adj[u].push_back({v, c});
        adj[v].push_back({u, c});
    }

    for(int i = 1; i <= n; i++) sort(adj[i].begin(), adj[i].end());

    for(int i = 1; i <= n; i++) {
        for(Edge &e : adj[i]) {
            addEdge(i, e.v, e.w);
        }
    }

    SegT1::init();
    SegT2::init();

    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, sz[1], rt);
    solve(rt);

    if(ans == -INF) cout << "-1\n";
    else cout << ans << endl;

    return 0;
}

P4886 快递员

其实不完全是点分治,只是利用重心的 Trick 来减小复杂度。

题意

给你一个包含 \(n\) 个节点的无根树,有 \(m\) 个点对 \((x_i,y_i)\),请你找出一个节点 \(s\),满足 \(\max\{dis(s,x_i)+dis(s,y_i)\}\) 最小。输出最小值。

\(n,m\le 10^5\)

考虑二分答案,发现没法 check,因此做不了。

尝试挖掘性质,考虑一个调整的过程,设当前节点为 \(s\),考虑 \(d(x_i)+d(y_i)\) 取到最大值那些的 \((x_i,y_i)\)

  • \(s\) 位于 \(x_i\to y_i\) 的路径上,那么无论如何答案也不可能更优了;
  • 若这些 \((x_i,y_i)\) 都位于 \(s\) 的同一棵子树内,那么将 \(s\) 向这棵子树内调整,才可能更优;
  • 若这些 \((x_i,y_i)\) 不都位于同一棵子树内,那么无论如何答案也不可能更优了。

然而,如果我们暴力调整,时间复杂度将达到 \(O(n^2)\)。但我们注意到,可能成为最优解的点的集合总是形如一个连通块,每次出现第二种情况,相当于把 \(s\) 的一堆子树都排除了,范围缩小到了一个更小的子树。这启发我们将 \(s\) 选在重心。发现这样做只需要 \(\log\) 次就能将范围缩小到一个点。

时间复杂度 \(O(m\log n)\)

代码
#include<iostream>
#include<vector>
using namespace std;
const int N = 1e5 + 10;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, w, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, m, ans = INF;
vector<int> a[N];

int d[N][2], s[N][2], now[N];

void clear() {
    for(int i = 1; i <= m; i++) now[i] = 0;
}

void get_dis(int u, int fa, int dis, int rt) {
    for(int i : a[u]) {
        d[i][now[i]] = dis;
        s[i][now[i]++] = rt;
    }
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(v == fa) continue;
        get_dis(v, u, dis + w, rt);
    }
}

int calc(int u) {
    clear();
    for(int i : a[u]) {
        d[i][now[i]] = 0;
        s[i][now[i]++] = u;
    }
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        get_dis(v, u, w, v);
    }
    int res = 0;
    for(int i = 1; i <= m; i++) {
        res = max(res, d[i][0] + d[i][1]);
    }
    ans = min(ans, res);
    int to = 0;
    for(int i = 1; i <= m; i++) {
        if(d[i][0] + d[i][1] != res) continue;
        if(s[i][0] == s[i][1]) {
            if(to) return 0;
            to = s[i][0];
        } else {
            return 0;
        }
    }
    return to;
}

int vis[N];
int sz[N], mxp[N] = {INF};

void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}
void get_rt(int u, int fa, int tot, int &rt) {
    mxp[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxp[u] = max(mxp[u], sz[v]);
    }
    mxp[u] = max(mxp[u], tot - sz[u]);
    if(mxp[u] < mxp[rt]) rt = u;
}

void solve(int u) {
    vis[u] = 1;
    int v = calc(u);
    if(!v || vis[v]) return;
    int rt = 0;
    get_sz(v, 0);
    get_rt(v, 0, sz[v], rt);
    solve(rt);
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }
    for(int i = 1; i <= m; i++) {
        int x, y;
        cin >> x >> y;
        a[x].push_back(i);
        a[y].push_back(i);
    }

    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, n, rt);
    solve(rt);

    cout << ans << '\n';

    return 0;
}

P7215 [JOISC 2020] 首都

题意

给定一棵树,点有颜色 \(col\)。你可以进行若干次操作,每次操作你可以将一个颜色的所有点的颜色都替换为另一种颜色。问最少替换多少次能产生一种颜色,使得这个颜色的所有节点连通。

\(n\le 2\times 10^5\)

我们有一种暴力的方法:枚举每个颜色为最终颜色,将所有影响它连通性的颜色都吞并,最终就是答案。考虑答案的连通块,它由若干个颜色对应的节点组合起来,不难发现由其中任何一种颜色开始,得到的答案都是相同的。这启发我们使用点分治。

考虑点分治,假设我们当前拿到了一个连通块,分治中心为 \(u\),我们需要求出以 \(col[u]\) 开始,达到连通的代价。首先注意到,如果 \(col[u]\) 存在位于连通块之外的部分,那么达到连通一定需要经过父亲,而父亲的贡献不应被考虑,因此遇到这种情况直接舍去。

否则,我们枚举 \(col[u]\) 出现的所有位置(肯定在连通块内),将它们到 \(u\) 路径上的所有颜色也都加入考虑集合,像 \(u\) 一样处理;如果考虑集合内的某个颜色也在连通块之外出现了,那么一样直接舍去该方案。

具体的,我们可以用一个队列维护考虑集合内颜色出现的位置;每次弹出队首,并向 \(u\) 方向(父亲方向)扩展一步,再将父亲颜色的所有位置也加入队列即可。

代码
#include<iostream>
#include<vector>
using namespace std;
const int N = 2e5 + 10;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, k, ans;
int c[N];

vector<int> pos[N];

int vis[N];
int sz[N], mxp[N] = {INF};

void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int fa, int tot, int &rt) {
    mxp[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxp[u] = max(mxp[u], sz[v]);
    }
    mxp[u] = max(mxp[u], tot - sz[u]);
    if(mxp[u] < mxp[rt]) rt = u;
}

int col[N], nc;

int que[N], hd, tl;
int vb[N], fa[N];

void fill(int u, int a0 = 0) {
    fa[u] = a0;
    col[u] = nc;
    vb[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v] || v == a0) continue;
        fill(v, u);
    }
}
void calc(int p) {
    hd = 1, tl = 0;
    for(int &i : pos[c[p]]) {
        if(col[i] != nc) return;
        que[++tl] = i;
        vb[i] = 1;
    }
    int res = 1;
    while(hd <= tl) {
        int u = que[hd++];
        int v = fa[u];
        if(!v || vb[v]) continue;
        ++res;
        for(int &i : pos[c[v]]) {
            if(col[i] != nc) return;
            que[++tl] = i;
            vb[i] = 1;
        }
    }
    ans = min(ans, res);
}

void solve(int u) {
    ++nc; fill(u);
    vis[u] = 1;
    calc(u);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

int main() {

    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for(int i = 1; i <= n; i++) cin >> c[i];
    for(int i = 1; i <= n; i++) pos[c[i]].push_back(i);

    ans = INF;
    int rt = 0;
    get_sz(1, 0);
    get_rt(1, 0, sz[1], rt);
    solve(rt);

    cout << ans - 1 << endl;

    return 0;
}