题目大意
有 \(n\) 个数字(\(n\le 50\)),第 \(i\) 个数字为 \(w_i\),每次操作你可以选择一段连续的区间 \([l,r]\),以
\[
a+b(\max_{i\in[l,r]}\{w_i\}-\min_{i\in[l,r]}\{w_i\})
\]
的代价将此区间内的所有数字删除,删除后该区间的左右两边即变为相邻。问将所有数字都删完的最小代价是多少。
我们注意到 \(n\) 非常小,考虑 DP。我们设 \(f_{i,j}\) 表示将区间 \([i,j]\) 里的所有数字都删除的最小代价。然而我们发现,对于一段区间,我们可以先删除区间中的若干部分,再一次把剩下的都删除。这种情况无法仅通过 \(f\) 进行简单的转移。
为了处理这种情况,我们改变我们原先的状态,在 \([i,j]\) 删除之后仍然保留部分元素供后面的区间使用,然后使用新状态转移到 \(f\) 状态。
我们考虑区间 \([i,j-1]\) 向 \([i,j]\) 的转移。对于 \(w_j\),它有几种情况:
- 自己单独删除;
- 自己和前面的区间 \([i,j-1]\) 中的一些元素一起删除(在区间内自己解决 \(j\) 的问题);
- 先保留,以后和区间 \(j\) 右边的元素一起删除;
为了方便讨论和计算,将第一种情况归于第二种情况,即 \([i,j-1]\) 中没有元素和 \(j\) 一起删除。
我们先进行一些简单的分析:
- 对于第二种情况,设 \([i,j-1]\) 中和 \(j\) 一起删除的所有元素中,最靠左的为 \(k\),此情况就等价于将 \([k,j]\) 区间内的所有元素全部删除(\(f_{k,j}\))。
我们现在想要知道,如何记录保留了哪些元素,以及如何对保留的元素进行转移。
我们注意到,删除区间的代价仅和区间的两个最值有关。我们可以将最大值和最小值纳入状态。具体的,我们设 \(g_{i,j,mn,mx}\) 表示区间 \([i,j]\) 经过删除之后剩下了一些数字,最小的是 \(mn\),最大的是 \(mx\)。
具体的,对于第二种情况:
\[
g_{i,j,mn,mx}=\min_{k\in[i,j-1]}\{g_{i,k,mn,mx}+f_{k+1,j}\}
\]
这样转移,我们就可以无需考虑 \([i,k]\) 和 \([k+1,j]\) 内部是如何处理的,只需要将它们拼接即可。
对于第三种情况:
\[
\begin{cases}
\min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,w_j,mx} ,&w_j \in [1,mn]\\
\min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,mn,w_j} ,&w_j \in [mx,nc]\\
\min\{g_{i,j-1,mn,mx}\}\rightarrow g_{i,j,mn,mx} ,&w_j \in [mn,mx]
\end{cases}
\]
通过这种状态的设置,可以容易地写出 \(g\) 向 \(f\) 的转移:
\[
f_{i,j}=\max_{mn,mx\in[1,nc]}\{g_{i,j,mn,mx}\}
\]
代码
| #include<iostream>
#include<cstring>
#include<unordered_map>
#include<algorithm>
using namespace std;
const int N = 55;
const int INF = 0x3f3f3f3f;
int n, a, b;
int w[N];
unordered_map<int, int> mp;
int num[N], nc;
int f[N][N], g[N][N][N][N];
int main() {
cin >> n >> a >> b;
for(int i = 1; i <= n; i++) {
cin >> w[i];
num[++nc] = w[i];
}
sort(num + 1, num + 1 + nc);
nc = unique(num + 1, num + 1 + nc) - (num + 1);
for(int i = 1; i <= nc; i++) {
mp[num[i]] = i;
}
for(int i = 1; i <= n; i++) {
w[i] = mp[w[i]];
}
memset(f, 0x3f, sizeof(f));
memset(g, 0x3f, sizeof(g));
for(int i = 1; i <= n; i++) {
for(int j = i + 1; j <= n; j++) {
f[j][i] = 0;
}
}
for(int i = 1; i <= n; i++) {
g[i][i][w[i]][w[i]] = 0;
f[i][i] = g[i][i][0][0] = a;
}
for(int len = 2; len <= n; len++) {
for(int i = 1; i + len - 1 <= n; i++) {
int j = i + len - 1;
for(int mn = w[j]; mn <= nc; mn++) {
for(int mx = mn; mx <= nc; mx++) {
g[i][j][w[j]][mx] = min(g[i][j][w[j]][mx], g[i][j - 1][mn][mx]);
}
}
for(int mx = w[j]; mx >= 1; mx--) {
for(int mn = mx; mn >= 1; mn--) {
g[i][j][mn][w[j]] = min(g[i][j][mn][w[j]], g[i][j - 1][mn][mx]);
}
}
for(int mn = 1; mn <= w[j]; mn++) {
for(int mx = w[j]; mx <= nc; mx++) {
g[i][j][mn][mx] = min(g[i][j][mn][mx], g[i][j - 1][mn][mx]);
}
}
for(int k = i + 1; k <= j; k++) {
for(int mn = 1; mn <= nc; mn++) {
for(int mx = mn; mx <= nc; mx++) {
g[i][j][mn][mx] = min(g[i][j][mn][mx], g[i][k - 1][mn][mx] + f[k][j]);
}
}
}
for(int mn = 1; mn <= nc; mn++) {
for(int mx = mn; mx <= nc; mx++) {
f[i][j] = min(f[i][j], g[i][j][mn][mx] + a + b * (num[mx] - num[mn]) * (num[mx] - num[mn]));
}
}
}
}
cout << f[1][n] << endl;
return 0;
}
|