跳转至

线段树优化 DP

线段树优化转移

有时,DP 的转移区间不具有单调性,或者历史 DP 值对现在的 DP值的贡献发生的变化较大,无法使用单调队列分离无关项,就可以使用线段树优化。

当一些产生贡献的项可以被刻画为区间修改时,则可以使用线段树优化。

例题

P9871 [NOIP2023] 天天爱打卡

请参考典题题解

P2605 [ZJOI2010] 基站选址

和 天天爱打卡 很相似,只是需要给 dp 数组多加一个维度 \(k\)。在本题中必须要把 \(k\) 放到外层循环,才能在线段树上把这一维压掉。否则就需要开 \(k\) 棵线段树,会 MLE。

代码
#include<iostream>
#include<cstring>
#include<algorithm>
#define int long long
using namespace std;
const int N = 2E4 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Range {
    int l, r, v;
    inline bool operator<(const Range &other) const {
        return r < other.r;
    }
} a[N];

namespace SegT {
    int mx[4 * N], tag[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        mx[p] = max(mx[lc(p)], mx[rc(p)]);
    }
    inline void move_tag(int p, int tg) {
        mx[p] += tg;
        tag[p] += tg;
    }
    inline void push_down(int p) {
        if(!tag[p]) return;
        move_tag(lc(p), tag[p]);
        move_tag(rc(p), tag[p]);
        tag[p] = 0;
    }
    void add(int p, int l, int r, int ql, int qr, int v) {
        if(ql <= l && r <= qr) {
            move_tag(p, v);
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= ql) add(lc(p), l, mid, ql, qr, v);
        if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    void modify(int p, int l, int r, int q, int v) {
        if(l == r) {
            mx[p] = v;
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= q) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return mx[p];
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
    }
}

int n, k;
int p[N], c[N], s[N], w[N], sc[N];
int f[N];

signed main() {

    cin >> n >> k;
    for(int i = 2; i <= n; i++) {
        cin >> p[i];
    }

    for(int i = 1; i <= n; i++) {
        cin >> c[i];
        sc[i] = sc[i - 1] + c[i];
    }
    for(int i = 1; i <= n; i++) {
        cin >> s[i];
    }
    for(int i = 1; i <= n; i++) {
        cin >> w[i];
    }

    for(int i = 1; i <= n; i++) {
        a[i].l = lower_bound(p + 1, p + 1 + n, p[i] - s[i]) - p;
        a[i].r = upper_bound(p + 1, p + 1 + n, p[i] + s[i]) - p - 1;
        a[i].v = w[i];
    }

    sort(a + 1, a + 1 + n);

    f[0] = 0;
    SegT::modify(1, 1, n, 2, -sc[1]);
    for(int i = 1, j = 1; i <= n; i++) {
        while(j <= n && a[j].r == i) {
            SegT::add(1, 1, n, 1, a[j].l, -a[j].v);
            j++;
        }
        f[i] = max(f[i - 1], SegT::query(1, 1, n, 1, i) + sc[i]);
        if(i < n - 1) SegT::modify(1, 1, n, i + 2, f[i] - sc[i + 1]);
    }

    cout << sc[n] - f[n] << endl;

    return 0;
}

线段树维护矩阵乘法

如果 DP 的转移过程可以被刻画为(普通 / 广义)矩阵乘法,得益于矩阵乘法具有结合律,可以使用线段树维护区间矩阵乘法的结果,从而做到 \(O(\log n)\) 查询区间答案

例题

CF750E New Year and Old Subsequence

题目大意

定义一个数字串满足性质 nice 当且仅当:该串包含子序列 \(2017\),且不包含子序列 \(2016\)

定义一个数字串的 ugliness 为:该串至少删去几个字符,可以使得剩余串满足性质 nice;如果该串没有满足性质 nice 的子序列,则该串的 ugliness-1

给定一个长度为 \(n\) 的数字串 \(t\),和 \(q\) 次询问,每次询问给定一个区间 \([l,r]\),你需要回答 ugliness(t[l,r])

\(1\le n,q\le 2\times 10^5\)

考虑一个朴素的 DP。设 \(f_{0/1/2/3/4}\) 表示已经匹配出了 \(\emptyset/2/20/201/2017\),且不包含 \(2016\),至少需要删去几个字符。朴素的转移:

\[ \begin{align*} f'_0&=f_0+[t_i=2]\\ f'_1&=\min(f_1+[t_i=0],&f_0+[t_i\ne 2]\times \infty)\\ f'_2&=\min(f_2+[t_i=1],&f_1+[t_i\ne 0]\times \infty)\\ f'_3&=\min(f_3+[t_i=7]+[t_i=6],&f_2+[t_i\ne 1]\times \infty)\\ f'_4&=\min(f_4+[t_i=6],&f_3+[t_i\ne 7]\times \infty) \end{align*} \]

我们希望求出区间的答案。因此考虑将转移刻画为矩阵乘法,然后使用线段树解决。我们记状态矩阵

\[ \begin{bmatrix} f_0&f_1&f_2&f_3&f_4 \end{bmatrix} \]

容易写出转移矩阵:

\[ \left[ \begin{array}{lllll} [t_i=2]& [t_i\ne 2]\times \infty& & & &\\ & [t_i=0]& [t_i\ne 0]\times \infty& & &\\ & & [t_i=1]& [t_i\ne 1]\times \infty& &\\ & & & [t_i=7]+[t_i=6]& [t_i\ne 7]\times \infty& &\\ & & & & [t_i=6] \end{array} \right] \]

注意,因为转移的过程主要使用加法和 \(\min\),因此此处的矩阵乘法是指 加法-\(\min\)广义矩阵乘法。空白部分均为 \(+\infty\)

我们预处理出数字串每个位置所对应的矩阵,建立一棵线段树维护区间矩阵乘法的结果:

Matrix 结构体
struct Matrix {
    int a[5][5];
    inline Matrix() {
        memset(a, 0x3f, sizeof(a));
    }
    inline Matrix(int x) {
        memset(a, 0x3f, sizeof(a));
        a[0][0] = (x == 2);
        a[0][1] = (x != 2) * INF;
        a[1][1] = (x == 0);
        a[1][2] = (x != 0) * INF;
        a[2][2] = (x == 1);
        a[2][3] = (x != 1) * INF;
        a[3][3] = (x == 6) + (x == 7);
        a[3][4] = (x != 7) * INF;
        a[4][4] = (x == 6);
    }
    inline int* operator[](int index) {
        return a[index];
    }
    inline const int* operator[](int index) const {
        return a[index];
    }
};
建树
inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i][j] = INF;
            for(int k = 0; k < 5; k++) {
                res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
            }
        }
    }
}

int a[N];
Matrix tr[4 * N];

void build(int p, int l, int r) {
    if(l == r) {
        tr[p] = (Matrix){a[l]};
        return;
    }
    int mid = (l + r) >> 1;
    build(lc(p), l, mid);
    build(rc(p), mid + 1, r);
    mul(tr[lc(p)], tr[rc(p)], tr[p]);
}

建树时间复杂度 \(O(5^3n)\)

查询时,我们不必让线段树返回整个区间对应的转移矩阵。我们可以传入一个初始的状态矩阵:

\[ res= \begin{bmatrix} 0& \infty& \infty& \infty& \infty \end{bmatrix} \]

然后让线段树把每段对应的转移矩阵 \(op_i\) 按顺序乘到 \(res\) 上(\(res=res\times op_i\)),最后返回 \(res\)。这样做可以避免两个 \(5\times 5\) 的矩阵直接相乘,而是让 \(1\times 5\)\(5\times 5\) 的矩阵相乘,从而将单次查询的时间复杂度降低到 \(O(5^2\log n)\)

查询
inline vector<int> mul(const vector<int> &a, const Matrix &b) {
    vector<int> res(5, INF);
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i] = min(res[i], a[j] + b[j][i]);
        }
    }
    return res;
}

void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
    if(ql <= l && r <= qr) {
        res = mul(res, tr[p]);
        return;
    }
    int mid = (l + r) >> 1;
    if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
    if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}

查询操作的技巧

维护矩阵乘法时,查询操作不返回一个完整的转移矩阵,而是返回一个一维的向量,这时一种很常见的优化。有些场景的时间复杂度高度依赖于这种优化。

代码
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int N = 2E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

inline int lc(int x) { return x << 1; };
inline int rc(int x) { return x << 1 | 1; }

struct Matrix {
    int a[5][5];
    inline Matrix() {
        memset(a, 0x3f, sizeof(a));
    }
    inline Matrix(int x) {
        memset(a, 0x3f, sizeof(a));
        a[0][0] = (x == 2);
        a[0][1] = (x != 2) * INF;
        a[1][1] = (x == 0);
        a[1][2] = (x != 0) * INF;
        a[2][2] = (x == 1);
        a[2][3] = (x != 1) * INF;
        a[3][3] = (x == 6) + (x == 7);
        a[3][4] = (x != 7) * INF;
        a[4][4] = (x == 6);
    }
    inline int* operator[](int index) {
        return a[index];
    }
    inline const int* operator[](int index) const {
        return a[index];
    }
};

inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i][j] = INF;
            for(int k = 0; k < 5; k++) {
                res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
            }
        }
    }
}

inline vector<int> mul(const vector<int> &a, const Matrix &b) {
    vector<int> res(5, INF);
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i] = min(res[i], a[j] + b[j][i]);
        }
    }
    return res;
}

int n, q;
int a[N];
Matrix tr[4 * N];

void build(int p, int l, int r) {
    if(l == r) {
        tr[p] = (Matrix){a[l]};
        return;
    }
    int mid = (l + r) >> 1;
    build(lc(p), l, mid);
    build(rc(p), mid + 1, r);
    mul(tr[lc(p)], tr[rc(p)], tr[p]);
}

void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
    if(ql <= l && r <= qr) {
        res = mul(res, tr[p]);
        return;
    }
    int mid = (l + r) >> 1;
    if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
    if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> q;
    for(int i = 1; i <= n; i++) {
        char c;
        cin >> c;
        a[i] = c - '0';
    }

    build(1, 1, n);

    while(q--) {
        int l, r;
        cin >> l >> r;
        vector<int> res({0, INF, INF, INF, INF});
        query(1, 1, n, l, r, res);
        if(res[4] > N) cout << -1 << '\n';
        else cout << res[4] << '\n';
    }

    return 0;
}