线段树优化 DP
线段树优化转移
有时,DP 的转移区间不具有单调性,或者历史 DP 值对现在的 DP值的贡献发生的变化较大,无法使用单调队列分离无关项,就可以使用线段树优化。
当一些产生贡献的项可以被刻画为区间修改时,则可以使用线段树优化。
题目大意
在 \(n\) 天里,你可以选择一些天进行跑步打卡,不得连续跑步超过 \(k\) 天,每次跑步消耗 \(d\) 的能量值。
此外,有 \(m\) 个任务。对于第 \(i\) 个任务,若在第 \(x_i\) 时,你已经连续打卡了 \(y_i\) 天,就会获得 \(v_i\) 的能量值,问 \(n\) 天以后能量值最高是多少。
其中 \(n\le 10^9\),\(m\le 10^5\)。
容易发现:开始 / 停止跑步的时间节点,一定是任务开始或是结束的时间点。这样,我们对跑步的天数进行离散化,时间复杂度中就不包含 \(n\) 了。
考虑 DP,设 \(f_i\) 表示从第一天开始,到第 \(i\) 个时间节点为止,最多可以获得多少能量。我们可以枚举:到第 \(i\) 个时间节点为止,已经连续跑步了多长时间。容易写出状态转移方程:
\[
f_i=\max_{num[j]\ge num[i]-k+1}\{g_j-(num[i]-num[j]+1)\times d+\sum_{p=1}^{m}{\big[[l_p,r_p]\subseteq[j,i]\big]\times v_p}\}
\]
第 \(i\) 个时间节点不跑步的情况:
\[
f_i=f_{i-1}
\]
考虑到时间节点相邻的情况(即:在第 \(j\) 天的前一天不能跑步;如果这样,连续跑步的时间可能超过 \(k\)),记 \(g_j\) 表示满足 \(num[k]<num[j]-1\) 的最大的 \(f_k\)。显然,由于 \(f_i\) 单调递增,
\[
g_j=
\begin{cases}
f_{j-2},&num[j]=num[j-1]+1,\\
f_{j-1},&num[j]>num[j-1]+1
\end{cases}
\]
上面提到的这种暴力转移的时间复杂度达到了 \(O(n^3)\),需要优化。我们注意到 \([l_p,r_p]\subseteq [j,i]\) 貌似属于一种二维数点问题。我们借用扫描线的思想,将所有区间按 \(r_p\) 排序,每遍历到一个 \(i\),就把所有 \(r_p=i\) 的区间在数据结构的 \(l_p\) 处加上 \(v_p\) 的权值。
这样,时间复杂度就被优化为 \(O(n^2\log n)\),仍需进一步优化。
注意到,状态转移方程可以分解为和 \(i\) 有关的部分(\(-num[i]\times d\))以及和 \(j\) 有关的部分(\(g_j+(nun[j]-1)\times d+\sum_p v_p\))。其中后者的 $
\sum_p v_p$ 不易维护。但我们注意到,每次扫描线处理一个区间 \([l_p,r_p]\) 只对 \(j\le l_p\) 有贡献,可以被刻画为一种区间修改。区修+区查最大值 考虑线段树。
我们使用线段树维护 \(g_j+(num[j]-1)\times d+\sum_p v_p\) 的区间最值:
- 新遍历到一个 \(i\),需要处理 \(\max\{\}\) 中新产生的的一项。我们分讨求出 \(g_i\),并将 \(g_i+(num[i]-1)\times d\) 单点修改到线段树的 \(i\) 位置。
- 每次处理一个区间 \([l_p,r_p]\) 就在线段树上给区间 \([1,l_p]\) 加 \(v_p\)。
- 查询时,先用
lower_bound
找到第一个满足 \(num[j]\ge num[i]-k+1\) 的 \(j\),在线段树上查询 \([j,i]\) 的最大值,并加上 \(-num[i]\times d\) 去更新 \(f_i\)。
技巧
- 可以考虑能贡献到查询操作 的 修改操作 需满足什么条件,或能被修改操作 贡献到的 查询操作 需要满足什么条件。
- 区间的包含关系通常可以被刻画为二维数点问题。
代码
| #include<iostream>
#include<algorithm>
#define int long long
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;
struct Range {
int l, r, v;
inline bool operator<(const Range &other) const {
return r < other.r;
}
};
int tp, T;
int n, m, k, d;
int f[2 * N];
int num[2 * N], nn;
Range a[N];
namespace Seg_T {
inline int lc(int x) { return x << 1; }
inline int rc(int x) { return x << 1 | 1; }
int mx[8 * N], tag[8 * N];
inline void push_up(int p) {
mx[p] = max(mx[lc(p)], mx[rc(p)]);
}
inline void move_tag(int p, int tg) {
mx[p] += tg;
tag[p] += tg;
}
inline void push_down(int p) {
if(!tag[p]) return;
move_tag(lc(p), tag[p]);
move_tag(rc(p), tag[p]);
tag[p] = 0;
}
void build(int p, int l, int r) {
if(l == r) {
mx[p] = (num[l] - 1) * d;
return;
}
int mid = (l + r) >> 1;
tag[p] = mx[p] = 0;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
push_up(p);
}
void add(int p, int l, int r, int ql, int qr, int v) {
if(ql <= l && r <= qr) {
move_tag(p, v);
return;
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= ql) add(lc(p), l, mid, ql, qr, v);
if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
push_up(p);
}
void modify(int p, int l, int r, int q, int v) {
if(l == r) {
mx[p] = max(mx[p], v);
return;
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= q) modify(lc(p), l, mid, q, v);
else modify(rc(p), mid + 1, r, q, v);
push_up(p);
}
int query(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) {
return mx[p];
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= qr) return query(lc(p), l, mid, ql, qr);
if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
}
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> tp >> T;
while(T--) {
cin >> n >> m >> k >> d;
nn = 0;
for(int i = 1; i <= m; i++) {
int x, y, v;
cin >> x >> y >> v;
a[i] = {x - y + 1, x, v};
num[++nn] = a[i].l;
num[++nn] = a[i].r;
}
sort(a + 1, a + 1 + m);
sort(num + 1, num + 1 + nn);
nn = unique(num + 1, num + 1 + nn) - (num + 1);
for(int i = 1; i <= m; i++) {
a[i].l = lower_bound(num + 1, num + 1 + nn, a[i].l) - num;
a[i].r = lower_bound(num + 1, num + 1 + nn, a[i].r) - num;
}
f[0] = 0;
Seg_T::build(1, 1, nn + 5);
for(int i = 1, j = 1; i <= nn; i++) {
// 处理新的一项
if(i != 1) {
if(num[i] == num[i - 1] + 1) {
if(i > 2) Seg_T::modify(1, 1, nn + 5, i, f[i - 2] + (num[i] - 1) * d);
} else {
Seg_T::modify(1, 1, nn + 5, i, f[i - 1] + (num[i] - 1) * d);
}
}
// 扫描线
while(j <= m && a[j].r == i) {
Seg_T::add(1, 1, nn + 5, 1, a[j].l, a[j].v);
j++;
}
// 更新 f[i]
int pre = lower_bound(num + 1, num + 1 + nn, num[i] - k + 1) - num;
f[i] = max(f[i - 1], Seg_T::query(1, 1, nn + 5, pre, i) - num[i] * d);
}
cout << f[nn] << '\n';
}
return 0;
}
|
和天天爱打卡很相似,只是需要给 dp 数组多加一个维度 \(k\)。在本题中必须要把 \(k\) 放到外层循环,才能在线段树上把这一维压掉。否则就需要开 \(k\) 棵线段树,会 MLE。
代码
| #include<iostream>
#include<cstring>
#include<algorithm>
#define int long long
using namespace std;
const int N = 2E4 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;
struct Range {
int l, r, v;
inline bool operator<(const Range &other) const {
return r < other.r;
}
} a[N];
namespace SegT {
int mx[4 * N], tag[4 * N];
inline int lc(int x) { return x << 1; }
inline int rc(int x) { return x << 1 | 1; }
inline void push_up(int p) {
mx[p] = max(mx[lc(p)], mx[rc(p)]);
}
inline void move_tag(int p, int tg) {
mx[p] += tg;
tag[p] += tg;
}
inline void push_down(int p) {
if(!tag[p]) return;
move_tag(lc(p), tag[p]);
move_tag(rc(p), tag[p]);
tag[p] = 0;
}
void add(int p, int l, int r, int ql, int qr, int v) {
if(ql <= l && r <= qr) {
move_tag(p, v);
return;
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= ql) add(lc(p), l, mid, ql, qr, v);
if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
push_up(p);
}
void modify(int p, int l, int r, int q, int v) {
if(l == r) {
mx[p] = v;
return;
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= q) modify(lc(p), l, mid, q, v);
else modify(rc(p), mid + 1, r, q, v);
push_up(p);
}
int query(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) {
return mx[p];
}
push_down(p);
int mid = (l + r) >> 1;
if(mid >= qr) return query(lc(p), l, mid, ql, qr);
if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
return max(query(lc(p), l, mid, ql, qr), query(rc(p), mid + 1, r, ql, qr));
}
}
int n, k;
int p[N], c[N], s[N], w[N], sc[N];
int f[N];
signed main() {
cin >> n >> k;
for(int i = 2; i <= n; i++) {
cin >> p[i];
}
for(int i = 1; i <= n; i++) {
cin >> c[i];
sc[i] = sc[i - 1] + c[i];
}
for(int i = 1; i <= n; i++) {
cin >> s[i];
}
for(int i = 1; i <= n; i++) {
cin >> w[i];
}
for(int i = 1; i <= n; i++) {
a[i].l = lower_bound(p + 1, p + 1 + n, p[i] - s[i]) - p;
a[i].r = upper_bound(p + 1, p + 1 + n, p[i] + s[i]) - p - 1;
a[i].v = w[i];
}
sort(a + 1, a + 1 + n);
f[0] = 0;
SegT::modify(1, 1, n, 2, -sc[1]);
for(int i = 1, j = 1; i <= n; i++) {
while(j <= n && a[j].r == i) {
SegT::add(1, 1, n, 1, a[j].l, -a[j].v);
j++;
}
f[i] = max(f[i - 1], SegT::query(1, 1, n, 1, i) + sc[i]);
if(i < n - 1) SegT::modify(1, 1, n, i + 2, f[i] - sc[i + 1]);
}
cout << sc[n] - f[n] << endl;
return 0;
}
|
线段树维护矩阵乘法
如果 DP 的转移过程可以被刻画为(普通 / 广义)矩阵乘法,得益于矩阵乘法具有结合律,可以使用线段树维护区间矩阵乘法的结果,从而做到 \(O(\log n)\) 查询区间答案。
题目大意
定义一个数字串满足性质 nice
当且仅当:该串包含子序列 \(2017\),且不包含子序列 \(2016\)。
定义一个数字串的 ugliness
为:该串至少删去几个字符,可以使得剩余串满足性质 nice
;如果该串没有满足性质 nice
的子序列,则该串的 ugliness
是 -1
。
给定一个长度为 \(n\) 的数字串 \(t\),和 \(q\) 次询问,每次询问给定一个区间 \([l,r]\),你需要回答 ugliness(t[l,r])
。
\(1\le n,q\le 2\times 10^5\)
考虑一个朴素的 DP。设 \(f_{0/1/2/3/4}\) 表示已经匹配出了 \(\emptyset/2/20/201/2017\),且不包含 \(2016\),至少需要删去几个字符。朴素的转移:
\[
\begin{align*}
f'_0&=f_0+[t_i=2]\\
f'_1&=\min(f_1+[t_i=0],&f_0+[t_i\ne 2]\times \infty)\\
f'_2&=\min(f_2+[t_i=1],&f_1+[t_i\ne 0]\times \infty)\\
f'_3&=\min(f_3+[t_i=7]+[t_i=6],&f_2+[t_i\ne 1]\times \infty)\\
f'_4&=\min(f_4+[t_i=6],&f_3+[t_i\ne 7]\times \infty)
\end{align*}
\]
我们希望求出区间的答案。因此考虑将转移刻画为矩阵乘法,然后使用线段树解决。我们记状态矩阵
\[
\begin{bmatrix}
f_0&f_1&f_2&f_3&f_4
\end{bmatrix}
\]
容易写出转移矩阵:
\[
\left[
\begin{array}{lllll}
[t_i=2]& [t_i\ne 2]\times \infty& & & &\\
& [t_i=0]& [t_i\ne 0]\times \infty& & &\\
& & [t_i=1]& [t_i\ne 1]\times \infty& &\\
& & & [t_i=7]+[t_i=6]& [t_i\ne 7]\times \infty& &\\
& & & & [t_i=6]
\end{array}
\right]
\]
注意,因为转移的过程主要使用加法和 \(\min\),因此此处的矩阵乘法是指 加法-\(\min\) 的广义矩阵乘法。空白部分均为 \(+\infty\)。
我们预处理出数字串每个位置所对应的矩阵,建立一棵线段树维护区间矩阵乘法的结果:
Matrix 结构体
| struct Matrix {
int a[5][5];
inline Matrix() {
memset(a, 0x3f, sizeof(a));
}
inline Matrix(int x) {
memset(a, 0x3f, sizeof(a));
a[0][0] = (x == 2);
a[0][1] = (x != 2) * INF;
a[1][1] = (x == 0);
a[1][2] = (x != 0) * INF;
a[2][2] = (x == 1);
a[2][3] = (x != 1) * INF;
a[3][3] = (x == 6) + (x == 7);
a[3][4] = (x != 7) * INF;
a[4][4] = (x == 6);
}
inline int* operator[](int index) {
return a[index];
}
inline const int* operator[](int index) const {
return a[index];
}
};
|
建树
| inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
for(int i = 0; i < 5; i++) {
for(int j = 0; j < 5; j++) {
res[i][j] = INF;
for(int k = 0; k < 5; k++) {
res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
}
}
}
}
int a[N];
Matrix tr[4 * N];
void build(int p, int l, int r) {
if(l == r) {
tr[p] = (Matrix){a[l]};
return;
}
int mid = (l + r) >> 1;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
mul(tr[lc(p)], tr[rc(p)], tr[p]);
}
|
建树时间复杂度 \(O(5^3n)\)。
查询时,我们不必让线段树返回整个区间对应的转移矩阵。我们可以传入一个初始的状态矩阵:
\[
res=
\begin{bmatrix}
0& \infty& \infty& \infty& \infty
\end{bmatrix}
\]
然后让线段树把每段对应的转移矩阵 \(op_i\) 按顺序乘到 \(res\) 上(\(res=res\times op_i\)),最后返回 \(res\)。这样做可以避免两个 \(5\times 5\) 的矩阵直接相乘,而是让 \(1\times 5\) 和 \(5\times 5\) 的矩阵相乘,从而将单次查询的时间复杂度降低到 \(O(5^2\log n)\)。
查询
| inline vector<int> mul(const vector<int> &a, const Matrix &b) {
vector<int> res(5, INF);
for(int i = 0; i < 5; i++) {
for(int j = 0; j < 5; j++) {
res[i] = min(res[i], a[j] + b[j][i]);
}
}
return res;
}
void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
if(ql <= l && r <= qr) {
res = mul(res, tr[p]);
return;
}
int mid = (l + r) >> 1;
if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}
|
查询操作的技巧
维护矩阵乘法时,查询操作不返回一个完整的转移矩阵,而是返回一个一维的向量,这时一种很常见的优化。有些场景的时间复杂度高度依赖于这种优化。
代码
| #include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int N = 2E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;
inline int lc(int x) { return x << 1; };
inline int rc(int x) { return x << 1 | 1; }
struct Matrix {
int a[5][5];
inline Matrix() {
memset(a, 0x3f, sizeof(a));
}
inline Matrix(int x) {
memset(a, 0x3f, sizeof(a));
a[0][0] = (x == 2);
a[0][1] = (x != 2) * INF;
a[1][1] = (x == 0);
a[1][2] = (x != 0) * INF;
a[2][2] = (x == 1);
a[2][3] = (x != 1) * INF;
a[3][3] = (x == 6) + (x == 7);
a[3][4] = (x != 7) * INF;
a[4][4] = (x == 6);
}
inline int* operator[](int index) {
return a[index];
}
inline const int* operator[](int index) const {
return a[index];
}
};
inline void mul(const Matrix &a, const Matrix &b, Matrix &res) {
for(int i = 0; i < 5; i++) {
for(int j = 0; j < 5; j++) {
res[i][j] = INF;
for(int k = 0; k < 5; k++) {
res[i][j] = min(res[i][j], a[i][k] + b[k][j]);
}
}
}
}
inline vector<int> mul(const vector<int> &a, const Matrix &b) {
vector<int> res(5, INF);
for(int i = 0; i < 5; i++) {
for(int j = 0; j < 5; j++) {
res[i] = min(res[i], a[j] + b[j][i]);
}
}
return res;
}
int n, q;
int a[N];
Matrix tr[4 * N];
void build(int p, int l, int r) {
if(l == r) {
tr[p] = (Matrix){a[l]};
return;
}
int mid = (l + r) >> 1;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
mul(tr[lc(p)], tr[rc(p)], tr[p]);
}
void query(int p, int l, int r, int ql, int qr, vector<int> &res) {
if(ql <= l && r <= qr) {
res = mul(res, tr[p]);
return;
}
int mid = (l + r) >> 1;
if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> q;
for(int i = 1; i <= n; i++) {
char c;
cin >> c;
a[i] = c - '0';
}
build(1, 1, n);
while(q--) {
int l, r;
cin >> l >> r;
vector<int> res({0, INF, INF, INF, INF});
query(1, 1, n, l, r, res);
if(res[4] > N) cout << -1 << '\n';
else cout << res[4] << '\n';
}
return 0;
}
|