跳转至

点分树(动态点分治)

点分树是指从上一级分治中心指向下一级分治中心的(虚)边组成的树。

点分树示意图

性质:对于任意两点 \(x,y\),它们在点分树上的 \(\operatorname{lca}\) 存在于它们在原树的最短路径上;

由于点分树也涵盖了所有节点,并且深度比较小(最多 \(\log n\) 层),因此我们可以枚举操作节点 \(x\) 在点分树上的祖先 \(t\),然后统计以 \(t\) 为点分树上 \(\operatorname{lca}\) 产生的贡献。

我们可以在分治中心处维护一个数据结构,维护对应连通块内的节点信息,来实现动态的修改查询。

P6329 【模板】点分树 | 震波

题意

给定一棵 \(n\) 个节点的树,点有点权 \(a_i\),你需要支持 \(q\) 次操作,每次操作是如下两种中的一种:

  • 给定 \(x,d\),求 \(\sum_{dis(x,y)\le d}a_y\)
  • 给定 \(x,y\),执行 \(a_x\gets y\)

强制在线,\(n,q\le 10^5,\ 0\le d\le n-1\)

考虑点分树,对每个分治中心,我们用数据结构维护连通块内部所有节点到它的距离信息。对于查询操作,假设查询 \((x,d)\),我们在点分树上从 \(x\) 开始往上跳,对于 \(x\) 在点分树上的祖先 \(t\),我们查询 \(t\) 对应的连通块内到 \(t\) 距离不超过 \(d-dis(x,t)\) 的点的点权和。

然而这样查询会有一个问题,在处理 \(t\) 连通块时,\(t\)\(x\) 方向的(原树)子树内的节点的距离计算有误,因为它们与 \(x\) 的(点分树)\(\operatorname{lca}\) 不是 \(t\);并且它们的贡献已经在 \(t\) 的下一级位置统计过了;因此在 \(t\) 处需要去掉它们的贡献。

因此,我们再用另一个相同的数据结构维护出每个节点的连通块向它(点分树)父亲的贡献。不难发现,一个节点 \(u\) 对应的第一个数据结构(整个连通块的信息)等于它的所有(点分树)儿子的第二个数据结构信息的并,再加上 \(a[u]\)

这样,每次在祖先 \(t\) 处查询时,我们再额外在第二个数据结构上查询 \(t\)\(x\) 方向的(点分树)儿子对应的连通块中,有多少个到 \(t\) 距离不超过 \(d-dis(x,t)\) 的节点。

上面提到的信息使用树状数组维护即可,用 vector 动态开空间,空间复杂度 \(O(n\log n)\),时间复杂度 \(O(n\log^2 n)\)

核心代码
int query(int p, int k) {
    int bg = p;
    int res = 0, pre, dis;
    res += BIT_all[p].query(0, k); // 本连通块内的贡献也要统计
    pre = p; p = fa[p];
    while(p) {
        dis = lca::dis(bg, p);
        if(dis <= k) {
            res += BIT_all[p].query(0, k - dis); // 这个是第一个数据结构
            res -= BIT_fa[pre].query(0, k - dis); // 这个是第二个数据结构
        }
        pre = p; p = fa[p];
    }
    return res;
}

void modify(int p, int v) {
    int bg = p, delta = v - a[p];
    a[p] = v;
    int pre, dis;
    BIT_all[p].add(0, delta); // 修改本连通块
    pre = p, p = fa[p];
    while(p) {
        dis = lca::dis(bg, p);
        BIT_all[p].add(dis, delta); // 修改第一个数据结构
        BIT_fa[pre].add(dis, delta); // 修改第二个数据结构,因为是儿子向父亲的贡献,所以要修改在下一级位置
        pre = p; p = fa[p];
    }
}
完整代码
#include<iostream>
#include<vector>
#include<cassert>
using namespace std;
const int N = 1e5 + 10;
const int LOGN = 17;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, m;
int a[N];

namespace lca {
    const int LOGN = 18;
    int dep[N];
    int lg[N];
    int mn[LOGN][N];
    int dfn[N], dt;
    void dfs(int u, int fa) {
        dfn[u] = ++dt;
        mn[0][dfn[u]] = fa;
        dep[u] = dep[fa] + 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa) continue;
            dfs(v, u);
        }
    }
    void init() {
        dfs(1, 0);
        for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        for(int k = 1; k < LOGN; k++) {
            for(int i = 1; i + (1 << k) - 1 <= n; i++) {
                int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
                mn[k][i] = dep[t1] < dep[t2] ? t1 : t2;
            }
        }
    }
    inline int dis(int x, int y) {
        if(x == y) return 0;
        int px = dfn[x], py = dfn[y];
        if(px > py) swap(px, py);
        ++px;
        int d = lg[py - px + 1];
        int lca = dep[mn[d][px]] < dep[mn[d][py - (1 << d) + 1]] ? mn[d][px] : mn[d][py - (1 << d) + 1];
        return dep[x] + dep[y] - 2 * dep[lca];
    }
}

int vis[N];
int sz[N], mxp[N] = {INF};
int fa[N];
int dep[N];

struct BIT_Tp {
    int n;
    vector<int> s;
    inline int lowbit(int x) { return x & -x; }
    inline void init(int _n) {
        n = _n;
        s.resize(n + 3);
    }
    inline void add(int p, int v) {
        for(int i = p + 1; i <= n + 1; i += lowbit(i)) s[i] += v;
    }
    inline int query(int p) {
        if(p < 0) return 0;
        if(p > n) p = n;
        int res = 0;
        for(int i = p + 1; i > 0; i -= lowbit(i)) res += s[i];
        return res;
    }
    inline int query(int l, int r) {
        return query(r) - query(l - 1);
    }
} BIT_all[N], BIT_fa[N];

void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int fa, int tot, int &rt) {
    mxp[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxp[u] = max(mxp[u], sz[v]);
    }
    mxp[u] = max(mxp[u], tot - sz[u]);
    if(mxp[u] < mxp[rt]) rt = u;
}

void get_dis(int u, int fa, const int &rt, const int &fr) {
    dep[u] = dep[fa] + 1;
    BIT_all[fr].add(dep[u], a[u]);
    BIT_fa[rt].add(dep[u], a[u]);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_dis(v, u, rt, fr);
    }
}

void solve(int u) {
    vis[u] = 1;
    dep[u] = 0;
    BIT_all[u].init(sz[u]);
    BIT_all[u].add(0, a[u]);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v]) continue;
        int rt = 0;
        get_sz(v, u);
        get_rt(v, u, sz[v], rt);
        BIT_fa[rt].init(sz[v]);
        fa[rt] = u;
        get_dis(v, u, rt, u);
        sz[rt] = sz[v];
        solve(rt);
    }
}

int query(int p, int k) {
    int bg = p;
    int res = 0, pre, dis;
    res += BIT_all[p].query(0, k);
    pre = p; p = fa[p];
    while(p) {
        dis = lca::dis(bg, p);
        if(dis <= k) {
            res += BIT_all[p].query(0, k - dis);
            res -= BIT_fa[pre].query(0, k - dis);
        }
        pre = p; p = fa[p];
    }
    return res;
}

void modify(int p, int v) {
    int bg = p, delta = v - a[p];
    a[p] = v;
    int pre, dis;
    BIT_all[p].add(0, delta);
    pre = p, p = fa[p];
    while(p) {
        dis = lca::dis(bg, p);
        BIT_all[p].add(dis, delta);
        BIT_fa[pre].add(dis, delta);
        pre = p; p = fa[p];
    }
}

#define FIO

int main() {

    #ifdef FIO
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #endif

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    lca::init();

    int rt = 0;
    get_sz(3, 0);
    get_rt(3, 0, sz[3], rt);
    sz[rt] = sz[3];
    solve(rt);

    int lastans = 0;
    while(m--) {
        int op, x, y;
        cin >> op >> x >> y;
        x ^= lastans;
        y ^= lastans;
        if(x > n) throw;
        if(op == 0) {
            if(y > n) throw;
            cout << (lastans = query(x, y)) << '\n';
        } else {
            if(y > 10000) throw;
            modify(x, y);
        }
    }

    return 0;
}

P2056 [ZJOI2007] 捉迷藏

题意

给定一棵 \(n\) 个节点的无根树,每个节点上都有一盏灯,灯只有点亮熄灭两种状态。初始时所有灯都是熄灭的。

\(q\) 次操作,每次操作是如下两种中的一种:

  • 给定 \(x\),反转节点 \(x\) 处灯的状态;
  • 查询最远的两个熄灯的节点之间的距离,若没有输出 \(-1\),只有一个输出 \(0\)

\(n\le 10^5,\ q\le 5\times 10^5\)

考虑点分树,如何统计贡献。对于一个节点 \(t\),我们枚举它的两个子树,并分别从两个子树内选出两个距离最远的节点。那么我们如何去除重复子树的贡献呢?注意到由于我们只需要最大值,因此无需把整个连通块内的所有节点都放到分治中心 \(t\) 的数据结构里。我们只需找到每一个子树连通块内到 \(t\) 最远的关灯节点,将它放到分治中心 \(t\) 的数据结构里。这样,每棵子树最多贡献一次。

我们可以使用堆来存储每棵子树中最远距离,查询时只需要取出堆中的最大值和次大值相加即可。

为了维护出每棵子树(连通块)内部到上一层分治中心 \(t\) 最远的熄灯节点,我们给每个节点再开一个堆,维护所有熄灯节点到 \(t\) 的距离;每棵子树都会把堆顶元素贡献给分治中心 \(t\) 的第一个堆。

为了维护答案,我们再在全局开一个堆,将每个节点作为 \(\operatorname{lca}\) 时的最优答案装进这个堆里,查询时取这个堆的堆顶即可。

因为带修,所以使用可删堆。

代码
#include<iostream>
#include<queue>
#include<cassert>
using namespace std;
const int N = 1e5 + 10;
const int LOGN = 18;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

class myHeap {
    private:
        priority_queue<int> h, d;
        inline void update() {
            while(!d.empty() && h.top() == d.top()) h.pop(), d.pop();
        }
    public:
        inline myHeap() { h.push(-1); }
        inline int queryMax() {
            update();
            return h.top();
        }
        inline int queryAns() {
            if(size() < 2) return -1;
            update();
            int tmp = h.top();
            if(!~tmp) return -1;
            h.pop();
            update();
            int res = h.top();
            h.push(tmp);
            if(!~res) return -1;
            return res + tmp;
        }
        inline void insert(int x) { h.push(x); }
        inline void erase(int x) { d.push(x); }
        inline void pop() { update(); h.pop(); }
        inline size_t size() { return h.size() - d.size() - 1; }
};

int n, q;
myHeap ans;

namespace LCA {
    int dep[N], dfn[N], dt;
    int mn[LOGN][N], lg[N];
    void dfs(int u, int fa) {
        dep[u] = dep[fa] + 1;
        dfn[u] = ++dt;
        mn[0][dfn[u]] = fa;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa) continue;
            dfs(v, u);
        }
    }
    void init() {
        for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        dfs(1, 0);
        for(int k = 1; k < LOGN; k++) {
            for(int i = 1; i + (1 << k) - 1 <= n; i++) {
                int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
                mn[k][i] = dep[t1] < dep[t2] ? t1 : t2;
            }
        }
    }
    inline int getLCA(int x, int y) {
        if(x == y) return x;
        x = dfn[x], y = dfn[y];
        if(x > y) swap(x, y); ++x;
        int d = lg[y - x + 1];
        return dep[mn[d][x]] < dep[mn[d][y - (1 << d) + 1]] ? mn[d][x] : mn[d][y - (1 << d) + 1];
    }
    inline int getDis(int x, int y) {
        return dep[x] + dep[y] - 2 * dep[getLCA(x, y)];
    }
}

int sta[N];
int ss;

namespace starch {

    int vis[N], sz[N], mxp[N] = {INF};
    int anc[N];

    myHeap d[N];  // 每棵子树对应一个元素
    myHeap df[N]; // 当前子树内部的所有 dis

    void get_sz(int u, int fa) {
        sz[u] = 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[v]) continue;
            get_sz(v, u);
            sz[u] += sz[v];
        }
    }

    void get_rt(int u, int fa, int tot, int &rt) {
        mxp[u] = 0;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(v == fa || vis[v]) continue;
            get_rt(v, u, tot, rt);
            mxp[u] = max(mxp[u], sz[v]);
        }
        mxp[u] = max(mxp[u], tot - sz[u]);
        if(mxp[u] < mxp[rt]) rt = u;
    }

    void build(int u) {
        vis[u] = 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(vis[v]) continue;
            int rt = 0;
            get_sz(v, u);
            get_rt(v, u, sz[v], rt);
            anc[rt] = u;
            build(rt);
        }
    }

    void erase(int p) {
        {
            int preans = d[p].queryAns();
            d[p].erase(0);
            int cans = d[p].queryAns();
            if(cans != preans) {
                ans.erase(preans);
                ans.insert(cans);
            }
        }
        int cur = p;
        while(anc[cur]) {
            int fa = anc[cur];
            int premx = df[cur].queryMax();
            df[cur].erase(LCA::getDis(p, fa));
            int cmx = df[cur].queryMax();
            if(premx != cmx) {
                int preans = d[fa].queryAns();
                d[fa].erase(premx);
                d[fa].insert(cmx);
                int cans = d[fa].queryAns();
                if(cans != preans) {
                    ans.erase(preans);
                    ans.insert(cans);
                }
            }
            cur = anc[cur];
        }
    }

    void insert(int p) {
        int cur = p;
        {
            int preans = d[p].queryAns();
            d[p].insert(0);
            int cans = d[p].queryAns();
            if(cans != preans) {
                ans.erase(preans);
                ans.insert(cans);
            }
        }
        while(anc[cur]) {
            int fa = anc[cur];
            int premx = df[cur].queryMax();
            df[cur].insert(LCA::getDis(p, fa));
            int cmx = df[cur].queryMax();
            if(premx != cmx) {
                int preans = d[fa].queryAns();
                d[fa].erase(premx);
                d[fa].insert(cmx);
                int cans = d[fa].queryAns();
                if(cans != preans) {
                    ans.erase(preans);
                    ans.insert(cans);
                }
            }
            cur = anc[cur];
        }
    }

    void build() {
        int rt = 0;
        get_sz(1, 0);
        get_rt(1, 0, sz[1], rt);
        build(rt);
        for(int i = 1; i <= n; i++) {
            if(!anc[i]) continue;
            d[anc[i]].insert(-1);
        }
        for(int i = 1; i <= n; i++) {
            assert(d[i].queryAns() == -1);
            ans.insert(-1);
        }
        for(int i = 1; i <= n; i++) sta[i] = 1, ++ss;
        for(int i = 1; i <= n; i++) insert(i);
    }

}

void flip(int p) {
    if(sta[p]) {
        starch::erase(p);
        sta[p] = 0;
        --ss;
    } else {
        starch::insert(p);
        sta[p] = 1;
        ++ss;
    }
}

using namespace starch;

int main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    LCA::init();
    starch::build();

    cin >> q;
    while(q--) {
        char c;
        int x;
        cin >> c;
        if(c == 'C') {
            cin >> x;
            flip(x);
        } else {
            if(ss == 0) cout << "-1\n";
            else if(ss == 1) cout << "0\n";
            else cout << ans.queryMax() << '\n';
        }
    }

    return 0;
}

P3920 [WC2014] 紫荆花之恋