跳转至

边分治

我们可以将一条边选为树分治的分治中心,这样形成的边分树就是一棵二叉树了。二叉树只有两个儿子,贡献关系极大的简化了,因此可以处理很多复杂的问题。

三度化

边分治可以直接被菊花图卡成 \(O(n^2)\)。注意到二叉树一定不会出现这种情况,因此我们需要先把原树 \(T_1\) 转化为等价的另一棵树 \(T_2\),满足节点之间两两的距离不变。

具体的,对于包含 \(>2\) 个儿子的节点,我们将它的第一个儿子接到它的左儿子位置上,然后新建一个虚点接到它的右儿子位置上,再把剩下的所有儿子都交给这个虚点处理。注意到这样处理之后节点数量仍然是 \(O(n)\) 的。时间复杂度也是 \(O(n)\)

求重心边

在边分治中,重心边的定义为:左右两棵子树较大者最小的边。这也容易使用两遍 dfs 求出。

P3806 【模板】点分治 1

代码
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e4 + 10;
const int V = 1e7;

struct Edge {
    int v, w, next;
};
Edge pool[2 * N];
int ne = 1, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, m;
int q[N], ans[N];

namespace T2 {

    Edge pool[4 * N];
    int ne = 1, head[2 * N];
    void addEdge(int u, int v, int w) {
        pool[++ne] = {v, w, head[u]};
        head[u] = ne;
    }

    int nn;

    int vis[4 * N];
    int sz[2 * N];

    void get_sz(int u, int fa) {
        sz[u] = 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[e]) continue;
            get_sz(v, u);
            sz[u] += sz[v];
        }
    }

    void get_rt(int u, int fa, int tot, int &rt, int &re) {
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[e]) continue;
            get_rt(v, u, tot, rt, re);
            if(rt == v) re = e;
        }
        if(max(sz[u], tot - sz[u]) < max(sz[rt], tot - sz[rt])) rt = u;
    }

    int buf1[2 * N], buf2[2 * N], top1, top2;

    void get_dis(int u, int fa, int buf[], int &top, int sum) {
        if(sum > V) return;
        if(u <= n) buf[++top] = sum;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = pool[e].w;
            if(v == fa || vis[e]) continue;
            get_dis(v, u, buf, top, sum + w);
        }
    }

    void solve(int u) {
        int rt = 0, re = 0;
        get_sz(u, 0);
        get_rt(u, 0, sz[u], rt, re);
        vis[re] = vis[re ^ 1] = 1;
        top1 = top2 = 0;
        u = pool[re].v;
        int v = pool[re ^ 1].v;
        get_dis(u, 0, buf1, top1, 0);
        get_dis(v, 0, buf2, top2, 0);
        sort(buf1 + 1, buf1 + 1 + top1);
        sort(buf2 + 1, buf2 + 1 + top2);
        for(int t = 1; t <= m; t++) {
            if(ans[t]) continue;
            int len = q[t];
            for(int i = 1, j = top2; i <= top1; i++) {
                while(j > 1 && buf1[i] + buf2[j] + pool[re].w > len) --j;
                if(buf1[i] + buf2[j] + pool[re].w == len) { ans[t] = 1; break; }
            }
        }
        for(int e = head[u]; e; e = pool[e].next) {
            int vv = pool[e].v;
            if(vis[e]) continue;
            solve(vv);
        }
        for(int e = head[v]; e; e = pool[e].next) {
            int vv = pool[e].v;
            if(vis[e]) continue;
            solve(vv);
        }
    }

    void work() {
        solve(1);
    }

}

void build(int u, int fa) {
    int last = u;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(v == fa) continue;
        T2::addEdge(last, v, w);
        T2::addEdge(v, last, w);
        T2::addEdge(last, ++T2::nn, 0);
        T2::addEdge(T2::nn, last, 0);
        last = T2::nn;
    }
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa) continue;
        build(v, u);
    }
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    T2::nn = n;
    build(1, 0);

    for(int i = 1; i <= m; i++) cin >> q[i];
    T2::work();
    for(int i = 1; i <= m; i++) cout << (ans[i] ? "AYE\n" : "NAY\n");

    return 0;
}

P4220 [WC2018] 通道

给定包含 \(n\) 个节点的 \(3\) 棵树 \(T_1,T_2,T_3\),边有边权,请你求出 \(\max_{(x,y)}\bigl\{dis_1(x,y)+dis_2(x,y)+dis_3(x,y)\big\}\)

\(n\le 10^5,\ w\le 10^{12}\)

我们似乎无法将 \(3\) 棵树简单的合并,因为两个节点的 \(dis\) 还有它们 \(\operatorname{lca}\) 的贡献参与。

我们可以枚举 \((x,y)\) 在一棵树 \(T_2\) 上的 \(\operatorname{lca}\),记为节点 \(t\),这样这棵树的 \(dis\) 就只和 \(x,y\) 单独有关了。然后,根据直径的点集合并性,我们可以求出 \(t\) 的每棵子树对应的点集在另一棵树 \(T_3\) 上的直径,然后再在 \(t\) 处进行合并即可。

考虑如何处理剩下的一棵树 \(T_1\)。两棵树的直径显然不能直接叠加,因此仍然考虑枚举 \(\operatorname{lca}\)。我们显然不能 \(O(n^2)\) 的枚举 \(T_1,T_2\) 两棵树的 \(\operatorname{lca}\)。注意到我们枚举 \(T_1\) 上的 \(lca\) 节点 \(p\) 之后,\(x,y\) 只能在 \(p\) 的子树内,这也进一步减小了 \(x,y\) 在其他两棵树上的范围。

考虑此时 \(x,y\)\(T_2\) 上的 \(\operatorname{lca}\) 节点 \(t\) 有什么性质。注意到它一定是在 \(T_1\)\(p\) 子树形成的点集 \(S\) 中,任取两个点,在 \(T_2\) 上可能的 \(\operatorname{lca}\),而根据虚树的理论,这样的节点数量是 \(O(|S|)\) 级别的。这也启发我们使用虚树。

由于虚树的复杂度和点集 \(S\) 的大小直接挂钩,因此我们要想办法减少 \(\sum_{p}|S|\)。这启发我们使用树分治,因为它的 \(\sum_{p}|S|\) 有保证。

考虑点分治,然而这要求 \(x,y\) 必须分布在 \(p\) 的不同子树内。这等价于给每个点又赋予了一个颜色,因此后两棵树的时间复杂度会乘以 \(son[p]^2\)

为了减少 \(son[p]\),考虑边分治,这样每个点的颜色只能是黑白两种中的一种。对于后两棵树,我们只需记录黑点内部的直径、白点内部的直径,然后进行合并即可得到跨越黑白点的直径。

代码
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
#include<ctype.h>
#include<cstdio>
#include<string>
#include<vector>
#include<algorithm>
#include<cassert>
#define int long long
using namespace std;
const int N = 1e5 + 10;
const int LOGN = 19;
const int INF = 0x3f3f3f3f;

namespace io {
    char endl = '\n';
    struct istream {
        char ch;
        bool flag;
        inline istream &operator>>(int &x) {
            flag = 0;
            while(!isdigit(ch = getchar())) (ch == '-') && (flag = 1);
            x = ch - '0';
            while(isdigit(ch = getchar())) x = x * 10 + ch - '0';
            flag && (x = -x);
            return *this;
        }
    } cin;
    struct ostream {
        char buf[60];
        int top;
        inline ostream() : top(0) {}
        inline ostream &operator<<(int x) {
            do buf[++top] = x % 10 + '0', x /= 10; while(x);
            while(top) putchar(buf[top--]);
            return *this;
        }
        inline ostream &operator<<(char c) {
            putchar(c);
            return *this;
        }
        inline ostream &operator<<(const char *s) {
            for(int i = 0; s[i]; i++) putchar(s[i]);
            return *this;
        }
    } cout;
}

using io::cin;
using io::cout;
using io::endl;

struct Edge {
    int v, w, next;
};

int n, ans;
int c[N], d1[N];

namespace T3 {

    Edge pool[2 * N];
    int ne, head[N];
    void addEdge(int u, int v, int w) {
        pool[++ne] = {v, w, head[u]};
        head[u] = ne;
    }

    struct myPair {
        int x1, y1, x2, y2;
        int a1, a2;
        inline myPair() : x1(0), y1(0), x2(0), y2(0), a1(0), a2(0) {}
        inline myPair(int _p1, int _p2, int _p3, int _p4, int _p5, int _p6) : x1(_p1), y1(_p2), x2(_p3), y2(_p4), a1(_p5), a2(_p6) {}
    };

    int dfn[N], dt, dis[N];
    int mn[LOGN][N], lg[N], dep[N];

    void dfs(int u, int fa) {
        dfn[u] = ++dt;
        mn[0][dfn[u]] = fa;
        dep[u] = dep[fa] + 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = pool[e].w;
            if(v == fa) continue;
            dis[v] = dis[u] + w;
            dfs(v, u);
        }
    }

    inline int getLCA(int x, int y) {
        if(x == y) return x;
        x = dfn[x], y = dfn[y];
        if(x > y) swap(x, y);
        ++x;
        int d = lg[y - x + 1];
        return dep[mn[d][x]] < dep[mn[d][y - (1 << d) + 1]] ? mn[d][x] : mn[d][y - (1 << d) + 1];
    }

    inline int getDis(int x, int y) {
        return dis[x] + dis[y] - (dis[getLCA(x, y)] << 1);
    }

    void init() {
        dfs(1, 0);
        for(int i = 2; i < N; i++) lg[i] = lg[i >> 1] + 1;
        for(int k = 1; k < LOGN; k++) {
            for(int i = 1; i + (1 << k) - 1 <= n; i++) {
                int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
                mn[k][i] = (dep[t1] < dep[t2]) ? t1 : t2;
            }
        }
    }

}

namespace T2 {

    namespace T2_0 {

        Edge pool[2 * N];
        int ne, head[N];
        void addEdge(int u, int v, int w) {
            pool[++ne] = {v, w, head[u]};
            head[u] = ne;
        }

    }

    using T3::myPair;

    int dfn[N], dis[N], dt;
    int mn[LOGN][N], lg[N], dep[N];

    namespace VT {

        Edge pool[2 * N];
        int ne, head[N];
        void addEdge(int u, int v) {
            pool[++ne] = {v, abs(dis[u] - dis[v]), head[u]};
            head[u] = ne;
        }

        void clear_Edge(int k) {
            for(int i = 0; i <= k + 5; i++) head[i] = 0;
            ne = 0;
        }

    }

    void dfs(int u, int fa) {
        using namespace T2_0;
        assert(!dfn[u]);
        assert(dt <= n);
        dfn[u] = ++dt;
        mn[0][dfn[u]] = fa;
        dep[u] = dep[fa] + 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = pool[e].w;
            if(v == fa) continue;
            dis[v] = dis[u] + w;
            dfs(v, u);
        }
    }

    int getLCA(int x, int y) {
        if(x == y) return x;
        x = dfn[x], y = dfn[y];
        if(x > y) swap(x, y);
        ++x;
        int d = lg[y - x + 1];
        return dep[mn[d][x]] < dep[mn[d][y - (1 << d) + 1]] ? mn[d][x] : mn[d][y - (1 << d) + 1];
    }

    void init() {
        dfs(1, 0);
        for(int i = 2; i < N; i++) lg[i] = lg[i >> 1] + 1;
        for(int k = 1; k < LOGN; k++) {
            for(int i = 1; i + (1 << k) - 1 <= n; i++) {
                int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
                mn[k][i] = dep[t1] < dep[t2] ? t1 : t2;
            }
        }
    }

    void build_VT(vector<int> &pt) {
        using namespace VT;
        static int sta[N], top;
        static int buf[N], top2;

        while(top2) { head[buf[top2--]] = 0; }
        ne = 0;

        sort(pt.begin(), pt.end(), [](int a, int b) { return dfn[a] < dfn[b]; } );
        top = 0;
        sta[++top] = 1;
        buf[++top2] = 1;
        for(int &i : pt) {
            if(i == 1) continue;
            int lca = getLCA(sta[top], i);
            if(lca == sta[top]) { sta[++top] = i; buf[++top2] = i; continue; }
            while(dfn[sta[top - 1]] > dfn[lca]) {
                addEdge(sta[top - 1], sta[top]);
                --top;
            }
            if(sta[top] != lca) {
                addEdge(lca, sta[top]);
            }
            --top;
            if(sta[top] != lca) {
                sta[++top] = lca;
                buf[++top2] = lca;
            }
            sta[++top] = i;
            buf[++top2] = i;
        }
        while(top >= 2) addEdge(sta[top - 1], sta[top]), --top;
    }

    // void test() {
    //     int k;
    //     vector<int> pt;
    //     cin >> k;
    //     for(int i = 1; i <= k; i++) {
    //         int x;
    //         cin >> x;
    //         pt.push_back(x);
    //     }
    //     build_VT(pt);
    // }

    inline int calc(int a, int b) {
        if(!a || !b) return 0;
        return d1[a] + d1[b] + dis[a] + dis[b] + T3::getDis(a, b);
    }

    inline int calc(const myPair &a, const myPair &b) {
        int res = 0;
        res = max(res, calc(a.x1, b.x2));
        res = max(res, calc(a.x1, b.y2));
        res = max(res, calc(a.y1, b.x2));
        res = max(res, calc(a.y1, b.y2));
        res = max(res, calc(a.x2, b.x1));
        res = max(res, calc(a.x2, b.y1));
        res = max(res, calc(a.y2, b.x1));
        res = max(res, calc(a.y2, b.y1));
        return res;
    }

    inline myPair merge(const myPair &a, const myPair &b) {
        myPair res = a;
        int tmp = 0;
        if(res.a1 < (tmp = calc(b.x1, b.y1))) res.x1 = b.x1, res.y1 = b.y1, res.a1 = tmp;
        if(res.a2 < (tmp = calc(b.x2, b.y2))) res.x2 = b.x2, res.y2 = b.y2, res.a2 = tmp;

        if(res.a1 < (tmp = calc(a.x1, b.x1))) res.x1 = a.x1, res.y1 = b.x1, res.a1 = tmp;
        if(res.a1 < (tmp = calc(a.x1, b.y1))) res.x1 = a.x1, res.y1 = b.y1, res.a1 = tmp;
        if(res.a1 < (tmp = calc(a.y1, b.x1))) res.x1 = a.y1, res.y1 = b.x1, res.a1 = tmp;
        if(res.a1 < (tmp = calc(a.y1, b.y1))) res.x1 = a.y1, res.y1 = b.y1, res.a1 = tmp;

        if(res.a2 < (tmp = calc(a.x2, b.x2))) res.x2 = a.x2, res.y2 = b.x2, res.a2 = tmp;
        if(res.a2 < (tmp = calc(a.x2, b.y2))) res.x2 = a.x2, res.y2 = b.y2, res.a2 = tmp;
        if(res.a2 < (tmp = calc(a.y2, b.x2))) res.x2 = a.y2, res.y2 = b.x2, res.a2 = tmp;
        if(res.a2 < (tmp = calc(a.y2, b.y2))) res.x2 = a.y2, res.y2 = b.y2, res.a2 = tmp;
        return res;
    }

    myPair solve(int u, int fa) {
        using namespace VT;
        myPair res;
        if(c[u] == 1) res = (myPair){u, u, 0, 0, 0, 0};
        if(c[u] == 2) res = (myPair){0, 0, u, u, 0, 0};
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            myPair tmp = solve(v, u);
            ans = max(ans, calc(res, tmp) - 2 * dis[u]);
            res = merge(res, tmp);
        }
        return res;
    }

    void work() {
        solve(1, 0);
    }

}

namespace T1_1 {

    Edge pool[4 * N];
    int ne = 1, head[2 * N];
    void addEdge(int u, int v, int w) {
        pool[++ne] = {v, w, head[u]};
        head[u] = ne;
    }

    int sz[2 * N];
    int vis[4 * N];

    void get_sz(int u, int fa) {
        sz[u] = 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[e]) continue;
            get_sz(v, u);
            sz[u] += sz[v];
        }
    }

    void get_rt(int u, int fa, int tot, int &rt, int &re) {
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[e]) continue;
            get_rt(v, u, tot, rt, re);
            if(rt == v) re = e;
        }
        if(max(sz[u], tot - sz[u]) < max(sz[rt], tot - sz[rt])) rt = u;
    }

    void get_dis(int u, int fa, int col, int dis, vector<int> &pt) {
        if(u <= n) {
            d1[u] = dis;
            c[u] = col;
            pt.push_back(u);
        }
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = pool[e].w;
            if(v == fa || vis[e]) continue;
            get_dis(v, u, col, dis + w, pt);
        }
    }

    void clear(int u, int fa) {
        if(u <= n) d1[u] = c[u] = 0;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[e]) continue;
            clear(v, u);
        }
    }

    void calc(int u, int v, int w) {
        static vector<int> pt;

        get_dis(u, 0, 1, 0, pt);
        get_dis(v, 0, 2, w, pt);

        T2::build_VT(pt);
        T2::work();

        clear(u, 0);
        clear(v, 0);
        pt.clear();
    }

    void solve(int p) {
        int rt = 0, re;
        get_sz(p, 0);
        get_rt(p, 0, sz[p], rt, re);
        assert(!vis[re]);
        vis[re] = vis[re ^ 1] = 1;
        int u = pool[re ^ 1].v, v = pool[re].v, w = pool[re].w;
        calc(u, v, w);
        for(int e = head[u]; e; e = pool[e].next) {
            if(vis[e]) continue;
            int vv = pool[e].v;
            solve(vv);
        }
        for(int e = head[v]; e; e = pool[e].next) {
            if(vis[e]) continue;
            int vv = pool[e].v;
            solve(vv);
        }
    }

}

namespace T1_0 {

    Edge pool[2 * N];
    int ne, head[N];
    void addEdge(int u, int v, int w) {
        pool[++ne] = {v, w, head[u]};
        head[u] = ne;
    }

    int nn;

    void build(int u, int fa) {
        int last = u;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = pool[e].w;
            if(v == fa) continue;
            T1_1::addEdge(last, v, w);
            T1_1::addEdge(v, last, w);
            T1_1::addEdge(last, ++nn, 0);
            T1_1::addEdge(nn, last, 0);
            last = nn;
        }
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa) continue;
            build(v, u);
        }
    }

}

signed main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        T1_0::addEdge(u, v, w);
        T1_0::addEdge(v, u, w);
    }
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        T2::T2_0::addEdge(u, v, w);
        T2::T2_0::addEdge(v, u, w);
    }
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        T3::addEdge(u, v, w);
        T3::addEdge(v, u, w);
    }

    T1_0::nn = n;
    T1_0::build(1, 0); // 三度化

    T2::init(); // 求出 dfn, dis
    T3::init(); // 求出 dfn, dis

    T1_1::solve(1);

    cout << ans << '\n';

    return 0;
}