跳转至

动态 DP

前置知识:线段树维护矩阵乘法线段树优化 DP重链剖分

P4719 【模板】动态 DP

给定一棵 \(n\) 个点的树,点有点权。有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

\(n,m\le 10^5\)

如果修改权值之后暴力 DP 求出答案,则时间复杂度为 \(O(nm)\),不能通过本题。

容易发现修改 \(x\) 节点的权值只会改变 \(x\) 到根节点路径上所有节点的 dp 值,其它位置则保持不变。然而返根链的长度之和仍然是 \(O(nm)\) 量级的,考虑优化。

我们知道,重链剖分通过将树剖分为若干重链,使得任意一个节点到根节点都只经过不超过 \(\log_2n\) 条不同的重链。由于一条重链上所有节点的 \(dfn\) 序是连续的,可以使用线段树维护每条重链的信息,实现 \(O(\log^2 n)\) 修改和查询路径信息。

这启示我们可以通过将树剖分为若干重链,以此减少操作返根链的时间复杂度。但是一个节点的 dp 值还受其轻儿子影响,应该如何使用线段树维护重链信息呢?

我们写出朴素的 dp 转移式:

\[ \begin{align*} f_{u,0}&=\left(\sum_{v\in to[u]}{\max(f_{v,0},f_{v,1})}\right)\\ f_{u,1}&=\left(\sum_{v\in to[u]}{f_{v,0}}\right)+w[u] \end{align*} \]

为了方便书写,我们定义 \(0\) 号节点为一个虚点,是所有叶子节点的唯一“儿子”。其 dp 值 \(f_{u,0}=0\)\(f_{u,1}=-\infty\)

\(v_0\) 表示 \(u\) 的重儿子;分离重儿子和轻儿子:

\[ \begin{align*} f_{u,0}&=\left(\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})}\right)+\max(f_{v_0,0},f_{v_0,1})\\ f_{u,1}&=\left(\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\right)+w[u]+f_{v_0,0} \end{align*} \]

树上动态 DP 的核心思想就是把所有轻儿子的信息浓缩为一个矩阵,将这个矩阵与重儿子的 dp 向量相乘,就能得到当前节点的 dp 向量。根据转移方程,我们定义此处的矩阵乘法为 \(\left<\max,+\right>\) 的广义矩阵乘法。

\[ \begin{align*} g_{u,0}&=\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\\ g_{u,1}&=\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})} \end{align*} \]

则转移式可以写成

\[ \begin{align*} f_{u,0}&=g_{u,1}+\max(f_{v_0,0},f_{v_0,1})\\ f_{u,1}&=g_{u,0}+w[u]+f_{v_0,0} \end{align*} \]

写成矩阵乘法的形式:

\[ \left[ \begin{matrix} f_{u,0}\\ f_{u,1} \end{matrix} \right]= \left[ \begin{matrix} g_{u,1}& g_{u,1}\\ g_{u,0}+w[u]& -\infty \end{matrix} \right] \left[ \begin{matrix} f_{v_0,0}\\ f_{v_0,1} \end{matrix} \right] \]

由于 \(g\) 只和轻儿子有关,因此每个节点的转移矩阵也就和重儿子无关。使用线段树维护每一条重链上转移矩阵的乘积。对于节点 \(u\),记 \(\operatorname{bot}[u]\) 为其链底节点,我们用线段树查询 \([u,\operatorname{bot}[u]]\) 中节点的转移矩阵的乘积 \(op\),将 \(op\) 左乘 \(0\) 号节点的 dp 向量,即可得到 \(u\) 节点的 dp 向量。

修改时,先修改 \(x\) 节点的权值 \(w[x]\) 和其转移矩阵;然后从下往上考虑 \(x\) 到根节点经过的所有轻边 \(fa[x']\rightarrow x'\),使用线段树查询 \(x'\) 的 dp 向量,然后更新 \(fa[x']\) 的转移矩阵即可。

注意矩阵乘法的左右顺序

矩阵乘法没有交换律,使用线段树维护时一定要注意乘法顺序。可以根据

\[ (AB)^{\operatorname{T}}=B^{\operatorname{T}}A^{\operatorname{T}} \]

转置矩阵并交换顺序。

代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

struct Matrix {
    int a[2][2];
    inline Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = 0; }
    inline int* operator[](int index) { return a[index]; }
    inline const int* operator[](int index) const { return a[index]; }
    inline Matrix operator*(const Matrix& b) const {
        Matrix res;
        res[0][0] = max(a[0][0] + b[0][0], a[0][1] + b[1][0]);
        res[0][1] = max(a[0][0] + b[0][1], a[0][1] + b[1][1]);
        res[1][0] = max(a[1][0] + b[0][0], a[1][1] + b[1][0]);
        res[1][1] = max(a[1][0] + b[0][1], a[1][1] + b[1][1]);
        return res;
    }
};

Matrix a[N];

namespace SegT {
    Matrix tr[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        tr[p] = tr[lc(p)] * tr[rc(p)];
    }
    void build(int p, int l, int r) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void update(int p, int l, int r, int q) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        if(q <= mid) update(lc(p), l, mid, q);
        else update(rc(p), mid + 1, r, q);
        push_up(p);
    }
    Matrix query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return tr[p];
        }
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return query(lc(p), l, mid, ql, qr) * query(rc(p), mid + 1, r, ql, qr);
    }
}

int n, m;
int w[N];

int sz[N], son[N], fa[N];
int dfn[N], id[N], top[N], bot[N], dt;
int f[N][2], g[N][2];

void get_sz(int u, int a0) {
    sz[u] = 1;
    fa[u] = a0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0) continue;
        get_sz(v, u);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

void init(int u, int a0, int tp) {
    dfn[u] = ++dt;
    id[dt] = u;
    top[u] = tp;
    if(!son[u]) {
        f[u][0] = 0;
        f[u][1] = w[u];
        Matrix& cur = a[dfn[u]];
        cur[1][0] = w[u];
        cur[1][1] = -INF;
        bot[u] = u;
        return;
    }
    init(son[u], u, tp);
    bot[u] = bot[son[u]];
    f[u][0] += max(f[son[u]][0], f[son[u]][1]);
    f[u][1] += f[son[u]][0] + w[u];
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0 || v == son[u]) continue;
        init(v, u, v);
        g[u][0] += f[v][0];
        g[u][1] += max(f[v][0], f[v][1]);
    }
    f[u][0] += g[u][1];
    f[u][1] += g[u][0];
    Matrix& cur = a[dfn[u]];
    cur[0][0] = cur[0][1] = g[u][1];
    cur[1][0] = g[u][0] + w[u];
    cur[1][1] = -INF;
}

void modify(int x, int y) {
    a[dfn[x]][1][0] += y - w[x];
    w[x] = y;
    SegT::update(1, 1, n, dfn[x]);
    x = top[x];
    while(x != 1) {
        Matrix cur = SegT::query(1, 1, n, dfn[x], dfn[bot[x]]);
        a[dfn[fa[x]]][0][0] -= max(f[x][0], f[x][1]);
        a[dfn[fa[x]]][1][0] -= f[x][0];
        f[x][0] = cur[0][0];
        f[x][1] = cur[1][0];
        a[dfn[fa[x]]][0][0] += max(f[x][0], f[x][1]);
        a[dfn[fa[x]]][0][1] = a[dfn[fa[x]]][0][0];
        a[dfn[fa[x]]][1][0] += f[x][0];
        SegT::update(1, 1, n, dfn[fa[x]]);
        x = top[fa[x]];
    }
}

int query() {
    Matrix rt = SegT::query(1, 1, n, dfn[1], dfn[bot[1]]);
    return max(rt[0][0], rt[1][0]);
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    get_sz(1, 0);
    init(1, 0, 1);
    SegT::build(1, 1, n);

    while(m--) {
        int x, y;
        cin >> x >> y;
        modify(x, y);
        cout << query() << '\n';
    }

    return 0;
}