跳转至

动态 DP

前置知识:线段树维护矩阵乘法线段树优化 DP重链剖分

P4719 【模板】动态 DP

题意

给定一棵 \(n\) 个点的树,点有点权。有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

\(n,m\le 10^5\)

如果修改权值之后暴力 DP 求出答案,则时间复杂度为 \(O(nm)\),不能通过本题。

容易发现修改 \(x\) 节点的权值只会改变 \(x\) 到根节点路径上所有节点的 dp 值,其它位置则保持不变。然而返根链的长度之和仍然是 \(O(nm)\) 量级的,考虑优化。

我们知道,重链剖分通过将树剖分为若干重链,使得任意一个节点到根节点都只经过不超过 \(\log_2n\) 条不同的重链。由于一条重链上所有节点的 \(dfn\) 序是连续的,可以使用线段树维护每条重链的信息,实现 \(O(\log^2 n)\) 修改和查询路径信息。

这启示我们可以通过将树剖分为若干重链,以此减少操作返根链的时间复杂度。但是一个节点的 dp 值还受其轻儿子影响,应该如何使用线段树维护重链信息呢?

我们写出朴素的 dp 转移式:

\[ \begin{align*} f_{u,0}&=\left(\sum_{v\in to[u]}{\max(f_{v,0},f_{v,1})}\right)\\ f_{u,1}&=\left(\sum_{v\in to[u]}{f_{v,0}}\right)+w[u] \end{align*} \]

为了方便书写,我们定义 \(0\) 号节点为一个虚点,是所有叶子节点的唯一“儿子”。其 dp 值 \(f_{u,0}=0\)\(f_{u,1}=-\infty\)

\(v_0\) 表示 \(u\) 的重儿子;分离重儿子和轻儿子:

\[ \begin{align*} f_{u,0}&=\left(\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})}\right)+\max(f_{v_0,0},f_{v_0,1})\\ f_{u,1}&=\left(\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\right)+w[u]+f_{v_0,0} \end{align*} \]

树上动态 DP 的核心思想就是把所有轻儿子的信息浓缩为一个矩阵,将这个矩阵与重儿子的 dp 向量相乘,就能得到当前节点的 dp 向量。根据转移方程,我们定义此处的矩阵乘法为 \(\left<\max,+\right>\) 的广义矩阵乘法。

\[ \begin{align*} g_{u,0}&=\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\\ g_{u,1}&=\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})} \end{align*} \]

则转移式可以写成

\[ \begin{align*} f_{u,0}&=g_{u,1}+\max(f_{v_0,0},f_{v_0,1})\\ f_{u,1}&=g_{u,0}+w[u]+f_{v_0,0} \end{align*} \]

写成矩阵乘法的形式:

\[ \left[ \begin{matrix} f_{u,0}\\ f_{u,1} \end{matrix} \right]= \left[ \begin{matrix} g_{u,1}& g_{u,1}\\ g_{u,0}+w[u]& -\infty \end{matrix} \right] \left[ \begin{matrix} f_{v_0,0}\\ f_{v_0,1} \end{matrix} \right] \]

由于 \(g\) 只和轻儿子有关,因此每个节点的转移矩阵也就和重儿子无关。使用线段树维护每一条重链上转移矩阵的乘积。对于节点 \(u\),记 \(\operatorname{bot}[u]\) 为其链底节点,我们用线段树查询 \([u,\operatorname{bot}[u]]\) 中节点的转移矩阵的乘积 \(op\),将 \(op\) 左乘 \(0\) 号节点的 dp 向量,即可得到 \(u\) 节点的 dp 向量。

修改时,先修改 \(x\) 节点的权值 \(w[x]\) 和其转移矩阵;然后从下往上考虑 \(x\) 到根节点经过的所有轻边 \(fa[x']\rightarrow x'\),使用线段树查询 \(x'\) 的 dp 向量,然后更新 \(fa[x']\) 的转移矩阵即可。

注意矩阵乘法的左右顺序

矩阵乘法没有交换律,使用线段树维护时一定要注意乘法顺序。可以根据

\[ (AB)^{\operatorname{T}}=B^{\operatorname{T}}A^{\operatorname{T}} \]

转置矩阵并交换顺序。

代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

struct Matrix {
    int a[2][2];
    inline Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = 0; }
    inline int* operator[](int index) { return a[index]; }
    inline const int* operator[](int index) const { return a[index]; }
    inline Matrix operator*(const Matrix& b) const {
        Matrix res;
        res[0][0] = max(a[0][0] + b[0][0], a[0][1] + b[1][0]);
        res[0][1] = max(a[0][0] + b[0][1], a[0][1] + b[1][1]);
        res[1][0] = max(a[1][0] + b[0][0], a[1][1] + b[1][0]);
        res[1][1] = max(a[1][0] + b[0][1], a[1][1] + b[1][1]);
        return res;
    }
};

Matrix a[N];

namespace SegT {
    Matrix tr[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        tr[p] = tr[lc(p)] * tr[rc(p)];
    }
    void build(int p, int l, int r) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void update(int p, int l, int r, int q) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        if(q <= mid) update(lc(p), l, mid, q);
        else update(rc(p), mid + 1, r, q);
        push_up(p);
    }
    Matrix query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return tr[p];
        }
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return query(lc(p), l, mid, ql, qr) * query(rc(p), mid + 1, r, ql, qr);
    }
}

int n, m;
int w[N];

int sz[N], son[N], fa[N];
int dfn[N], id[N], top[N], bot[N], dt;
int f[N][2], g[N][2];

void get_sz(int u, int a0) {
    sz[u] = 1;
    fa[u] = a0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0) continue;
        get_sz(v, u);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

void init(int u, int a0, int tp) {
    dfn[u] = ++dt;
    id[dt] = u;
    top[u] = tp;
    if(!son[u]) {
        f[u][0] = 0;
        f[u][1] = w[u];
        Matrix& cur = a[dfn[u]];
        cur[1][0] = w[u];
        cur[1][1] = -INF;
        bot[u] = u;
        return;
    }
    init(son[u], u, tp);
    bot[u] = bot[son[u]];
    f[u][0] += max(f[son[u]][0], f[son[u]][1]);
    f[u][1] += f[son[u]][0] + w[u];
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0 || v == son[u]) continue;
        init(v, u, v);
        g[u][0] += f[v][0];
        g[u][1] += max(f[v][0], f[v][1]);
    }
    f[u][0] += g[u][1];
    f[u][1] += g[u][0];
    Matrix& cur = a[dfn[u]];
    cur[0][0] = cur[0][1] = g[u][1];
    cur[1][0] = g[u][0] + w[u];
    cur[1][1] = -INF;
}

void modify(int x, int y) {
    a[dfn[x]][1][0] += y - w[x];
    w[x] = y;
    SegT::update(1, 1, n, dfn[x]);
    x = top[x];
    while(x != 1) {
        Matrix cur = SegT::query(1, 1, n, dfn[x], dfn[bot[x]]);
        a[dfn[fa[x]]][0][0] -= max(f[x][0], f[x][1]);
        a[dfn[fa[x]]][1][0] -= f[x][0];
        f[x][0] = cur[0][0];
        f[x][1] = cur[1][0];
        a[dfn[fa[x]]][0][0] += max(f[x][0], f[x][1]);
        a[dfn[fa[x]]][0][1] = a[dfn[fa[x]]][0][0];
        a[dfn[fa[x]]][1][0] += f[x][0];
        SegT::update(1, 1, n, dfn[fa[x]]);
        x = top[fa[x]];
    }
}

int query() {
    Matrix rt = SegT::query(1, 1, n, dfn[1], dfn[bot[1]]);
    return max(rt[0][0], rt[1][0]);
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    get_sz(1, 0);
    init(1, 0, 1);
    SegT::build(1, 1, n);

    while(m--) {
        int x, y;
        cin >> x >> y;
        modify(x, y);
        cout << query() << '\n';
    }

    return 0;
}

P5024 [NOIP 2018 提高组] 保卫王国

题意

给定一棵 \(n\) 个点的树,点有点权,有 \(m\) 次询问;每次询问给定 \((a,x,b,y)\),保证 \(x,y\in \{0,1\}\),表示钦定节点 \(a,b\) 选或不选,树的最小权点覆盖。

\(n,m\le 10^5\)

注意到,对于一次询问,我们只需要修改 \(a,b\) 的权值为 \(+\infty\)\(-\infty\),然后查询全局最小点覆盖即可。

代码
#include<cstdio>
#include<ctype.h>
#define ll long long
using namespace std;
const int N = 1e5 + 10;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll INF2 = 0x3f3f3f3f3f3f3f;

template<typename _Tp>
inline const _Tp& min(const _Tp& a, const _Tp& b) {
    if(a < b) return a;
    return b;
}

template<typename _Tp>
inline const _Tp& max(const _Tp& a, const _Tp& b) {
    if(a < b) return b;
    return a;
}

struct my_istream {
    template<typename _Tp>
    my_istream& operator>>(_Tp& x) {
        char ch;
        while(!isdigit(ch = getchar_unlocked()));
        x = ch - 48;
        while(isdigit(ch = getchar_unlocked())) x = x * 10 + ch - 48;
        return *this;
    }
} cin;

struct my_ostream {
    char buf[60], nb;
    inline my_ostream() { nb = 0; }
    my_ostream& operator<<(int x) {
        while(x) buf[++nb] = x % 10, x /= 10;
        while(nb) putchar(buf[nb--] + 48);
        return *this;
    }
    my_ostream& operator<<(ll x) {
        while(x) buf[++nb] = x % 10, x /= 10;
        while(nb) putchar(buf[nb--] + 48);
        return *this;
    }
    my_ostream& operator<<(const char* s) {
        while(*s) putchar(*(s++));
        return *this;
    }
} cout;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

struct Matrix {
    ll a[2][2];
    inline ll* operator[](int index) { return a[index]; }
    inline const ll* operator[](int index) const { return a[index]; }
    inline Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = 0; }
    inline Matrix operator*(const Matrix& b) const {
        Matrix res;
        res[0][0] = min(a[0][0] + b[0][0], a[0][1] + b[1][0]);
        res[0][1] = min(a[0][0] + b[0][1], a[0][1] + b[1][1]);
        res[1][0] = min(a[1][0] + b[0][0], a[1][1] + b[1][0]);
        res[1][1] = min(a[1][0] + b[0][1], a[1][1] + b[1][1]);
        return res;
    }
};

int n, m;
ll w[N];

int sz[N], son[N], top[N], bot[N], fa[N];
int dfn[N], dt;

ll f[N][2], g[N][2];

Matrix a[N];

namespace SegT {
    Matrix tr[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        tr[p] = tr[lc(p)] * tr[rc(p)];
    }
    void build(int p, int l, int r) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void update(int p, int l, int r, int q) {
        if(l == r) {
            tr[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        if(mid >= q) update(lc(p), l, mid, q);
        else update(rc(p), mid + 1, r, q);
        push_up(p);
    }
    void query(int p, int l, int r, int ql, int qr, Matrix& res) {
        if(ql <= l && r <= qr) {
            res = res * tr[p];
            return;
        }
        int mid = (l + r) >> 1;
        if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
        if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
    }
}

void init1(int u, int a0) {
    sz[u] = 1;
    fa[u] = a0;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0) continue;
        init1(v, u);
        if(sz[v] > sz[son[u]]) son[u] = v;
        sz[u] += sz[v];
    }
}

void init2(int u, int a0, int tp) {
    dfn[u] = ++dt;
    top[u] = tp;
    if(son[u]) init2(son[u], u, tp), bot[u] = bot[son[u]];
    else bot[u] = u;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0 || v == son[u]) continue;
        init2(v, u, v);
        g[u][0] += min(f[v][0], f[v][1]);
        g[u][1] += f[v][1];
    }
    f[u][0] = g[u][1] + f[son[u]][1];
    f[u][1] = g[u][0] + min(f[son[u]][0], f[son[u]][1]) + w[u];
    Matrix &cur = a[dfn[u]];
    cur[0][0] = INF;
    cur[0][1] = g[u][1];
    cur[1][0] = g[u][0] + w[u];
    cur[1][1] = g[u][0] + w[u];
}

void update(int p) {
    {   
        Matrix &cur = a[dfn[p]];
        cur[1][0] = g[p][0] + w[p];
        cur[1][1] = g[p][0] + w[p];
        SegT::update(1, 1, n, dfn[p]);
    }
    p = top[p];
    while(p != 1) {
        int ff = fa[p];
        Matrix cur;
        cur[0][1] = cur[1][0] = INF;
        SegT::query(1, 1, n, dfn[p], dfn[bot[p]], cur);
        g[ff][0] -= min(f[p][0], f[p][1]);
        g[ff][1] -= f[p][1];
        f[p][0] = cur[0][1];
        f[p][1] = cur[1][1];
        g[ff][0] += min(f[p][0], f[p][1]);
        g[ff][1] += f[p][1];
        Matrix &fm = a[dfn[ff]];
        fm[0][1] = g[ff][1];
        fm[1][0] = g[ff][0] + w[ff];
        fm[1][1] = g[ff][0] + w[ff];
        SegT::update(1, 1, n, dfn[ff]);
        p = top[fa[p]];
    }
}

ll query() {
    Matrix cur;
    SegT::query(1, 1, n, dfn[1], dfn[bot[1]], cur);
    return cur[1][1];
}

int main() {

    cin >> n >> m;
    while(getchar() != '\n');
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    init1(1, 0);
    init2(1, 0, 1);

    SegT::build(1, 1, n);

    while(m--) {
        int a, x, b, y;
        ll ta, tb, c = 0;
        cin >> a >> x >> b >> y;
        ta = w[a], tb = w[b];
        if(x == 0) w[a] = INF2;
        else c += w[a] + INF2, w[a] = -INF2;
        if(y == 0) w[b] = INF2;
        else c += w[b] + INF2, w[b] = -INF2;
        update(a);
        update(b);
        ll res = query() + c;
        if(res >= INF2) {
            cout << "-1\n";
        } else cout << res << "\n";
        w[a] = ta, w[b] = tb;
        update(a);
        update(b);
    }

    return 0;
}

/*
5 3 
2 4 1 3 9 
1 5 
5 2 
5 3 
3 4 
2 1 3 1 
1 0 3 0 
1 0 5 0

*/