跳转至

线段树优化 DP

线段树优化转移

有时,DP 的转移区间不具有单调性,或者历史 DP 值对现在的 DP值的贡献发生的变化较大,无法使用单调队列分离无关项,就可以使用线段树优化。

当一些产生贡献的项可以被刻画为区间修改时,则可以使用线段树优化。

P9871 [NOIP2023] 天天爱打卡

题目大意

\(n\) 天里,你可以选择一些天进行跑步打卡,不得连续跑步超过 \(k\) 天,每次跑步消耗 \(d\) 的能量值。

此外,有 \(m\) 个任务。对于第 \(i\) 个任务,若在第 \(x_i\) 时,你已经连续打卡了 \(y_i\) 天,就会获得 \(v_i\) 的能量值,问 \(n\) 天以后能量值最高是多少。

其中 \(n\le 10^9\)\(m\le 10^5\)

容易发现:开始 / 停止跑步的时间节点,一定是任务开始或是结束的时间点。这样,我们对跑步的天数进行离散化,时间复杂度中就不包含 \(n\) 了。

考虑 DP,设 \(f_i\) 表示从第一天开始,到第 \(i\) 个时间节点为止,最多可以获得多少能量。我们可以枚举:到第 \(i\) 个时间节点为止,已经连续跑步了多长时间。容易写出状态转移方程:

\[ f_i=\max_{num[j]\ge num[i]-k+1}\{g_j-(num[i]-num[j]+1)\times d+\sum_{p=1}^{m}{\big[[l_p,r_p]\subseteq[j,i]\big]\times v_p}\} \]

\(i\) 个时间节点不跑步的情况:

\[ f_i=f_{i-1} \]

考虑到时间节点相邻的情况(即:在第 \(j\) 天的前一天不能跑步;如果这样,连续跑步的时间可能超过 \(k\)),记 \(g_j\) 表示满足 \(num[k]<num[j]-1\) 的最大的 \(f_k\)。显然,由于 \(f_i\) 单调递增,

\[ g_j= \begin{cases} f_{j-2},&num[j]=num[j-1]+1,\\ f_{j-1},&num[j]>num[j-1]+1 \end{cases} \]

上面提到的这种暴力转移的时间复杂度达到了 \(O(n^3)\),需要优化。我们注意到 \([l_p,r_p]\subseteq [j,i]\) 貌似属于一种二维数点问题。我们借用扫描线的思想,将所有区间按 \(r_p\) 排序,每遍历到一个 \(i\),就把所有 \(r_p=i\) 的区间在数据结构的 \(l_p\) 处加上 \(v_p\) 的权值。

这样,时间复杂度就被优化为 \(O(n^2\log n)\),仍需进一步优化。

注意到,状态转移方程可以分解为和 \(i\) 有关的部分(\(-num[i]\times d\))以及和 \(j\) 有关的部分(\(g_j+(nun[j]-1)\times d+\sum_p v_p\))。其中后者的 $ \sum_p v_p$ 不易维护。但我们注意到,每次扫描线处理一个区间 \([l_p,r_p]\) 只对 \(j\le l_p\) 有贡献,可以被刻画为一种区间修改。区修+区查最大值 考虑线段树。

我们使用线段树维护 \(g_j+(num[j]-1)\times d+\sum_p v_p\) 的区间最值:

  • 新遍历到一个 \(i\),需要处理 \(\max\{\}\) 中新产生的的一项。我们分讨求出 \(g_i\),并将 \(g_i+(num[i]-1)\times d\) 单点修改到线段树的 \(i\) 位置。
  • 每次处理一个区间 \([l_p,r_p]\) 就在线段树上给区间 \([1,l_p]\)\(v_p\)
  • 查询时,先用 lower_bound 找到第一个满足 \(num[j]\ge num[i]-k+1\)\(j\),在线段树上查询 \([j,i]\) 的最大值,并加上 \(-num[i]\times d\) 去更新 \(f_i\)

技巧

  • 可以考虑能贡献到查询操作 的 修改操作 需满足什么条件,或能被修改操作 贡献到的 查询操作 需要满足什么条件
  • 区间的包含关系通常可以被刻画为二维数点问题。
代码
#include<iostream>
#include<algorithm>
#define int long long
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Range {
    int l, r, v;
    inline bool operator<(const Range &other) const {
        return r < other.r;
    }
};

int tp, T;
int n, m, k, d;
int f[2 * N];
int num[2 * N], nn;
Range a[N];

namespace Seg_T {

    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }

    int mx[8 * N], tag[8 * N];

    inline void push_up(int p) {
        mx[p] = max(mx[lc(p)], mx[rc(p)]);
    }
    inline void move_tag(int p, int tg) {
        mx[p] += tg;
        tag[p] += tg;
    }
    inline void push_down(int p) {
        if(!tag[p]) return;
        move_tag(lc(p), tag[p]);
        move_tag(rc(p), tag[p]);
        tag[p] = 0;
    }

    void build(int p, int l, int r) {
        if(l == r) {
            mx[p] = (num[l] - 1) * d;
            return;
        }
        int mid = (l + r) >> 1;
        tag[p] = mx[p] = 0;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void add(int p, int l, int r, int ql, int qr, int v) {
        if(ql <= l && r <= qr) {
            move_tag(p, v);
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= ql) add(lc(p), l, mid, ql, qr, v);
        if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    void modify(int p, int l, int r, int q, int v) {
        if(l == r) {
            mx[p] = max(mx[p], v);
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= q) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return mx[p];
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
    }

}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> tp >> T;
    while(T--) {
        cin >> n >> m >> k >> d;
        nn = 0;
        for(int i = 1; i <= m; i++) {
            int x, y, v;
            cin >> x >> y >> v;
            a[i] = {x - y + 1, x, v};
            num[++nn] = a[i].l;
            num[++nn] = a[i].r;
        }
        sort(a + 1, a + 1 + m);
        sort(num + 1, num + 1 + nn);
        nn = unique(num + 1, num + 1 + nn) - (num + 1);
        for(int i = 1; i <= m; i++) {
            a[i].l = lower_bound(num + 1, num + 1 + nn, a[i].l) - num;
            a[i].r = lower_bound(num + 1, num + 1 + nn, a[i].r) - num;
        }
        f[0] = 0;
        Seg_T::build(1, 1, nn + 5);
        for(int i = 1, j = 1; i <= nn; i++) {
            // 处理新的一项
            if(i != 1) {
                if(num[i] == num[i - 1] + 1) {
                    if(i > 2) Seg_T::modify(1, 1, nn + 5, i, f[i - 2] + (num[i] - 1) * d);
                } else {
                    Seg_T::modify(1, 1, nn + 5, i, f[i - 1] + (num[i] - 1) * d);
                }
            }
            // 扫描线
            while(j <= m && a[j].r == i) {
                Seg_T::add(1, 1, nn + 5, 1, a[j].l, a[j].v);
                j++;
            }
            // 更新 f[i]
            int pre = lower_bound(num + 1, num + 1 + nn, num[i] - k + 1) - num;
            f[i] = max(f[i - 1], Seg_T::query(1, 1, nn + 5, pre, i) - num[i] * d);
        }
        cout << f[nn] << '\n';
    }

    return 0;
}

P2605 [ZJOI2010] 基站选址

和天天爱打卡很相似,只是需要给 dp 数组多加一个维度 \(k\)。在本题中必须要把 \(k\) 放到外层循环,才能在线段树上把这一维压掉。否则就需要开 \(k\) 棵线段树,会 MLE。

代码
#include<iostream>
#include<cstring>
#include<algorithm>
#define int long long
using namespace std;
const int N = 2E4 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Range {
    int l, r, v;
    inline bool operator<(const Range &other) const {
        return r < other.r;
    }
} a[N];

namespace SegT {
    int mx[4 * N], tag[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        mx[p] = max(mx[lc(p)], mx[rc(p)]);
    }
    inline void move_tag(int p, int tg) {
        mx[p] += tg;
        tag[p] += tg;
    }
    inline void push_down(int p) {
        if(!tag[p]) return;
        move_tag(lc(p), tag[p]);
        move_tag(rc(p), tag[p]);
        tag[p] = 0;
    }
    void add(int p, int l, int r, int ql, int qr, int v) {
        if(ql <= l && r <= qr) {
            move_tag(p, v);
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= ql) add(lc(p), l, mid, ql, qr, v);
        if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    void modify(int p, int l, int r, int q, int v) {
        if(l == r) {
            mx[p] = v;
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= q) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return mx[p];
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
    }
}

int n, k;
int p[N], c[N], s[N], w[N], sc[N];
int f[N];

signed main() {

    cin >> n >> k;
    for(int i = 2; i <= n; i++) {
        cin >> p[i];
    }

    for(int i = 1; i <= n; i++) {
        cin >> c[i];
        sc[i] = sc[i - 1] + c[i];
    }
    for(int i = 1; i <= n; i++) {
        cin >> s[i];
    }
    for(int i = 1; i <= n; i++) {
        cin >> w[i];
    }

    for(int i = 1; i <= n; i++) {
        a[i].l = lower_bound(p + 1, p + 1 + n, p[i] - s[i]) - p;
        a[i].r = upper_bound(p + 1, p + 1 + n, p[i] + s[i]) - p - 1;
        a[i].v = w[i];
    }

    sort(a + 1, a + 1 + n);

    f[0] = 0;
    SegT::modify(1, 1, n, 2, -sc[1]);
    for(int i = 1, j = 1; i <= n; i++) {
        while(j <= n && a[j].r == i) {
            SegT::add(1, 1, n, 1, a[j].l, -a[j].v);
            j++;
        }
        f[i] = max(f[i - 1], SegT::query(1, 1, n, 1, i) + sc[i]);
        if(i < n - 1) SegT::modify(1, 1, n, i + 2, f[i] - sc[i + 1]);
    }

    cout << sc[n] - f[n] << endl;

    return 0;
}

线段树维护矩阵乘法

如果 DP 的转移过程可以被刻画为(普通 / 广义)矩阵乘法,得益于矩阵乘法具有结合律,可以使用线段树维护区间矩阵乘法的结果,从而做到 \(O(\log n)\) 查询区间答案

CF750E New Year and Old Subsequence

题目大意

定义一个数字串满足性质 nice 当且仅当:该串包含子序列 \(2017\),且不包含子序列 \(2016\)

定义一个数字串的 ugliness 为:该串至少删去几个字符,可以使得剩余串满足性质 nice;如果该串没有满足性质 nice 的子序列,则该串的 ugliness-1

给定一个长度为 \(n\) 的数字串 \(t\),和 \(q\) 次询问,每次询问给定一个区间 \([l,r]\),你需要回答 ugliness(t[l,r])

\(1\le n,q\le 2\times 10^5\)

考虑一个朴素的 DP。设 \(f_{0/1/2/3/4}\) 表示已经匹配出了 \(\emptyset/2/20/201/2017\),且不包含 \(2016\),至少需要删去几个字符。朴素的转移:

\[ \begin{align*} f'_0&=f_0+[t_i=2]\\ f'_1&=\min(f_1+[t_i=0],&f_0+[t_i\ne 2]\times \infty)\\ f'_2&=\min(f_2+[t_i=1],&f_1+[t_i\ne 0]\times \infty)\\ f'_3&=\min(f_3+[t_i=7]+[t_i=6],&f_2+[t_i\ne 1]\times \infty)\\ f'_4&=\min(f_4+[t_i=6],&f_3+[t_i\ne 7]\times \infty) \end{align*} \]

我们希望求出区间的答案。因此考虑将转移刻画为矩阵乘法,然后使用线段树解决。我们记状态矩阵

\[ \begin{bmatrix} f_0&f_1&f_2&f_3&f_4 \end{bmatrix} \]

容易写出转移矩阵:

\[ \left[ \begin{array}{lllll} [t_i=2]& [t_i\ne 2]\times \infty& & & &\\ & [t_i=0]& [t_i\ne 0]\times \infty& & &\\ & & [t_i=1]& [t_i\ne 1]\times \infty& &\\ & & & [t_i=7]+[t_i=6]& [t_i\ne 7]\times \infty& &\\ & & & & [t_i=6] \end{array} \right] \]

注意,因为转移的过程主要使用加法和 \(\min\),因此此处的矩阵乘法是指 加法-\(\min\)广义矩阵乘法。空白部分均为 \(+\infty\)

我们预处理出数字串每个位置所对应的矩阵,建立一棵线段树维护区间矩阵乘法的结果:

Matrix 结构体
struct Matrix {
    int a[5][5];
    inline Matrix() {
        memset(a, 0x3f, sizeof(a));
    }
    inline Matrix(int x) {
        memset(a, 0x3f, sizeof(a));
        a[0][0] = (x == 2);
        a[0][1] = (x != 2) * INF;
        a[1][1] = (x == 0);
        a[1][2] = (x != 0) * INF;
        a[2][2] = (x == 1);
        a[2][3] = (x != 1) * INF;
        a[3][3] = (x == 6) + (x == 7);
        a[3][4] = (x != 7) * INF;
        a[4][4] = (x == 6);
    }
    inline int* operator[](int index) {
        return a[index];
    }
    inline const int* operator[](int index) const {
        return a[index];
    }
};
建树
inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i][j] = INF;
            for(int k = 0; k < 5; k++) {
                res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
            }
        }
    }
}

int a[N];
Matrix tr[4 * N];

void build(int p, int l, int r) {
    if(l == r) {
        tr[p] = (Matrix){a[l]};
        return;
    }
    int mid = (l + r) >> 1;
    build(lc(p), l, mid);
    build(rc(p), mid + 1, r);
    mul(tr[lc(p)], tr[rc(p)], tr[p]);
}

建树时间复杂度 \(O(5^3n)\)

查询时,我们不必让线段树返回整个区间对应的转移矩阵。我们可以传入一个初始的状态矩阵:

\[ res= \begin{bmatrix} 0& \infty& \infty& \infty& \infty \end{bmatrix} \]

然后让线段树把每段对应的转移矩阵 \(op_i\) 按顺序乘到 \(res\) 上(\(res=res\times op_i\)),最后返回 \(res\)。这样做可以避免两个 \(5\times 5\) 的矩阵直接相乘,而是让 \(1\times 5\)\(5\times 5\) 的矩阵相乘,从而将单次查询的时间复杂度降低到 \(O(5^2\log n)\)

查询
inline vector<int> mul(const vector<int> &a, const Matrix &b) {
    vector<int> res(5, INF);
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i] = min(res[i], a[j] + b[j][i]);
        }
    }
    return res;
}

void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
    if(ql <= l && r <= qr) {
        res = mul(res, tr[p]);
        return;
    }
    int mid = (l + r) >> 1;
    if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
    if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}

查询操作的技巧

维护矩阵乘法时,查询操作不返回一个完整的转移矩阵,而是返回一个一维的向量,这时一种很常见的优化。有些场景的时间复杂度高度依赖于这种优化。

代码
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int N = 2E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

inline int lc(int x) { return x << 1; };
inline int rc(int x) { return x << 1 | 1; }

struct Matrix {
    int a[5][5];
    inline Matrix() {
        memset(a, 0x3f, sizeof(a));
    }
    inline Matrix(int x) {
        memset(a, 0x3f, sizeof(a));
        a[0][0] = (x == 2);
        a[0][1] = (x != 2) * INF;
        a[1][1] = (x == 0);
        a[1][2] = (x != 0) * INF;
        a[2][2] = (x == 1);
        a[2][3] = (x != 1) * INF;
        a[3][3] = (x == 6) + (x == 7);
        a[3][4] = (x != 7) * INF;
        a[4][4] = (x == 6);
    }
    inline int* operator[](int index) {
        return a[index];
    }
    inline const int* operator[](int index) const {
        return a[index];
    }
};

inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i][j] = INF;
            for(int k = 0; k < 5; k++) {
                res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
            }
        }
    }
}

inline vector<int> mul(const vector<int> &a, const Matrix &b) {
    vector<int> res(5, INF);
    for(int i = 0; i < 5; i++) {
        for(int j = 0; j < 5; j++) {
            res[i] = min(res[i], a[j] + b[j][i]);
        }
    }
    return res;
}

int n, q;
int a[N];
Matrix tr[4 * N];

void build(int p, int l, int r) {
    if(l == r) {
        tr[p] = (Matrix){a[l]};
        return;
    }
    int mid = (l + r) >> 1;
    build(lc(p), l, mid);
    build(rc(p), mid + 1, r);
    mul(tr[lc(p)], tr[rc(p)], tr[p]);
}

void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
    if(ql <= l && r <= qr) {
        res = mul(res, tr[p]);
        return;
    }
    int mid = (l + r) >> 1;
    if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
    if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> q;
    for(int i = 1; i <= n; i++) {
        char c;
        cin >> c;
        a[i] = c - '0';
    }

    build(1, 1, n);

    while(q--) {
        int l, r;
        cin >> l >> r;
        vector<int> res({0, INF, INF, INF, INF});
        query(1, 1, n, l, r, res);
        if(res[4] > N) cout << -1 << '\n';
        else cout << res[4] << '\n';
    }

    return 0;
}