四边形不等式优化 DP
四边形不等式
对于定义在 \(\N^2\) 上的二元函数 \(w(x,y)\),若 \(\forall a\le b\le c\le d\)(\(a,b,c,d\in \N\)),都有
\[
w(a,c)+w(b,d)\le w(a,d)+w(b,c)
\]
则称 \(w(x,y)\) 满足四边形不等式。
性质
性质 1:\(w(x,y)\) 满足四边形不等式的充要条件是:\(\forall x<y\),有
\[
w(x,y)+w(x+1,y+1)\le w(x,y+1)+w(x+1,y)\tag{1.1}
\]
推论 1:\(w(x,y)\) 满足四边形不等式的充要条件是 \(w(x,y)\) 的二阶混合差分始终非正。
性质 2:若函数 \(w(x,y)\) 满足四边形不等式,则 \(f(x)+w(x,y)\) 也满足四边形不等式。
性质 3:设 \(h(u)\) 是一个单调增加的凸函数,若函数 \(w(x,y)\) 满足四边形不等式且具有区间包含单调性,则复合函数 \(h(w(x,y))\) 也满足四边形不等式和区间包含单调性。
性质 4:设 \(h(u)\) 是一个凸函数,若函数 \(w(x,y)\) 满足四边形恒等式且具有区间包含单调性,则复合函数 \(h(w(x,y))\) 也满足四边形不等式。
证明
性质 1
考虑归纳法。我们现在要证明:对于所有 \(k_1\) 和 \(x+k_1\le y\) 都有
\[
w(x,y)+w(x+k_1,y+1)\le w(x,y+1)+w(x+k_1,y) \tag{1.2}
\]
然后对于所有 \(k_2\) 再证明
\[
w(x,y)+w(x+k_1,y+k_2)\le w(x,y+k_2)+w(x+k_1,y) \tag{1.3}
\]
即可。
假设 \((1.2)\) 对所有 \(k_1\le p\) 都满足。将 \(x+p\) 代入 \((1.1)\):
\[
w(x+p,y)+w(x+p+1,y+1)\le w(x+p,y+1)+w(x+p+1,y) \tag{1.4}
\]
将 \(k_1=p\) 代入 \((1,2)\),和 \((1.4)\) 相加:
\[
w(x,y)+w(x+p+1,y+1)\le w(x,y+1)+w(x+p+1,y)
\]
\((1.3)\) 式同理归纳即可。
推论 1
将 \((1.1)\) 移项:
\[
w(i+1,j+1)+w(i,j)-w(i,j+1)-w(i+1,j)=\Delta_x\Delta_y w(i,j)
\]
性质 2
在四边形不等式中,左右两侧的 \(f(a)+f(b)\) 可以同时消掉。
性质 3 和性质 4 的证明真的不会了。
决策单调性
对于满足四边形不等式的函数 \(w(x,y)\),考虑如下形式的 dp 转移方程:
\[
f[i]=\min_{j< i}\{w(j,i)\}
\]
朴素的转移为 \(O(n^2)\)。记 \(p[i]\) 为 \(f_i\) 的最优决策点,即 \(f_i=w(p[i],i)\)。
对于
\[
f[i]=\min_{j< i}\{f[j]+w(j,i)\}
\]
的转移方程,根据四边形不等式性质 2,它等价于上面的情况。
性质 1:对于任意 \(j_1< j_2\),存在一个常数 \(i_0\),使得 \(\forall i\ge i_0\),\(w(j_2,i)< w(j_1,i)\);\(\forall i< i_0\),\(w(j_1,i)\le w(j_2,i)\);
推论:\(p[i]\) 单调不降;
证明
考虑两种情况:
- 对于下标 \(i_0\),若 \(w(j_2,i_0)< w(j_1,i_0)\),则 \(\forall i\ge i_0\),\(w(j_2,i)< w(j_1,i)\)(四边形不等式);
- 对于下标 \(i_0\),若 \(w(j_1,i_0)\le w(j_2,i_0)\),则 \(\forall i\le i_0\),\(w(j_1,i)\le w(j_2,i)\)(四边形不等式);
第一种情况下 \(i_0\) 的最小值、第二种情况下 \(i_0\) 的最大值加 \(1\),都是分界点。
两条推论启发我们可以通过维护 \(p[i]\) 来进行快速的转移。我们从左向右进行转移,假设已经考虑了前缀 \([1,i-1]\)。对于下标 \(i\),根据性质 1,其可能贡献到的 \(i'\) 是一个后缀(是指 \(i\) 比原先记载的 \(p[i']\) 更优)。如果我们能找到这个后缀,并修改后缀中所有位置的 \(p[i']\),那么向 \(f[i']\) 转移时,记录的 \(p[i']\) 就一定是最优解。

为什么“\(f[i]\) 可能贡献到的 \(i'\) 是一个后缀”
如果找到了一个分界点 \(x\) 满足 \(w(i,x)<w(p[x],i)\),则 \(\forall j<i\),\(w(i,x)<w(j,x)\)(最优性)。根据性质 1,对于 \(k\ge x\),都有 \(w(i,k)<w(j,k)\)。
因为每次修改的区间都是一个后缀,所以使用单调队列维护即可。
具体的,队列中的每个节点需要记录 \((l,r,p_0)\),表示对于 \(i\in[l,r]\),\(p[i]=p_0\)。每遍历到一个 \(i\):
- 弹出队首过时的区间;
- 考虑用 \(i\) 更新后面的 dp 值,如果队尾的 \(l\) 位置处,\(i\) 比 \(p[tail]\) 更优,则弹出队尾(因为队尾区间被 \(i\) 后缀完全包含了);
- 直到队尾不满足以上条件,说明分界点 \(i_0\) 就位于此时队尾所在的区间内。
- 二分找到分界点 \(i_0\),将队尾的 \(r\) 修改为 \(i_0-1\);判断分界点是否合法,将 \((i_0,n,i)\) 压入队尾。
因为每考虑一个 \(i\) 都需要进行一次二分,因此时间复杂度为 \(O(n\log n)\)。
例题
模板代码
| #include<iostream>
#include<deque>
#define int long long
using namespace std;
const int N = 5E4 + 10;
struct myPair {
int l, r, p;
};
int n, c;
int s[N], f[N];
inline int w(int l, int r) {
return (s[r] - s[l] + c) * (s[r] - s[l] + c);
}
myPair que[N];
int head = 1, tail;
signed main() {
cin >> n >> c;
for(int i = 1; i <= n; i++) {
cin >> s[i];
s[i] += s[i - 1] + 1;
}
c = -c - 1;
que[++tail] = {1, n, 0};
for(int i = 1; i <= n; i++) {
if(que[head].r == i - 1) head++;
else que[head].l = i;
f[i] = f[que[head].p] + w(que[head].p, i);
while(i < que[tail].l && f[i] + w(i, que[tail].l) < f[que[tail].p] + w(que[tail].p, que[tail].l)) --tail;
int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
while(l < r) {
int mid = (l + r) >> 1;
if(f[i] + w(i, mid) < f[p] + w(p, mid)) {
r = mid;
} else l = mid + 1;
}
que[tail].r = l - 1;
if(l <= n) que[++tail] = {l, n, i};
}
cout << f[n] << endl;
return 0;
}
|
题目大意
给定一个长为 \(n\) 的数列 \(a_i\),你需要对每个 \(i\) 求出
\[
\max_{j=1}^{i-1}\{\Big\lceil a_j+\sqrt{i-j}\Big\rceil \}
\]
注意
\(y=\sqrt{x}\) 是上凸函数,但 \(y=\Big\lceil\sqrt{x}\Big\rceil\) 不具有凸性。因此本题需要在浮点数类型下计算,将上取整提到 \(\max\) 外。
注意到 \(w(j,i)=\sqrt{i-j}\) 满足反向四边形不等式,考虑四边形不等式优化。通过画图可以发现,每个 \(j\) 可能贡献到的 \(i\) 确实是一个后缀。直接套用四边形不等式的模板即可。
代码
| #include<iostream>
#include<algorithm>
#include<cmath>
#define int long long
#define ld long double
using namespace std;
const int N = 5E5 + 10;
const ld eps = 1e-8;
struct Range {
int l, r, p;
};
int n;
int a[N];
ld w[N], mx[N];
Range que[N];
int head, tail;
signed main() {
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
for(int i = 1; i <= n; i++) {
w[i] = sqrt(i);
}
head = 1, tail = 0;
que[++tail] = {1, n, 1};
for(int i = 2; i <= n; i++) {
while(que[head].r < i) ++head;
que[head].l = i;
mx[i] = max(mx[i], a[que[head].p] + w[i - que[head].p]);
while(head < tail && a[i] + w[que[tail].l - i] >= a[que[tail].p] + w[que[tail].l - que[tail].p]) --tail;
int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
while(l < r) {
int mid = (l + r) >> 1;
if(a[i] + w[mid - i] >= a[p] + w[mid - p]) {
r = mid;
} else {
l = mid + 1;
}
}
que[tail].r = l - 1;
if(l <= n) que[++tail] = {l, n, i};
}
// 因为原题没有限制 j<i,因此需要反着跑一边 DP
reverse(a + 1, a + 1 + n);
head = 1, tail = 0;
que[++tail] = {1, n, 1};
for(int i = 2; i <= n; i++) {
while(que[head].r < i) ++head;
que[head].l = i;
mx[n - i + 1] = max(mx[n - i + 1], a[que[head].p] + w[i - que[head].p]);
while(head < tail && a[i] + w[que[tail].l - i] >= a[que[tail].p] + w[que[tail].l - que[tail].p]) --tail;
int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
while(l < r) {
int mid = (l + r) >> 1;
if(a[i] + w[mid - i] >= a[p] + w[mid - p]) {
r = mid;
} else {
l = mid + 1;
}
}
que[tail].r = l - 1;
if(l <= n) que[++tail] = {l, n, i};
}
for(int i = 1; i <= n; i++) {
// 向上取整
cout << max((int)(mx[i] + 1 - eps) - a[n - i + 1], 0ll) << '\n';
}
return 0;
}
|
关于 long double
和 long long
的精度问题,请参考错题本。
代码
| #include<iostream>
#include<cstring>
#define int long long
#define ld long double
using namespace std;
const int N = 1E5 + 10;
const double V = 1E18;
const double INF = 1E20;
struct myPair {
int l, r, p;
};
int T;
int n, L, P;
int s[N], p[N];
ld f[N];
char str[N][40];
myPair que[N];
int head, tail;
inline ld qpow(ld a, int b) {
ld res = 1;
while(b) {
if(b & 1) res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
inline ld w(int i, int j) {
return qpow(abs(s[j] - s[i] - L - 1), P);
}
void solve() {
head = 1, tail = 0;
que[++tail] = {1, n, 0};
for(int i = 1; i <= n; i++) {
if(que[head].r == i - 1) head++;
else que[head].l = i;
f[i] = f[que[head].p] + w(que[head].p, i);
p[i] = que[head].p;
while(head < tail && f[i] + w(i, que[tail].l) < f[que[tail].p] + w(que[tail].p, que[tail].l)) --tail;
int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
while(l < r) {
int mid = (l + r) >> 1;
if(f[i] + w(i, mid) < f[p] + w(p, mid)) {
r = mid;
} else l = mid + 1;
}
que[tail].r = l - 1;
if(l <= n) que[++tail] = {l, n, i};
}
if(f[n] > V) throw 114514ll;
}
void outPut(int x) {
if(x == 0) return;
outPut(p[x]);
for(int i = p[x] + 1; i <= x; i++) {
for(int j = 0; j < s[i] - s[i - 1] - 1; j++) {
cout << str[i][j];
}
if(i < x) cout << ' ';
}
cout << '\n';
}
signed main() {
cin >> T;
while(T--) {
cin >> n >> L >> P;
getchar();
for(int i = 1; i <= n; i++) {
cin.getline(str[i], 40, '\n');
s[i] = strlen(str[i]);
s[i] += s[i - 1] + 1;
}
for(int i = 1; i <= n; i++) f[i] = 0;
for(int i = 1; i <= n; i++) que[i] = {0, 0, 0};
try {
solve();
} catch(int err) {
cout << "Too hard to arrange\n";
cout << "--------------------\n";
continue;
}
cout << (int)f[n] << '\n';
outPut(n);
cout << "--------------------\n";
}
return 0;
}
|