跳转至

树上莫队

树上莫队是一种将莫队算法应用于树结构的方法,主要用于解决静态树上的路径查询问题。通过将树展开成括号序,然后将区间内出现两次的节点抵消,从而实现路径 \(\rightarrow\) 区间的转化。

例题

P4074 [WC2013] 糖果公园

题意

给定一棵 \(n\) 个节点的树,点有颜色 \(c_i\)。定义一条路径 \(x\rightarrow y\) 的权值为 \(\sum_{i\in V}{a[i]b[cnt[i]]}\)。你需要实现两种操作:

  • 修改一个点的颜色;
  • 查询一条路径的权值;

\(n,m\le 10^5\)

树的括号序由如下方法得到:

  • 递归到一个节点 \(u\),将 \(u\) 加入括号序;
  • 递归处理 \(u\) 的所有儿子;
  • 结束时,再将 \(u\) 加入括号序;

记括号序为 \(bfn\)

容易发现,对于路径 \(x\rightarrow y\),其一定在括号序上对应一个区间 \([l,r]\),满足 \(x\rightarrow y\) 上的所有节点都在 \([l,r]\) 中出现仅 \(1\) 次(暂不考虑 \(lca\)),不在路径上的节点都出现 \(0\) 次或 \(2\) 次。

显然,\(bfn[l]\)\(bfn[r]\) 分别对应 \(x\)\(y\)

我们可以使用一个桶数组记录每个节点是否出现过。如果发现是第二次加入,则将其删除;否则正常加入。

注意:如果 \(x\)\(y\) 不具有祖先关系,则 \(lca\) 不会出现在 \([l,r]\) 中。此时需要特判加入 \(lca\) 的贡献;

if(!alr[q[i].fa]) ans[q[i].id] += b[cnt[c[q[i].fa]] + 1] * a[c[q[i].fa]];

注意:因为 \(x\)\(y\) 都在括号序中出现了 \(2\) 次,如果不当的选择 \(l\)\(r\),会将 \(x\)\(y\) 排除在外;

int tmp = lca::get(x, y);
if(pos1[x] > pos1[y]) swap(x, y);
if(pos2[x] < pos1[y]) x = pos2[x];
else x = pos1[x];
y = pos1[y];
if(x > y) swap(x, y);
++nq;
q[nq] = {x, y, tmp, nq, tt};

其它都是带修莫队板子。

代码
#include<iostream>
#include<algorithm>
#include<cmath>
#define ll long long
#define ld long double
using namespace std;
const int N = 1e5 + 10;
const int LOGN = 20;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

namespace lca {
    int anc[N][LOGN], dep[N];
    void init(int u, int fa) {
        anc[u][0] = fa;
        dep[u] = dep[fa] + 1;
        for(int i = 1; i < LOGN; i++) {
            anc[u][i] = anc[ anc[u][i - 1] ][i - 1];
        }
        for(int i = head[u]; i; i = pool[i].next) {
            int v = pool[i].v;
            if(v == fa) continue;
            init(v, u);
        }
    }
    int get(int x, int y) {
        if(dep[x] < dep[y]) swap(x, y);
        for(int i = LOGN - 1; i >= 0; i--) {
            if(dep[anc[x][i]] >= dep[y]) x = anc[x][i];
        }
        if(x == y) return x;
        for(int i = LOGN - 1; i >= 0; i--) {
            if(anc[x][i] != anc[y][i]) {
                x = anc[x][i];
                y = anc[y][i];
            }
        }
        return anc[x][0];
    }
}

struct Op {
    int p, v1, v2;
} p[N];
int tt;

int blo[2 * N], blen;
struct Query {
    int l, r, fa, id, t;
    inline bool operator<(const Query& other) const {
        if(blo[l] != blo[other.l]) return (blo[l] < blo[other.l]);
        if(blo[r] != blo[other.r]) return (blo[l] & 1) ? (blo[r] < blo[other.r]) : (blo[r] > blo[other.r]);
        return ((blo[l] + blo[r]) & 1) ? (t < other.t) : (t > other.t);
    }
} q[N];
int nq;

int n, m, T;
ll a[N], b[N], ans[N];
int c[N], c1[N];

int fa[N], pos1[N], pos2[N];
int bfn[2 * N], nn;
void dfs(int u, int a0) {
    bfn[++nn] = u;
    pos1[u] = nn;
    fa[u] = a0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0) continue;
        dfs(v, u);
    }
    bfn[++nn] = u;
    pos2[u] = nn;
}

int c_l, c_r, c_t;
ll c_ans;
ll cnt[N], alr[N];

void update_p(int pos) {
    pos = bfn[pos];
    alr[pos] ^= 1;
    int w = alr[pos] ? 1 : -1;
    if(w > 0) ++cnt[c[pos]];
    c_ans += w * b[cnt[c[pos]]] * a[c[pos]];
    if(w < 0) --cnt[c[pos]];
}

void update_t(int tt, int w) {
    int v1 = p[tt].v1, v2 = p[tt].v2, pos = p[tt].p;
    if(w < 0) swap(v1, v2);
    if(alr[pos]) {
        c_ans -= b[cnt[v1]] * a[v1];
        c_ans += b[cnt[v2] + 1] * a[v2];
        --cnt[v1];
        ++cnt[v2];
    }
    c[pos] = v2;
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m >> T;
    blen = max(1ll, (ll)pow((ld)n * n * 2, (ld)1 / 3));
    for(int i = 1; i <= 2 * n; i++) blo[i] = (i + blen - 1) / blen;
    for(int i = 1; i <= m; i++) cin >> a[i];
    for(int i = 1; i <= n; i++) cin >> b[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for(int i = 1; i <= n; i++) cin >> c[i];
    for(int i = 1; i <= n; i++) c1[i] = c[i];

    dfs(1, 0);
    lca::init(1, 0);

    for(int i = 1; i <= T; i++) {
        int op, x, y;
        cin >> op >> x >> y;
        if(op == 0) {
            if(c1[x] == y) continue;
            p[++tt] = {x, c1[x], y};
            c1[x] = y;
        } else {
            int tmp = lca::get(x, y);
            if(pos1[x] > pos1[y]) swap(x, y);
            if(pos2[x] < pos1[y]) x = pos2[x];
            else x = pos1[x];
            y = pos1[y];
            if(x > y) swap(x, y);
            ++nq;
            q[nq] = {x, y, tmp, nq, tt};
        }
    }

    sort(q + 1, q + 1 + nq);

    c_l = 1;
    c_r = 0;
    c_t = 0;
    for(int i = 1; i <= nq; i++) {
        int l = q[i].l, r = q[i].r, t = q[i].t;
        while(c_t < t) update_t(++c_t, 1);
        while(c_t > t) update_t(c_t--, -1);
        while(c_r < r) update_p(++c_r);
        while(c_l > l) update_p(--c_l);
        while(c_r > r) update_p(c_r--);
        while(c_l < l) update_p(c_l++);
        ans[q[i].id] = c_ans;
        if(!alr[q[i].fa]) ans[q[i].id] += b[cnt[c[q[i].fa]] + 1] * a[c[q[i].fa]];
    }

    for(int i = 1; i <= nq; i++) {
        cout << ans[i] << '\n';
    }

    return 0;
}