跳转至

wqs 二分

wqs 二分可以解决具有这样一类限制的问题:要求恰好选 \(m\) 个特殊结构(区间、物品、特殊节点、边 等等),最小化总代价(或最大化总价值)。

记恰好选 \(m\) 个时的最小代价为 \(g(m)\)。wqs 二分有一些使用条件:

  • 假如原问题没有“恰好 \(m\) 个”的限制的话,可以在很低的时间复杂度内求出最优解 \(\min\{g(m)\}\),或者存在很简单的解法;
  • 原问题的答案 \(g(m)\) 关于 \(m\) 具有凸性;

wqs 二分的基本思路是:给每一个特殊结构都增加一个固定的代价 \(k\),从而控制特殊结构的数量。二分出合适的 \(k\),使得子问题的最优解恰好选择了 \(m\) 个特殊结构。此时将子问题的答案减去 \(mk\) 即是选 \(m\) 个的最优解。

形式化的,如果我们可以 \(O(n)\) 求出 \(\frac{d}{dx}g(x)=0\) 的位置 \(x\)(显然,我们可以数出最优解中特殊结构的数量);wqs 二分就是找到一个 \(k\) 使得 \(\frac{d}{dx}(g(x)+kx)=0\) 的解为 \(x=m\)

wqs 二分

如何判断无解

使用 wqs 二分时,最保险的判断无解的方法是:

if(check(INF) > k) no();
if(check(-INF) < k) no();

如果不使用双关键字排序,check 可能返回直线和凸包相切的线段上的任何一点。此时 l = -INFl = INF 不能说明一定无解。

P4072 [SDOI2016] 征途

给定一个长为 \(n\) 的序列 \(a\),要求你将它划分成恰好 \(m\) 段连续的区间,最小化 \(\sum_{i=1}^m{(s[r_i]-s[l_i-1])^2}\)

这道题显然可以使用斜率优化 \(O(nm)\) 实现。考虑 wqs 二分,我们需要证明:在固定 \(n,a_i\) 的情况下,答案关于 \(m\) 是下凸的。

证明

记恰好 \(m\) 段时最优解为 \(g(m)\)。我们现在证明 \(\forall i\in [2,n-1],\ 2g(i)\le g(i-1)+g(i+1)\)

考虑 \(m=i-1\)\(m=i+1\) 时的最优划分:\([a_1,d_1],\cdots,[a_{i-1},d_{i-1}]\)\([b_1,c_1],\cdots,[b_{i+1},c_{i+1}]\)。找到满足 \(c_{j+1}<d_j\) 且最小的 \(j\),则有 \(a_j\ge b_{j+1}\)(否则 \(j\) 不是最小的)。

考虑将两个解交换 \([j,n]\) 的部分:

\[ [a_1,d_1],\cdots,[a_{j},c_{j+1}],[b_{j+2},c_{j+2}],\cdots,[b_{i+1},c_{i+1}]\\ [b_1,c_1],\cdots,[b_{j+1},d_{j}],[a_{j+1},d_{j+1}],\cdots,[a_{i-1},d_{i-1}] \]

此时得到的两个划分都有恰好 \(i\) 段。根据 \(g(i)\) 的最优性和四边形不等式:

\[ \begin{align*} 2g(i)\le&\ w(a_1,d_1)+\cdots+w(a_{j},c_{j+1})+\cdots+w(b_{i+1},c_{i+1})\\ &+\ w(b_1, c_1)+\cdots+w(b_{j+1},d_j)+\cdots+w(a_{i-1},d_{i-1})\\ \le &\ w(a_1,d_1)+\cdots+w(a_j,d_j)+w(b_{j+2},c_{j+2})+\cdots+w(b_{i+1},c_{i+1})\\ &+\ w(b_1, c_1)+\cdots+w(b_{j+1}, c_{j+1})+w(a_{j+1},d_{j+1})+\cdots+w(a_{i-1},d_{i-1})\\ =&\ w(a_1,d_1)+\cdots+w(a_j,d_j)+w(a_{j+1},d_{j+1})+\cdots+w(a_{i-1},d_{i-1})\\ &+\ w(b_1, c_1)+\cdots+w(b_{j+1}, c_{j+1})+w(b_{j+2},c_{j+2})+\cdots+w(b_{i+1},c_{i+1})\\ =&\ g(i-1)+g(i+1) \end{align*} \]

得证。

代码
#include<iostream>
#define ld long double
#define ll long long
using namespace std;
const int N = 3010;
const ll INF = 0x3f3f3f3f3f3f3f3f;

int n, m;
ll s[N], f[N], g[N];

int que[N], hd, tl;

inline ll pw2(ll x) { return x * x; }

inline ll X(int id) { return s[id]; }
inline ll Y(int id) { return f[id] + s[id] * s[id]; }

inline ld K(int i1, int i2) { return (ld)(Y(i2) - Y(i1)) / (X(i2) - X(i1)); }

ll check(int w) {
    hd = 1, tl = 0;
    que[++tl] = 0;
    for(int i = 1; i <= n; i++) {
        while(hd < tl && K(que[hd], que[hd + 1]) <= 2 * s[i]) ++hd;
        f[i] = f[que[hd]] + pw2(s[i] - s[que[hd]]) + w;
        g[i] = g[que[hd]] + 1;
        while(hd < tl && K(que[tl], i) <= K(que[tl - 1], que[tl])) --tl;
        que[++tl] = i;
    }
    return g[n];
}

ll work() {
    int l = 0, r = 1e9;
    ll ans, aw;
    while(l < r) {
        int mid = (l + r + 1) >> 1;
        if(check(mid) >= m) {
            ans = f[n];
            aw = mid;
            l = mid;
        } else r = mid - 1;
    }
    return ans - aw * m;
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> s[i];
    for(int i = 1; i <= n; i++) s[i] += s[i - 1];

    ll res = work();
    cout << (m * res - s[n] * s[n]) << endl;

    return 0;
}

P2619 [国家集训队] Tree I

代码
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 5E4 + 10;
const int M = 1E5 + 10;

struct Edge {
    int u, v, w, c;
    inline bool operator<(const Edge &other) const {
        if(w != other.w) return w < other.w;
        return c < other.c;
    }
} edg[M];

int n, m, k;

int fa[N];
int find(int x) {
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}
void merge(int x, int y) {
    fa[find(x)] = find(y);
}

int mst, cnt;
int getmst(int b) {
    mst = cnt = 0;
    for(int i = 1; i <= m; i++) {
        edg[i].w += edg[i].c * b;
    }
    sort(edg + 1, edg + 1 + m);
    for(int i = 0; i <= n; i++) fa[i] = i;
    for(int i = 1; i <= m; i++) {
        int u = edg[i].u, v = edg[i].v;
        if(find(u) != find(v)) {
            merge(u, v);
            mst += edg[i].w;
            cnt += edg[i].c;
        }
    }
    for(int i = 1; i <= m; i++) {
        edg[i].w -= edg[i].c * b;
    }
    return cnt;
}

signed main() {

    cin >> n >> m >> k;
    for(int i = 1; i <= m; i++) {
        cin >> edg[i].u >> edg[i].v >> edg[i].w >> edg[i].c;
        edg[i].c ^= 1;
    }

    int l = 0, r = 410;
    int ans;
    while(l < r) {
        int mid = (l + r) >> 1;
        if(getmst(mid - 205) <= k) {
            r = mid;
            ans = mst - k * (r - 205);
        } else l = mid + 1;
    }

    cout << ans << endl;

    return 0;
} 

P5633 最小度限制生成树

代码
#include<iostream>
using namespace std;
const int N = 5e4 + 10;
const int M = 5e5 + 10;
const int V = 3e4 + 10;

inline void no() { cout << "Impossible\n"; exit(0); }

struct Edge {
    int u, v, w;
} edg[M];
int swp[M];

int n, m, s, k;
int ans, res;

int fa[N];
int cnt[4 * V];

void sort() {
    for(int i = 0; i < 4 * V; i++) cnt[i] = 0;
    for(int i = 1; i <= m; i++) ++cnt[edg[i].w + V];
    for(int i = 1; i < 4 * V; i++) cnt[i] += cnt[i - 1];
    for(int i = 1; i <= m; i++) swp[cnt[edg[i].w + V]--] = i;
}

int find(int x) {
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

int check(int ww) {
    for(int i = 1; i <= m; i++) {
        if(edg[i].u == s || edg[i].v == s) edg[i].w += ww;
    }
    sort();
    res = 0;
    int cc = 0, j = 0;
    for(int i = 1; i <= n; i++) fa[i] = i;
    for(int i = 1; i <= m && j <= n - 1; i++) {
        Edge e = edg[swp[i]];
        if(find(e.u) != find(e.v)) {
            res += e.w;
            fa[find(e.u)] = find(e.v);
            ++j;
            if(e.u == s || e.v == s) ++cc;
        }
    }
    if(j != n - 1) no();
    for(int i = 1; i <= m; i++) {
        if(edg[i].u == s || edg[i].v == s) edg[i].w -= ww;
    }
    return cc;
}

int main() {

    int tmp = 0;

    cin >> n >> m >> s >> k;
    for(int i = 1; i <= m; i++) {
        cin >> edg[i].u >> edg[i].v >> edg[i].w;
        if(edg[i].u == s || edg[i].v == s) ++tmp;
    }

    if(check(-30005) < k) no();

    int l = -30005, r = 30005;
    while(l < r) {
        int mid = (l + r) >> 1;
        if(check(mid) <= k) {
            r = mid;
            ans = res - k * mid;
        } else l = mid + 1;
    }

    if(l == 30005) {
        no();
    }

    cout << ans << '\n';

    return 0;
}