跳转至

虚树

有时,题目会在初始时给定一棵树,然后多次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2\cdots p_k\),要求回答相关信息,同时保证 \(\sum_{k}\)。我们显然不能 \(O(n)\) 的回答每次询问。因此,我们可以使用虚树来将单次询问的时间降低到 \(O(k)\)

虚树通过只保留树上的关键点以及维持树形态的必要节点(总共 \(O(k)\) 个)来将实际的计算量减小到只和 \(k\) 有关。

建立

我们先给出结论:当且仅当一个节点存在三棵不同的包含关键点的子树,它应当被保留在虚树中。这样,虚树中的一条边就代表了原树中的一条链和挂在链上的所有子树;同时要注意虚树中的一个点还包含了被省略掉的子树。

为了更方便建立虚树,我们放宽上面的条件,加入所有 \(\operatorname{lca}(p_i,p_j)\)。不难发现放宽条件后最多比放宽前多一个节点。

到这里,我们已经能写出一种方法建出一种虚树:

  • 将所有关键点按 dfn 序排序;
  • 单调栈清空,初始加入节点 \(1\)(如果没有);
  • 依次将关键点加入单调栈,设当前节点为 \(u\),如果栈顶节点是 \(u\) 的祖先,那么直接将 \(u\) 入栈;
    • 否则不断弹出单调栈中的元素,直到满足栈顶第二个元素是 \(u\) 的祖先;
    • 将栈顶和 \(u\)\(lca\) 记录下来,连边 $lca\to $ 栈顶,弹出栈顶,压入 \(lca\)
    • 压入 \(u\)
  • 结束后,弹出栈内的剩余元素;
  • 每次正常弹栈之前从栈顶第二个元素连一条边到栈顶;

然后在虚树上求解答案即可。

性质

  • 虚树的节点数量和关键点数量同阶;
  • 虚树上所有节点的相对祖先关系和原树相同;
  • 虚树上任意两个节点在原树上的 \(\operatorname{lca}\) 还在虚树上;

P4103 [HEOI2014] 大工程

题意

给定一棵 \(n\) 个点的树,有若干次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2,\cdots p_k\),表示询问 \((p_i,p_j)\) 二元组(\(i\ne j\))中 \(p_i,p_j\) 距离的总和,距离的最小值和最大值。

\(n\le 10^6,\ \sum k\le 2n\)

先建出虚树,然后在虚树上跑一遍 dfs,记录子树内离当前点最近、最远的关键点,和关键点的总数,然后在 lca 处统计贡献即可。

参考代码

本代码采用模拟调用栈的方法,不建议使用这种写法。

#include<bits/stl_algobase.h>
#include<ctype.h>
#include<cstdio>
#include<cassert>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 1e6 + 10;
const int LOGN = 21;
const ll INF = 0x3f3f3f3f3f3f3f3f;

struct istream {
    char ch;
    template<typename _Tp>
    inline istream &operator>>(_Tp &x) {
        while(!isdigit(ch = getchar_unlocked()));
        x = ch - '0';
        while(isdigit(ch = getchar_unlocked())) x = x * 10 + ch - '0';
        return *this;
    }
} cin;

struct ostream {
    char buf[60], top;
    inline ostream() { top = 0; }
    inline ostream &operator<<(char c) {
        putchar_unlocked(c);
        return *this;
    }
    template<typename _Tp>
    inline ostream &operator<<(_Tp x) {
        do buf[++top] = x % 10, x /= 10; while(x);
        while(top) putchar_unlocked(buf[top--] + '0');
        return *this;
    }
} cout;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, q, k;
int dfn[N], id[N], dt;
int dep[N], mn[LOGN][N], lg[N];

void dfs(int u, int fa) {
    dfn[u] = ++dt;
    id[dt] = u;
    dep[u] = dep[fa] + 1;
    mn[0][dfn[u]] = fa;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa) continue;
        dfs(v, u);
    }
}

int getLCA(int x, int y) {
    if(x == y) return x;
    x = dfn[x], y = dfn[y];
    if(x > y) swap(x, y);
    ++x;
    int d = lg[y - x + 1];
    int t1 = mn[d][x], t2 = mn[d][y - (1 << d) + 1];
    return dep[t1] < dep[t2] ? t1 : t2;
}

int getDis(int x, int y) {
    return dep[x] + dep[y] - 2 * dep[getLCA(x, y)];
}

int p[2 * N], cnt;
int sta[N], top;
bool cmp_dfn(int a, int b) {
    return dfn[a] < dfn[b];
}

int imp[N];

ll ans1, ans2, ans3;
ll sz[N], mxd[N], mnd[N];

void pop_stack() {
    int u = sta[top - 1], v = sta[top];
    if(u) {
        int w = getDis(u, v);
        ans1 += (ll)w * sz[v] * (k - sz[v]);
        sz[u] += sz[v];
        ans2 = min(ans2, mnd[u] + w + mnd[v]);
        ans3 = max(ans3, mxd[u] + w + mxd[v]);
        mnd[u] = min(mnd[u], w + mnd[v]);
        mxd[u] = max(mxd[u], w + mxd[v]);
    }
    --top;
}

void push_stack(int u) {
    sz[u] = imp[u];
    mxd[u] = imp[u] ? 0 : -INF;
    mnd[u] = imp[u] ? 0 : INF;
    sta[++top] = u;
}

int main() {

    for(int i = 2; i < N; i++) lg[i] = lg[i >> 1] + 1;

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    dfs(1, 0);

    for(int k = 1; k < LOGN; k++) {
        for(int i = 1; i + (1 << k) - 1 <= n; i++) {
            int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
            mn[k][i] = dep[t1] < dep[t2] ? t1 : t2;
        }
    }

    cin >> q;
    while(q--) {
        ans1 = 0, ans2 = INF, ans3 = 0, top = 0;
        cin >> k;
        for(int i = 1; i <= k; i++) cin >> p[i];
        for(int i = 1; i <= k; i++) imp[p[i]] = 1;
        sort(p + 1, p + 1 + k, cmp_dfn);
        cnt = k;
        for(int i = 1; i < k; i++) p[++cnt] = getLCA(p[i], p[i + 1]);
        sort(p + 1, p + 1 + cnt, cmp_dfn);
        cnt = unique(p + 1, p + 1 + cnt) - (p + 1);
        push_stack(p[1]);
        for(int i = 2; i <= cnt; i++) {
            int lca = getLCA(sta[top], p[i]);
            while(lca != sta[top]) {
                pop_stack();
            }
            push_stack(p[i]);
        }
        while(top) pop_stack();
        for(int i = 1; i <= cnt; i++) imp[p[i]] = 0;
        cout << ans1 << ' ' << ans2 << ' ' << ans3 << '\n';
    }

    return 0;
}

P3233 [HNOI2014] 世界树

题意

给定一棵 \(n\) 个点的树,有若干次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2,\cdots p_k\);对于树上的每个点,它都会被离它最近的关键点控制,若距离相同则被标号小者控制;问每个关键点控制多少个节点。

\(n\le 3\times 10^5,\ \sum k\le 3\times 10^5\)

先建出虚树 \(T'\)(边有边权),然后求出虚树上每个节点会被哪个关键点控制(记为 \(s_i\)),以及控制的距离是多少。

考察每一条边 \((u,v)\)(不妨设 \(u=fa[v]\)),\(u,v\) 被同一关键点控制的情况是平凡的;若不同,我们用倍增求出 \((u,v)\) 在原树上对应的链的分界点 \(t\)(即 \(t\) 是链上最深的被 \(s_v\) 控制的节点),记节点 \(s\) 是在链上 \(u\) 下面的一个节点,那么 \(s_u\) 会控制这条链上 \(sz[s]-sz[t]\) 个节点,\(s_v\) 会控制链上 \(sz[t]-sz[v]\) 个节点,累加到 \(s_u,s_v\) 的贡献上即可。

考察每一个节点 \(u\),除去链上的贡献,它还会对 \(s_u\) 产生 \(sz[u]-\sum_{v\in son[u]\wedge v\in T'}{sz[v]}\) 的贡献,将它们也累加到 ans 数组即可。

代码
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N = 3e5 + 10;
const int LOGN = 20;
const int INF = 0x3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, T;
int k;

int p[N];
int anc[LOGN][N];

int mnl[LOGN][N], lg[N];

int dep[N], sz[N];
int dfn[N], dt;
void dfs1(int u, int fa) {
    dfn[u] = ++dt;
    dep[u] = dep[fa] + 1;
    anc[0][u] = fa;
    sz[u] = 1;
    mnl[0][dfn[u]] = fa;
    for(int i = 1; i < LOGN; i++) anc[i][u] = anc[i - 1][ anc[i - 1][u] ];
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v];
    }
}

void init_st() {
    for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
    for(int k = 1; k < LOGN; k++) {
        for(int i = 1; i + (1 << k) - 1 <= n; i++) {
            int t1 = mnl[k - 1][i], t2 = mnl[k - 1][i + (1 << (k - 1))];
            mnl[k][i] = dep[t1] < dep[t2] ? t1 : t2;
        }
    }
}

int getLCA(int x, int y) {
    if(x == y) return x;
    x = dfn[x], y = dfn[y];
    if(x > y) { swap(x, y); } ++x;
    int d = lg[y - x + 1];
    return dep[mnl[d][x]] < dep[mnl[d][y - (1 << d) + 1]] ? mnl[d][x] : mnl[d][y - (1 << d) + 1];
}

int get_son(int x, int y) {
    int x1 = y;
    for(int i = LOGN - 1; i >= 0; i--) {
        if(dfn[anc[i][x1]] > dfn[x]) x1 = anc[i][x1];
    }
    return x1;
}

inline bool cmp_dfn(int a, int b) { return dfn[a] < dfn[b]; }

int iskey[N];
int sta[N], top;

int ans[N];
vector<int> qr; // 询问的点
vector<int> kp; // 所有虚树点

namespace VT {

    struct Edge {
        int v, son, next;
    } pool[2 * N];
    int ne, head[N];

    void addEdge(int u, int v) {
        pool[++ne] = {v, get_son(u, v), head[u]};
        head[u] = ne;
    }

    struct myPair {
        int p, d;
        inline myPair operator+(int w) const {
            return {p, d + w};
        }
        inline bool operator<(const myPair &b) const {
            if(d != b.d) return d < b.d;
            return p < b.p;
        }
    } mn[N];

    int val[N];

    void dfs1(int u) {
        val[u] = sz[u];
        if(iskey[u]) {
            mn[u] = {u, 0};
        } else mn[u] = {0, INF};
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = dep[v] - dep[u];
            val[u] -= sz[pool[e].son];
            dfs1(v);
            if(mn[v] + w < mn[u]) mn[u] = mn[v] + w;
        }
    }

    void calc(int x, int y, int s) {
        int p = y;
        for(int i = LOGN - 1; i >= 0; i--) {
            int p1 = anc[i][p];
            if(dfn[p1] <= dfn[x]) continue;
            int w1 = dep[p1] - dep[x];
            int w2 = dep[y] - dep[p1];
            if(mn[y] + w2 < mn[x] + w1) p = p1;
        }
        ans[mn[y].p] += sz[p] - sz[y];
        ans[mn[x].p] += sz[s] - sz[p];
    }

    void dfs2(int u) {
        ans[mn[u].p] += val[u];
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v, w = dep[v] - dep[u];
            if(mn[u] + w < mn[v]) mn[v] = mn[u] + w;
            calc(u, v, pool[e].son);
            dfs2(v);
        }
    }

    void work() {
        dfs1(1);
        dfs2(1);
    }

    void clear(int u) {
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            clear(v);
        }
        head[u] = 0;
    }

    // 清边
    void clear() {
        clear(1);
        ne = 0;
    }

}

int main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    qr.reserve(n);
    kp.reserve(n);

    dfs1(1, 0);
    init_st();

    cin >> T;
    while(T--) {
        cin >> k;
        qr.clear(); kp.clear(); top = 0;
        for(int i = 1; i <= k; i++) {
            cin >> p[i];
            ans[p[i]] = 0;
            qr.push_back(p[i]);
            kp.push_back(p[i]);
        }
        sort(p + 1, p + 1 + k, cmp_dfn);
        sta[++top] = 1;
        for(int i = 1 + (p[1] == 1); i <= k; i++) {
            int lca = getLCA(p[i], sta[top]);
            if(lca == sta[top]) {
                sta[++top] = p[i];
                continue;
            }
            while(dfn[sta[top - 1]] > dfn[lca]) {
                VT::addEdge(sta[top - 1], sta[top]);
                --top;
            }
            if(sta[top - 1] == lca) {
                VT::addEdge(sta[top - 1], sta[top]);
                --top;
            } else {
                VT::addEdge(lca, sta[top]);
                --top;
                sta[++top] = lca;
                kp.push_back(lca);
            }
            sta[++top] = p[i];
        }
        while(top >= 2) {
            VT::addEdge(sta[top - 1], sta[top]);
            --top;
        }
        for(int i = 1; i <= k; i++) iskey[p[i]] = 1;
        VT::work();
        for(int i : qr) cout << ans[i] << ' ';
        cout << '\n';
        for(int i = 1; i <= k; i++) iskey[p[i]] = 0;
        VT::clear();
    }

    return 0;
}