跳转至

P6773 [NOI2020] 命运 题解

题意

给定一棵 \(n\)\(n\le 5\times 10^5\))个节点的有根树和 \(m\)\(m \le 5\times 10^5\))条返祖链。现在要求用黑白两种颜色给所有边染色,使得每条返祖链都有至少一条边是白色的。问染色方案数对 \(998244353\) 取模的结果。

题解

记返祖链的底端为其起点,顶端为其末端。

容易发现,从同一个节点出发的多条返祖链,只有末端最深的一个是有效的。我们给每个节点维护一个数组 \(w[u]\) 表示从该起点出发的所有返祖链中,末端的最大深度。若没有从该节点出发的返祖链,则 \(w[u]=0\)

因为我们可以在根节点上方挂一条虚边,钦定这条虚边始终为白色,再从所有节点都向这个虚边上方连一条返祖链,就可以将有、无剩余返祖链的情况简单的统一。

1
2
3
4
5
for(int i = 1; i <= m; i++) {
    int u, v;
    cin >> u >> v;
    w[v] = max(w[v], dep[u]);
}

考虑一种朴素的 dp,设 \(dp_{u,j}\) 表示节点 \(u\) 的子树中,整个都在子树内的限制(返祖链)都被满足,而没有被满足的所有返祖链的末端节点中最深的一个深度为 \(j\),方案数是多少。(若所有都被满足,则 \(j=0\)

容易写出状态转移方程:

\[ dp'_{u,j}=dp_{u,j}\times (\sum_{k=0}^{j}{dp_{v,k}}+\sum_{k=0}^{dep_u}{dp_{v,k}})+(\sum_{k=0}^{j-1}{dp_{u,k}})\times dp_{v,j} \]

其中,第 \(1\)\(2\) 项表示新增的子树 \(v\) 没有改变最深的末端深度;第 \(3\) 项表示 \(v\) 产生了一条末端更深的未处理的返祖链。

这种式子可以直接使用线段树合并优化。

  • 对于 \(\sum\limits_{k=0}^{dep_u}{dp_{v,k}}\) 的常数,可以在调用 merge() 之前对 \(v\) 的线段树调用 query() 得到,记为 \(c\)
  • 对于 \(dp_{v,k}\)\(dp_{u,k}\) 的前缀和,因为线段树合并的递归顺序是一种前序遍历的 dfs 序,可以在 dfs 的时候动态维护两个变量 \(sum_1\)\(sum_2\),分别记录两个前缀和。
int c, sum1, sum2;
/*
 * p1 线段树对应 dp[u] 数组
 * p2 线段树对应 dp[v] 数组
*/
int merge(int p1, int p2, int l, int r) {
    if(p1 == 0 && p2 == 0) return 0;
/*
 * p2 == 0 说明 dp[v][j] 为 0,则转移方程的第三项为 0。
 * 且 sum1 没有变化。
 * 我们直接在原先 p1(dp[u][j])的基础上乘以 (sum1 + c) 即可(调用 move_tag 就是这个效果)。
*/
    if(p2 == 0) {
        sum2 = sum2 + sum[p1];
        move_tag(p1, sum1 + c);
        return p1;
    }
/*
 * p1 == 0 说明 dp[u][j] 为 0,则方程第 1、2 项为 0。
 * 且 sum2 没有变化。
 * 可以直接将 sum2 乘在 p2 树上面,并将其 copy 到 p1 上返回。
*/
    if(p1 == 0) {
        sum1 = sum1 + sum[p2];
        move_tag(p2, sum2);
        return p2;
    }
/*
 * 单节点处,直接根据公式处理即可。
 * 注意,sum2 的求和上界是 j-1,并且求和的元素是 原先 的 dp[u][1](区分更新后的 dp[u][1])
 * 因此要先开一个临时变量记录 sum2 更新后的结果,再去合并 p1 和 p2。
*/
    if(l == r) {
        int tmp;
        sum1 = sum1 + sum[p2];
        tmp = sum2 + sum[p1];
        sum[p1] = sum[p1] * (sum1 + c) + sum[p2] * sum2;
        sum2 = tmp;
        return p1;
    }
/*
 * 线段树合并板子
*/
    push_down(p1);
    push_down(p2);
    int mid = (l + r) >> 1;
    lc[p1] = merge(lc[p1], lc[p2], l, mid);
    rc[p1] = merge(rc[p1], rc[p2], mid + 1, r);
    push_up(p1);
    return p1;
}

调用处:

void dfs(int u, int f) {
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f) continue;
        dfs(v, u);
        sum1 = sum2 = 0;
        c = query(rt[v], 0, n, 0, dep[u]);
        rt[u] = merge(rt[u], rt[v], 0, n);
    }
}
完整代码
#include<iostream>
#define int long long
using namespace std;
const int N = 5E5 + 10;
const int MOD = 998244353;

struct Edge {
    int v;
    int next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, m;
int dep[N], w[N];

void dfs1(int u, int f) {
    dep[u] = dep[f] + 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f) continue;
        dfs1(v, u);
    }
}

int nn;
int rt[N];
int sum[32 * N], lc[32 * N], rc[32 * N], mul[32 * N];

void push_up(int p) {
    sum[p] = ((lc[p] ? sum[lc[p]] : 0) + (rc[p] ? sum[rc[p]] : 0)) % MOD;
}

void move_tag(int p, int tg) {
    if(p == 0) return;
    sum[p] = sum[p] * tg % MOD;
    mul[p] = mul[p] * tg % MOD;
}

void push_down(int p) {
    move_tag(lc[p], mul[p]);
    move_tag(rc[p], mul[p]);
    mul[p] = 1;
}

void insert(int &p, int l, int r, int q, int v) {
    if(p == 0) {
        p = ++nn;
        mul[p] = 1;
    }
    if(l == r) {
        sum[p] = (sum[p] + v) % MOD;
        return;
    }
    push_down(p);
    int mid = (l + r) >> 1;
    if(q <= mid) insert(lc[p], l, mid, q, v);
    else insert(rc[p], mid + 1, r, q, v);
    push_up(p);
}

int query(int p, int l, int r, int ql, int qr) {
    if(p == 0) return 0;
    if(ql <= l && r <= qr) {
        return sum[p];
    }
    push_down(p);
    int mid = (l + r) >> 1, res = 0;
    if(mid >= ql) res = query(lc[p], l, mid, ql, qr);
    if(mid < qr) res = (res + query(rc[p], mid + 1, r, ql, qr)) % MOD;
    return res;
}

int c, sum1, sum2;

int merge(int p1, int p2, int l, int r) {
    if(p1 == 0 && p2 == 0) return 0;
    if(p2 == 0) {
        sum2 = (sum2 + sum[p1]) % MOD;
        move_tag(p1, (sum1 + c) % MOD);
        return p1;
    }
    if(p1 == 0) {
        sum1 = (sum1 + sum[p2]) % MOD;
        move_tag(p2, sum2);
        return p2;
    }
    if(l == r) {
        int tmp;
        sum1 = (sum1 + sum[p2]) % MOD;
        tmp = (sum2 + sum[p1]) % MOD;
        sum[p1] = (sum[p1] * ((sum1 + c) % MOD) % MOD + sum[p2] * sum2 % MOD) % MOD;
        sum2 = tmp;
        return p1;
    }
    push_down(p1);
    push_down(p2);
    int mid = (l + r) >> 1;
    lc[p1] = merge(lc[p1], lc[p2], l, mid);
    rc[p1] = merge(rc[p1], rc[p2], mid + 1, r);
    push_up(p1);
    return p1;
}

void dfs(int u, int f) {
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == f) continue;
        dfs(v, u);
        sum1 = sum2 = 0;
        c = query(rt[v], 0, n, 0, dep[u]);
        rt[u] = merge(rt[u], rt[v], 0, n);
    }
}

signed main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    dfs1(1, 0);

    cin >> m;
    for(int i = 1; i <= m; i++) {
        int u, v;
        cin >> u >> v;
        w[v] = max(w[v], dep[u]);
    }

    for(int i = 1; i <= n; i++) {
        insert(rt[i], 0, n, w[i], 1);
    }

    dfs(1, 0);

    cout << query(rt[1], 0, n, 0, 0) << endl;

    return 0;
}