wqs 二分
wqs 二分可以解决具有这样一类限制的问题:要求恰好选 \(m\) 个特殊结构(区间、物品、特殊节点、边 等等),最小化总代价(或最大化总价值)。
记恰好选 \(m\) 个时的最小代价为 \(g(m)\)。wqs 二分有一些使用条件:
- 假如原问题没有“恰好 \(m\) 个”的限制的话,可以在很低的时间复杂度内求出最优解 \(\min\{g(m)\}\),或者存在很简单的解法;
- 原问题的答案 \(g(m)\) 关于 \(m\) 具有凸性;
wqs 二分的基本思路是:给每一个特殊结构都增加一个固定的代价 \(k\),从而控制特殊结构的数量。二分出合适的 \(k\),使得子问题的最优解恰好选择了 \(m\) 个特殊结构。此时将子问题的答案减去 \(mk\) 即是选 \(m\) 个的最优解。
形式化的,如果我们可以 \(O(n)\) 求出 \(\frac{d}{dx}g(x)=0\) 的位置 \(x\)(显然,我们可以数出最优解中特殊结构的数量);wqs 二分就是找到一个 \(k\) 使得 \(\frac{d}{dx}(g(x)+kx)=0\) 的解为 \(x=m\)。

如何判断无解
使用 wqs 二分时,最保险的判断无解的方法是:
if(check(INF) > k) no();
if(check(-INF) < k) no();
如果不使用双关键字排序,check
可能返回直线和凸包相切的线段上的任何一点。此时 l = -INF
或 l = INF
不能说明一定无解。
例题
题意
给定一个长为 \(n\) 的序列 \(a\),要求你将它划分成恰好 \(m\) 段连续的区间,最小化 \(\sum_{i=1}^m{(s[r_i]-s[l_i-1])^2}\)。
这道题显然可以使用斜率优化 \(O(nm)\) 实现。考虑 wqs 二分,我们需要证明:在固定 \(n,a_i\) 的情况下,答案关于 \(m\) 是下凸的。
证明
记恰好 \(m\) 段时最优解为 \(g(m)\)。我们现在证明 \(\forall i\in [2,n-1],\ 2g(i)\le g(i-1)+g(i+1)\)。
考虑 \(m=i-1\) 和 \(m=i+1\) 时的最优划分:\([a_1,d_1],\cdots,[a_{i-1},d_{i-1}]\) 和 \([b_1,c_1],\cdots,[b_{i+1},c_{i+1}]\)。找到满足 \(c_{j+1}<d_j\) 且最小的 \(j\),则有 \(a_j\ge b_{j+1}\)(否则 \(j\) 不是最小的)。
考虑将两个解交换 \([j,n]\) 的部分:
\[
[a_1,d_1],\cdots,[a_{j},c_{j+1}],[b_{j+2},c_{j+2}],\cdots,[b_{i+1},c_{i+1}]\\
[b_1,c_1],\cdots,[b_{j+1},d_{j}],[a_{j+1},d_{j+1}],\cdots,[a_{i-1},d_{i-1}]
\]
此时得到的两个划分都有恰好 \(i\) 段。根据 \(g(i)\) 的最优性和四边形不等式:
\[
\begin{align*}
2g(i)\le&\ w(a_1,d_1)+\cdots+w(a_{j},c_{j+1})+\cdots+w(b_{i+1},c_{i+1})\\
&+\ w(b_1, c_1)+\cdots+w(b_{j+1},d_j)+\cdots+w(a_{i-1},d_{i-1})\\
\le &\ w(a_1,d_1)+\cdots+w(a_j,d_j)+w(b_{j+2},c_{j+2})+\cdots+w(b_{i+1},c_{i+1})\\
&+\ w(b_1, c_1)+\cdots+w(b_{j+1}, c_{j+1})+w(a_{j+1},d_{j+1})+\cdots+w(a_{i-1},d_{i-1})\\
=&\ w(a_1,d_1)+\cdots+w(a_j,d_j)+w(a_{j+1},d_{j+1})+\cdots+w(a_{i-1},d_{i-1})\\
&+\ w(b_1, c_1)+\cdots+w(b_{j+1}, c_{j+1})+w(b_{j+2},c_{j+2})+\cdots+w(b_{i+1},c_{i+1})\\
=&\ g(i-1)+g(i+1)
\end{align*}
\]
得证。
代码
| #include<iostream>
#define ld long double
#define ll long long
using namespace std;
const int N = 3010;
const ll INF = 0x3f3f3f3f3f3f3f3f;
int n, m;
ll s[N], f[N], g[N];
int que[N], hd, tl;
inline ll pw2(ll x) { return x * x; }
inline ll X(int id) { return s[id]; }
inline ll Y(int id) { return f[id] + s[id] * s[id]; }
inline ld K(int i1, int i2) { return (ld)(Y(i2) - Y(i1)) / (X(i2) - X(i1)); }
ll check(int w) {
hd = 1, tl = 0;
que[++tl] = 0;
for(int i = 1; i <= n; i++) {
while(hd < tl && K(que[hd], que[hd + 1]) <= 2 * s[i]) ++hd;
f[i] = f[que[hd]] + pw2(s[i] - s[que[hd]]) + w;
g[i] = g[que[hd]] + 1;
while(hd < tl && K(que[tl], i) <= K(que[tl - 1], que[tl])) --tl;
que[++tl] = i;
}
return g[n];
}
ll work() {
int l = 0, r = 1e9;
ll ans, aw;
while(l < r) {
int mid = (l + r + 1) >> 1;
if(check(mid) >= m) {
ans = f[n];
aw = mid;
l = mid;
} else r = mid - 1;
}
return ans - aw * m;
}
int main() {
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> s[i];
for(int i = 1; i <= n; i++) s[i] += s[i - 1];
ll res = work();
cout << (m * res - s[n] * s[n]) << endl;
return 0;
}
|