跳转至

250217 模拟赛 T2 连通块计数 题解

题意

有一个树有 \(n\) 个节点,其 \(\operatorname{dfn}\) 序由以下规则求出:

  1. 设置时间戳 \(\operatorname T = 0\),将树转为以 \(1\) 为根的有根树。并将每个节点的儿子按照节点标号从小到大排序。
  2. 访问节点 \(1\)
  3. 当前在访问节点 \(u\),将 \(\operatorname T\leftarrow \operatorname T+1\),然后将 \(dfn[u]\) 设置为 \(\operatorname T\)
  4. 按照节点编号顺序访问 \(u\) 的儿子,重复操作 \(3\)

求出 \(\operatorname{dfn}\) 序之后,有 \(q\) 个询问形如 \(k,l_1,r_1,l_2,r_2\cdots\),记点集 \(S=\{x\mid \exists\ i,dfn[x]\in [l_i,r_i]\}\),也就是 \(\operatorname{dfn}\) 序在任意一个区间里的节点。

保证 \(l_i,r_i\) 不交,且端点递增。

你需要求出 \(S\) 点集分为多少个连通块,输出连通块数 \(+1\)

部分测试点强制在线。

\(n,m,\sum k\le 5\times 10^5\)

题解

考虑经典 trick:\(nxt\) 数组:\(nxt[u]=dfn[u]+sz[u]\),即 \(\operatorname{dfn}\) 序向后第一个不在 \(u\) 子树内的节点。

容易发现节点 \(u\)\(\operatorname{dfn}\) 序上向后到达的第一个不连通的节点就是 \(nxt[u]\)。每在 \(nxt\) 上跳一次就会额外产生一个连通块。我们关注 \(u\)\(nxt\) 上跳了多少次(倍增解决),以及每次的连通块有没有和之前的区间连边。

通过画图发现,一段连续的 dfn 区间是一些挂在左链右边的连续子树。只有左链上的节点才会和当前区间连边。左链从上到下 dfn 序单增,右边挂的子树从上往下 dfn 序单减。我们依次处理每个区间 \([l_i,r_i]\),可以不断记录落在当前左链上的所有区间 \([l_j,r_j]\),使用倍增判断出每一个区间都覆盖了多少棵 \([l_i,r_i]\) 挂在下面的子树,就可以知道它们使连通块数量减少了多少。

那么如何记录这些可能出现在左链上的区间呢?注意到区间对应的左链是从左向右转动的。如果当前区间 \([l_i,r_i]\) 把之前的某一个区间 \([l_j,r_j]\) 完全盖在了左边,那么 \([l_j,r_j]\) 就不可能对后面的区间产生贡献了。注意到每次盖住的区间是当前有效的所有区间的一个后缀,可以使用单调栈维护这些有效的区间。每次倍增跳 \(nxt\),数出 \([l_i,r_i]\) 有多少个子树被覆盖在了栈顶区间内。若完全覆盖栈顶区间,则弹栈。

AC 代码

注意细节。

#include<iostream>
#include<vector>
#include<algorithm>
#define cint const int&
using namespace std;
const int N = 5E5 + 10;
const int LOGN = 21;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Range {
    int l, r;
};

int n, q, b;
int fa[N], id[N], dfn[N], sz[N], dt;
int nxt[N][LOGN];
int k, l[N], r[N];

vector<int> adj[N];

Range sta[N];
int top;

void dfs(int u, int nxt) {
    dfn[u] = ++dt;
    id[dt] = u;
    fa[u] = nxt;
    sz[u] = 1;
    for(int v : adj[u]) {
        if(v == nxt) continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

int main() {

    freopen("lu.in", "r", stdin);
    freopen("lu.out", "w", stdout);

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> q >> b;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    for(int i = 1; i <= n; i++) {
        sort(adj[i].begin(), adj[i].end());
    }

    dfs(1, 0);

    nxt[n + 1][0] = n + 1;
    dfn[n + 1] = n + 1;
    id[n + 1] = n + 1;
    for(int i = 1; i <= n; i++) {
        nxt[i][0] = id[dfn[i] + sz[i]];
    }
    for(int j = 1; j < LOGN; j++) {
        for(int i = 1; i <= n + 1; i++) {
            nxt[i][j] = nxt[ nxt[i][j - 1] ][j - 1];
        }
    }

    int ans = 0, cur;
    while(q--) {
        cin >> k;
        ans *= b;
        for(int i = 1; i <= k; i++) {
            cin >> l[i] >> r[i];
            l[i] ^= ans, r[i] ^= ans;
        }
        ans = 0;
        top = 0;
        sta[++top] = {-1, -1};
        for(int now = 1; now <= k; now++) {
            int rf = id[l[now]];
            ++ans;
            for(int i = LOGN - 1; i >= 0; i--) {
                if(dfn[nxt[rf][i]] <= r[now]) {
                    rf = nxt[rf][i];
                    ans += (1 << i);
                }
            }
            cur = id[l[now]];
            while(top && dfn[cur] <= dfn[rf] && sta[top].r >= dfn[fa[rf]]) {
                Range tp = sta[top];
                if(dfn[fa[cur]] > tp.r) {
                    for(int i = LOGN - 1; i >= 0; i--) {
                        if(dfn[nxt[cur][i]] <= dfn[rf] && dfn[fa[nxt[cur][i]]] > tp.r) {
                            cur = nxt[cur][i];
                        }
                    }
                    cur = nxt[cur][0];
                }
                if(dfn[fa[cur]] >= tp.l) {
                    --ans;
                    for(int i = LOGN - 1; i >= 0; i--) {
                        if(dfn[nxt[cur][i]] <= dfn[rf] && tp.l <= dfn[fa[nxt[cur][i]]]) {
                            cur = nxt[cur][i];
                            ans -= (1 << i);
                        }
                    }
                    cur = nxt[cur][0]; 
                }
                if(sta[top].l > dfn[fa[rf]]) --top;
                else break;
            }
            sta[++top] = {l[now], r[now]};
        }
        cout << --ans << endl;
    }

    return 0;
}

/*
6 2 0
1 3
1 2
2 4
2 5
3 6
2
1 3 5 6
1
4 5


7 1 0
1 2
2 3
3 4
4 5
2 6
6 7
2
2 5 6 7

*/