跳转至

250423 D4 模拟赛 T3 题解

题意

给定一个长度为 \(n\) 的序列 \(a_1,a_2,\cdots,a_n\) 和一个正整数 \(C\)。你需要选择 \(a\) 的一个子序列 \(S\),最大化

\[ \left(\sum_{1\le l\le r\le n}[\forall i\in [l,r],\ i\in S]\right)C-\sum_{i\in S}a_i \]

\(m\) 次询问,每次询问给定 \(x,y\),表示将 \(a_x\) 修改为 \(y\) 之后的答案。询问之间相互独立。

题解

对于子问题,我们设 \(f_i\) 表示前缀 \(i\) 且钦定不选 \(i\) 的答案,显然可以使用斜率优化 \(O(n)\) 求出答案。考虑如何带修。

我们考虑 \(a_x=y\) 时最优解的两种情况:\(a_x\in S\)\(a_x\notin S\)。如果我们可以对每个 \(x\) 都求出:在初始情况下,钦定选择 \(x\) 的最优解和钦定不选 \(x\) 的最优解,由于其他元素的权值没有发生变化,可以立即得出询问的答案。

技巧

对于询问之间相互独立的情况,我们可以对每个元素 \(x\) 都求出钦定选择和钦定不选时的最优解,然后我们可以直接分讨修改操作之后是否选择 \(x\) 元素。

考虑钦定不选 \(x\) 的情况。我们可以再反着跑一遍斜率优化 dp,得到后缀的答案 \(g_x\);那么 \(f_x+g_x\) 就是钦定不选 \(x\) 的答案。

考虑钦定选择 \(x\) 的情况,记其答案为 \(h_x\)。我们有如下转移:

\[ h_{i+1\sim j-1}\leftarrow f_i+g_j+\frac{c}{2}(j-i-1)(j-i)-(s[j-1]-s[i]) \]

可以直接 \(O(n^2)\) 实现。由于我们固定了区间的一个端点 \(i\),再枚举另一个端点 \(j\) 就会产生很高的时间消耗,因为对于靠近 \(j\) 一侧的 \(h_x\),我们不能保证 \(i\) 总是最优的,导致被重复更新了很多次。

但是如果我们去掉 \(i\) 的限制,虽然保证了 \(i,j\) 转移的最优性,但是不能涵盖全部的 \(h_x\)

考虑分治。在本层递归中,我们考虑 \(i\in [l,mid],\ j\in [mid+1,r]\) 的区间。我们可以在推进右端点 \(j\) 的同时保证 \(i\) 总是 \([l,mid]\) 中最优的一个。然而,这一次我们只能更新 \(h_{mid+1\sim r-1}\),左半区间的最优性和全面性无法保证。因此我们再反着跑一遍,更新左半区间的 \(h\) 即可。

时间复杂度 \(O(n\log n)\)

注意函数冲突问题

有多个斜率优化时,注意区分不同的 YK。注意它们共同使用的资源是否冲突。

AC 代码
#include<iostream>
#include<algorithm>
#include<cassert>
#define ld long double
#define int long long
using namespace std;
const int N = 3e5 + 10;
const int INF = 0x003f3f3f3f3f3f3f;
const ld eps = 1e-7;

int n, m, c;
int a[N], s[N];

int f[N], g[N], h[N];
int sta[N], top;

namespace w1 {

    inline int X(int i) { return i; }
    inline ld Y1(int i) { return f[i] + s[i] + (ld)c / 2 * (i * i + i); }

    inline ld K1(int i1, int i2) { return (ld)(Y1(i2) - Y1(i1)) / (X(i2) - X(i1)); }

    void calc1() {
        for(int i = 1; i <= n + 1; i++) s[i] = s[i - 1] + a[i];
        top = 0;
        sta[++top] = 0;
        for(int i = 1; i <= n + 1; i++) {
            while(top > 1 && K1(sta[top - 1], sta[top]) <= c * i) --top;
            int j = sta[top];
            f[i] = f[j] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
            while(top > 1 && K1(sta[top - 1], sta[top]) <= K1(sta[top], i)) --top;
            sta[++top] = i;
        }
    }

    inline ld Y2(int i) { return g[i] + s[i] + (ld)c / 2 * (i * i + i); }
    inline ld K2(int i1, int i2) { return (ld)(Y2(i2) - Y2(i1)) / (X(i2) - X(i1)); }

    void calc2() {
        reverse(a + 1, a + 1 + n);
        for(int i = 1; i <= n + 1; i++) s[i] = s[i - 1] + a[i];
        top = 0;
        sta[++top] = 0;
        for(int i = 1; i <= n + 1; i++) {
            while(top > 1 && K2(sta[top - 1], sta[top]) <= c * i) --top;
            int j = sta[top];
            g[i] = g[j] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
            while(top > 1 && K2(sta[top - 1], sta[top]) <= K2(sta[top], i)) --top;
            sta[++top] = i;
        }
        reverse(g, g + 2 + n);
        reverse(a + 1, a + 1 + n);
    }

}

int t[N];

inline int X(int i) { return i; }

inline ld Y1(int i) { return g[i] - s[i - 1] + ((ld)c / 2) * (i * i - i); }
inline ld K1(int i1, int i2) { return (ld)(Y1(i2) - Y1(i1)) / (X(i2) - X(i1)); }

inline ld Y2(int i) { return f[i] + s[i] + ((ld)c / 2) * (i * i + i); }
inline ld K2(int i1, int i2) { return (ld)(Y2(i2) - Y2(i1)) / (X(i2) - X(i1)); }

void solve(int l, int r) {
    if(r - l <= 1) {
        return;
    }
    int mid = (l + r) >> 1;
    top = 0;
    sta[++top] = r;
    for(int i = r - 1; i >= mid + 1; i--) {
        while(top > 1 && K1(i, sta[top]) <= K1(sta[top], sta[top - 1])) --top;
        sta[++top] = i;
    }
    for(int i = mid - 1; i >= l; i--) {
        while(top > 1 && K1(sta[top], sta[top - 1]) >= c * i) --top;
        int j = sta[top];
        t[i + 1] = f[i] + g[j] - (s[j - 1] - s[i]) + (j - i - 1) * (j - i) / 2 * c;
    }
    for(int i = l + 1; i <= mid; i++) {
        t[i] = max(t[i], t[i - 1]);
        h[i] = max(h[i], t[i]);
    }
    for(int i = l + 1; i <= mid; i++) t[i] = -INF;
    top = 0;
    sta[++top] = l;
    for(int i = l + 1; i <= mid; i++) {
        while(top > 1 && K2(sta[top - 1], sta[top]) <= K2(sta[top], i)) --top;
        sta[++top] = i;
    }
    for(int i = mid + 2; i <= r; i++) {
        while(top > 1 && K2(sta[top - 1], sta[top]) <= c * i) --top;
        int j = sta[top];
        t[i - 1] = f[j] + g[i] - (s[i - 1] - s[j]) + (i - j - 1) * (i - j) / 2 * c;
    }
    for(int i = r - 1; i >= mid + 1; i--) {
        t[i] = max(t[i], t[i + 1]);
        h[i] = max(h[i], t[i]);
    }
    for(int i = r - 1; i >= mid + 1; i--) t[i] = -INF;
    solve(l, mid);
    solve(mid + 1, r);
}

// #define O_J

signed main() {

    #ifdef O_J
    freopen("sequence.in", "r", stdin);
    freopen("sequence.out", "w", stdout);
    #endif

    cin >> n >> m >> c;
    for(int i = 1; i <= n; i++) cin >> a[i];

    for(int i = 0; i <= n + 1; i++) h[i] = -INF;
    for(int i = 0; i <= n + 1; i++) t[i] = -INF;

    w1::calc1();
    w1::calc2();

    cout << f[n + 1] << '\n';

    for(int i = 1; i <= n; i++) s[i] = s[i - 1] + a[i];

    solve(0, n + 1);

    for(int i = 1; i <= m; i++) {
        int x, y;
        cin >> x >> y;
        cout << max(f[x] + g[x], h[x] + a[x] - y) << '\n';
    }

    return 0;
}