跳转至

Y52A 最小代价

题目大意

\(n\) 个数字(\(n\le 50\)),第 \(i\) 个数字为 \(w_i\),每次操作你可以选择一段连续的区间 \([l,r]\),以

\[ a+b(\max_{i\in[l,r]}\{w_i\}-\min_{i\in[l,r]}\{w_i\}) \]

的代价将此区间内的所有数字删除,删除后该区间的左右两边即变为相邻。问将所有数字都删完的最小代价是多少。

我们注意到 \(n\) 非常小,考虑 DP。我们设 \(f_{i,j}\) 表示将区间 \([i,j]\) 里的所有数字都删除的最小代价。然而我们发现,对于一段区间,我们可以先删除区间中的若干部分,再一次把剩下的都删除。这种情况无法仅通过 \(f\) 进行简单的转移。

为了处理这种情况,我们改变我们原先的状态,\([i,j]\) 删除之后仍然保留部分元素供后面的区间使用,然后使用新状态转移到 \(f\) 状态。

我们考虑区间 \([i,j-1]\)\([i,j]\) 的转移。对于 \(w_j\),它有几种情况:

  • 自己单独删除;
  • 自己和前面的区间 \([i,j-1]\) 中的一些元素一起删除(在区间内自己解决 \(j\) 的问题);
  • 先保留,以后和区间 \(j\) 右边的元素一起删除;

为了方便讨论和计算,将第一种情况归于第二种情况,即 \([i,j-1]\) 中没有元素和 \(j\) 一起删除。

我们先进行一些简单的分析:

  • 对于第二种情况,设 \([i,j-1]\) 中和 \(j\) 一起删除的所有元素中,最靠左的为 \(k\),此情况就等价于将 \([k,j]\) 区间内的所有元素全部删除(\(f_{k,j}\))。

我们现在想要知道,如何记录保留了哪些元素,以及如何对保留的元素进行转移。

我们注意到,删除区间的代价仅和区间的两个最值有关。我们可以将最大值和最小值纳入状态。具体的,我们设 \(g_{i,j,mn,mx}\) 表示区间 \([i,j]\) 经过删除之后剩下了一些数字,最小的是 \(mn\),最大的是 \(mx\)

具体的,对于第二种情况:

\[ g_{i,j,mn,mx}=\min_{k\in[i,j-1]}\{g_{i,k,mn,mx}+f_{k+1,j}\} \]

这样转移,我们就可以无需考虑 \([i,k]\)\([k+1,j]\) 内部是如何处理的,只需要将它们拼接即可。

对于第三种情况:

\[ \begin{cases} \min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,w_j,mx} ,&w_j \in [1,mn]\\ \min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,mn,w_j} ,&w_j \in [mx,nc]\\ \min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,mn,mx} ,&w_j \in [mn,mx] \end{cases} \]

通过这种状态的设置,可以容易地写出 \(g\)\(f\) 的转移:

\[ f_{i,j}=\max_{mn,mx\in[1,nc]}\{g_{i,j,mn,mx}\} \]
代码
#include<iostream>
#include<cstring>
#include<unordered_map>
#include<algorithm>
using namespace std;
const int N = 55;
const int INF = 0x3f3f3f3f;

int n, a, b;
int w[N];
unordered_map<int, int> mp;
int num[N], nc;

int f[N][N], g[N][N][N][N];

int main() {

    cin >> n >> a >> b;
    for(int i = 1; i <= n; i++) {
        cin >> w[i];
        num[++nc] = w[i];
    }

    sort(num + 1, num + 1 + nc);
    nc = unique(num + 1, num + 1 + nc) - (num + 1);
    for(int i = 1; i <= nc; i++) {
        mp[num[i]] = i;
    }
    for(int i = 1; i <= n; i++) {
        w[i] = mp[w[i]];
    }

    memset(f, 0x3f, sizeof(f));
    memset(g, 0x3f, sizeof(g));

    for(int i = 1; i <= n; i++) {
        for(int j = i + 1; j <= n; j++) {
            f[j][i] = 0;
        }
    }

    for(int i = 1; i <= n; i++) {
        g[i][i][w[i]][w[i]] = 0;
        f[i][i] = g[i][i][0][0] = a;
    }

    for(int len = 2; len <= n; len++) {
        for(int i = 1; i + len - 1 <= n; i++) {
            int j = i + len - 1;
            for(int mn = w[j]; mn <= nc; mn++) {
                for(int mx = mn; mx <= nc; mx++) {
                    g[i][j][w[j]][mx] = min(g[i][j][w[j]][mx], g[i][j - 1][mn][mx]);
                }
            }
            for(int mx = w[j]; mx >= 1; mx--) {
                for(int mn = mx; mn >= 1; mn--) {
                    g[i][j][mn][w[j]] = min(g[i][j][mn][w[j]], g[i][j - 1][mn][mx]);
                }
            }
            for(int mn = 1; mn <= w[j]; mn++) {
                for(int mx = w[j]; mx <= nc; mx++) {
                    g[i][j][mn][mx] = min(g[i][j][mn][mx], g[i][j - 1][mn][mx]);
                }
            }
            for(int k = i + 1; k <= j; k++) {
                for(int mn = 1; mn <= nc; mn++) {
                    for(int mx = mn; mx <= nc; mx++) {
                        g[i][j][mn][mx] = min(g[i][j][mn][mx], g[i][k - 1][mn][mx] + f[k][j]);
                    }
                }
            }
            for(int mn = 1; mn <= nc; mn++) {
                for(int mx = mn; mx <= nc; mx++) {
                    f[i][j] = min(f[i][j], g[i][j][mn][mx] + a + b * (num[mx] - num[mn]) * (num[mx] - num[mn]));
                }
            }
        }
    }

    cout << f[1][n] << endl;

    return 0;
}