250423 D4 模拟赛 T3 题解
题意
给定一个长度为 \(n\) 的序列 \(a_1,a_2,\cdots,a_n\) 和一个正整数 \(C\)。你需要选择 \(a\) 的一个子序列 \(S\),最大化
\[
\left(\sum_{1\le l\le r\le n}[\forall i\in [l,r],\ i\in S]\right)C-\sum_{i\in S}a_i
\]
有 \(m\) 次询问,每次询问给定 \(x,y\),表示将 \(a_x\) 修改为 \(y\) 之后的答案。询问之间相互独立。
题解
对于子问题,我们设 \(f_i\) 表示前缀 \(i\) 且钦定不选 \(i\) 的答案,显然可以使用斜率优化 \(O(n)\) 求出答案。考虑如何带修。
我们考虑 \(a_x=y\) 时最优解的两种情况:\(a_x\in S\) 或 \(a_x\notin S\)。如果我们可以对每个 \(x\) 都求出:在初始情况下,钦定选择 \(x\) 的最优解和钦定不选 \(x\) 的最优解,由于其他元素的权值没有发生变化,可以立即得出询问的答案。
技巧
对于询问之间相互独立的情况,我们可以对每个元素 \(x\) 都求出钦定选择和钦定不选时的最优解,然后我们可以直接分讨修改操作之后是否选择 \(x\) 元素。
考虑钦定不选 \(x\) 的情况。我们可以再反着跑一遍斜率优化 dp,得到后缀的答案 \(g_x\);那么 \(f_x+g_x\) 就是钦定不选 \(x\) 的答案。
考虑钦定选择 \(x\) 的情况,记其答案为 \(h_x\)。我们有如下转移:
\[
h_{i+1\sim j-1}\leftarrow f_i+g_j+\frac{c}{2}(j-i-1)(j-i)-(s[j-1]-s[i])
\]
可以直接 \(O(n^2)\) 实现。由于我们固定了区间的一个端点 \(i\),再枚举另一个端点 \(j\) 就会产生很高的时间消耗,因为对于靠近 \(j\) 一侧的 \(h_x\),我们不能保证 \(i\) 总是最优的,导致被重复更新了很多次。
但是如果我们去掉 \(i\) 的限制,虽然保证了 \(i,j\) 转移的最优性,但是不能涵盖全部的 \(h_x\)。
考虑分治。在本层递归中,我们考虑 \(i\in [l,mid],\ j\in [mid+1,r]\) 的区间。我们可以在推进右端点 \(j\) 的同时保证 \(i\) 总是 \([l,mid]\) 中最优的一个。然而,这一次我们只能更新 \(h_{mid+1\sim r-1}\),左半区间的最优性和全面性无法保证。因此我们再反着跑一遍,更新左半区间的 \(h\) 即可。
时间复杂度 \(O(n\log n)\)。
注意函数冲突问题
有多个斜率优化时,注意区分不同的 Y
和 K
。注意它们共同使用的资源是否冲突。
AC 代码
| #include<iostream>
#include<algorithm>
#include<cassert>
#define ld long double
#define int long long
using namespace std;
const int N = 3e5 + 10;
const int INF = 0x003f3f3f3f3f3f3f;
const ld eps = 1e-7;
int n, m, c;
int a[N], s[N];
int f[N], g[N], h[N];
int sta[N], top;
namespace w1 {
inline int X(int i) { return i; }
inline ld Y1(int i) { return f[i] + s[i] + (ld)c / 2 * (i * i + i); }
inline ld K1(int i1, int i2) { return (ld)(Y1(i2) - Y1(i1)) / (X(i2) - X(i1)); }
void calc1() {
for(int i = 1; i <= n + 1; i++) s[i] = s[i - 1] + a[i];
top = 0;
sta[++top] = 0;
for(int i = 1; i <= n + 1; i++) {
while(top > 1 && K1(sta[top - 1], sta[top]) <= c * i) --top;
int j = sta[top];
f[i] = f[j] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
while(top > 1 && K1(sta[top - 1], sta[top]) <= K1(sta[top], i)) --top;
sta[++top] = i;
}
}
inline ld Y2(int i) { return g[i] + s[i] + (ld)c / 2 * (i * i + i); }
inline ld K2(int i1, int i2) { return (ld)(Y2(i2) - Y2(i1)) / (X(i2) - X(i1)); }
void calc2() {
reverse(a + 1, a + 1 + n);
for(int i = 1; i <= n + 1; i++) s[i] = s[i - 1] + a[i];
top = 0;
sta[++top] = 0;
for(int i = 1; i <= n + 1; i++) {
while(top > 1 && K2(sta[top - 1], sta[top]) <= c * i) --top;
int j = sta[top];
g[i] = g[j] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
while(top > 1 && K2(sta[top - 1], sta[top]) <= K2(sta[top], i)) --top;
sta[++top] = i;
}
reverse(g, g + 2 + n);
reverse(a + 1, a + 1 + n);
}
}
int t[N];
inline int X(int i) { return i; }
inline ld Y1(int i) { return g[i] - s[i - 1] + ((ld)c / 2) * (i * i - i); }
inline ld K1(int i1, int i2) { return (ld)(Y1(i2) - Y1(i1)) / (X(i2) - X(i1)); }
inline ld Y2(int i) { return f[i] + s[i] + ((ld)c / 2) * (i * i + i); }
inline ld K2(int i1, int i2) { return (ld)(Y2(i2) - Y2(i1)) / (X(i2) - X(i1)); }
void solve(int l, int r) {
if(r - l <= 1) {
return;
}
int mid = (l + r) >> 1;
top = 0;
sta[++top] = r;
for(int i = r - 1; i >= mid + 1; i--) {
while(top > 1 && K1(i, sta[top]) <= K1(sta[top], sta[top - 1])) --top;
sta[++top] = i;
}
for(int i = mid - 1; i >= l; i--) {
while(top > 1 && K1(sta[top], sta[top - 1]) >= c * i) --top;
int j = sta[top];
t[i + 1] = f[i] + g[j] - (s[j - 1] - s[i]) + (j - i - 1) * (j - i) / 2 * c;
}
for(int i = l + 1; i <= mid; i++) {
t[i] = max(t[i], t[i - 1]);
h[i] = max(h[i], t[i]);
}
for(int i = l + 1; i <= mid; i++) t[i] = -INF;
top = 0;
sta[++top] = l;
for(int i = l + 1; i <= mid; i++) {
while(top > 1 && K2(sta[top - 1], sta[top]) <= K2(sta[top], i)) --top;
sta[++top] = i;
}
for(int i = mid + 2; i <= r; i++) {
while(top > 1 && K2(sta[top - 1], sta[top]) <= c * i) --top;
int j = sta[top];
t[i - 1] = f[j] + g[i] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
}
for(int i = r - 1; i >= mid + 1; i--) {
t[i] = max(t[i], t[i + 1]);
h[i] = max(h[i], t[i]);
}
for(int i = r - 1; i >= mid + 1; i--) t[i] = -INF;
solve(l, mid);
solve(mid + 1, r);
}
// #define O_J
signed main() {
#ifdef O_J
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
#endif
cin >> n >> m >> c;
for(int i = 1; i <= n; i++) cin >> a[i];
for(int i = 0; i <= n + 1; i++) h[i] = -INF;
for(int i = 0; i <= n + 1; i++) t[i] = -INF;
w1::calc1();
w1::calc2();
cout << f[n + 1] << '\n';
for(int i = 1; i <= n; i++) s[i] = s[i - 1] + a[i];
solve(0, n + 1);
for(int i = 1; i <= m; i++) {
int x, y;
cin >> x >> y;
cout << max(f[x] + g[x], h[x] + a[x] - y) << '\n';
}
return 0;
}
|