跳转至

四边形不等式优化 DP

四边形不等式

对于定义在 \(\N^2\) 上的二元函数 \(w(x,y)\),若 \(\forall a\le b\le c\le d\)\(a,b,c,d\in \N\)),都有

\[ w(a,c)+w(b,d)\le w(a,d)+w(b,c) \]

则称 \(w(x,y)\) 满足四边形不等式。

性质

性质 1\(w(x,y)\) 满足四边形不等式的充要条件是:\(\forall x<y\),有

\[ w(x,y)+w(x+1,y+1)\le w(x,y+1)+w(x+1,y)\tag{1.1} \]

推论 1\(w(x,y)\) 满足四边形不等式的充要条件是 \(w(x,y)\) 的二阶混合差分始终非正。

性质 2:若函数 \(w(x,y)\) 满足四边形不等式,则 \(f(x)+w(x,y)\) 也满足四边形不等式。

性质 3:设 \(h(u)\) 是一个单调增加的凸函数,若函数 \(w(x,y)\) 满足四边形不等式且具有区间包含单调性,则复合函数 \(h(w(x,y))\) 也满足四边形不等式和区间包含单调性。

性质 4:设 \(h(u)\) 是一个凸函数,若函数 \(w(x,y)\) 满足四边形恒等式且具有区间包含单调性,则复合函数 \(h(w(x,y))\) 也满足四边形不等式。

证明

性质 1

考虑归纳法。我们现在要证明:对于所有 \(k_1\)\(x+k_1\le y\) 都有

\[ w(x,y)+w(x+k_1,y+1)\le w(x,y+1)+w(x+k_1,y) \tag{1.2} \]

然后对于所有 \(k_2\) 再证明

\[ w(x,y)+w(x+k_1,y+k_2)\le w(x,y+k_2)+w(x+k_1,y) \tag{1.3} \]

即可。

假设 \((1.2)\) 对所有 \(k_1\le p\) 都满足。将 \(x+p\) 代入 \((1.1)\)

\[ w(x+p,y)+w(x+p+1,y+1)\le w(x+p,y+1)+w(x+p+1,y) \tag{1.4} \]

\(k_1=p\) 代入 \((1,2)\),和 \((1.4)\) 相加:

\[ w(x,y)+w(x+p+1,y+1)\le w(x,y+1)+w(x+p+1,y) \]

\((1.3)\) 式同理归纳即可。

推论 1

\((1.1)\) 移项:

\[ w(i+1,j+1)+w(i,j)-w(i,j+1)-w(i+1,j)=\Delta_x\Delta_y w(i,j) \]

性质 2

在四边形不等式中,左右两侧的 \(f(a)+f(b)\) 可以同时消掉。

性质 3 和性质 4 的证明真的不会了。

决策单调性

对于满足四边形不等式的函数 \(w(x,y)\),考虑如下形式的 dp 转移方程:

\[ f[i]=\min_{j< i}\{w(j,i)\} \]

朴素的转移为 \(O(n^2)\)。记 \(p[i]\)\(f_i\) 的最优决策点,即 \(f_i=w(p[i],i)\)

对于

\[ f[i]=\min_{j< i}\{f[j]+w(j,i)\} \]

的转移方程,根据四边形不等式性质 2,它等价于上面的情况。

性质 1:对于任意 \(j_1< j_2\),存在一个常数 \(i_0\),使得 \(\forall i\ge i_0\)\(w(j_2,i)< w(j_1,i)\)\(\forall i< i_0\)\(w(j_1,i)\le w(j_2,i)\)

推论\(p[i]\) 单调不降;

证明

考虑两种情况:

  • 对于下标 \(i_0\),若 \(w(j_2,i_0)< w(j_1,i_0)\),则 \(\forall i\ge i_0\)\(w(j_2,i)< w(j_1,i)\)(四边形不等式);
  • 对于下标 \(i_0\),若 \(w(j_1,i_0)\le w(j_2,i_0)\),则 \(\forall i\le i_0\)\(w(j_1,i)\le w(j_2,i)\)(四边形不等式);

第一种情况下 \(i_0\) 的最小值、第二种情况下 \(i_0\) 的最大值加 \(1\),都是分界点。

两条推论启发我们可以通过维护 \(p[i]\) 来进行快速的转移。我们从左向右进行转移,假设已经考虑了前缀 \([1,i-1]\)。对于下标 \(i\),根据性质 1,其可能贡献到的 \(i'\) 是一个后缀(是指 \(i\) 比原先记载的 \(p[i']\) 更优)。如果我们能找到这个后缀,并修改后缀中所有位置的 \(p[i']\),那么向 \(f[i']\) 转移时,记录的 \(p[i']\) 就一定是最优解。

加入直线的过程

为什么“\(f[i]\) 可能贡献到的 \(i'\) 是一个后缀”

如果找到了一个分界点 \(x\) 满足 \(w(i,x)<w(p[x],i)\),则 \(\forall j<i\)\(w(i,x)<w(j,x)\)(最优性)。根据性质 1,对于 \(k\ge x\),都有 \(w(i,k)<w(j,k)\)

因为每次修改的区间都是一个后缀,所以使用单调队列维护即可。

具体的,队列中的每个节点需要记录 \((l,r,p_0)\),表示对于 \(i\in[l,r]\)\(p[i]=p_0\)。每遍历到一个 \(i\)

  • 弹出队首过时的区间;
  • 考虑用 \(i\) 更新后面的 dp 值,如果队尾的 \(l\) 位置处,\(i\)\(p[tail]\) 更优,则弹出队尾(因为队尾区间被 \(i\) 后缀完全包含了);
  • 直到队尾不满足以上条件,说明分界点 \(i_0\) 就位于此时队尾所在的区间内。
  • 二分找到分界点 \(i_0\),将队尾的 \(r\) 修改为 \(i_0-1\);判断分界点是否合法,将 \((i_0,n,i)\) 压入队尾。

因为每考虑一个 \(i\) 都需要进行一次二分,因此时间复杂度为 \(O(n\log n)\)

例题

P3195 [HNOI2008] 玩具装箱

模板代码
#include<iostream>
#include<deque>
#define int long long
using namespace std;
const int N = 5E4 + 10;

struct myPair {
    int l, r, p;
};

int n, c;
int s[N], f[N];

inline int w(int l, int r) {
    return (s[r] - s[l] + c) * (s[r] - s[l] + c);
}

myPair que[N];
int head = 1, tail;

signed main() {

    cin >> n >> c;
    for(int i = 1; i <= n; i++) {
        cin >> s[i];
        s[i] += s[i - 1] + 1;
    }

    c = -c - 1;

    que[++tail] = {1, n, 0};

    for(int i = 1; i <= n; i++) {
        if(que[head].r == i - 1) head++;
        else que[head].l = i;
        f[i] = f[que[head].p] + w(que[head].p, i);
        while(i < que[tail].l && f[i] + w(i, que[tail].l) < f[que[tail].p] + w(que[tail].p, que[tail].l)) --tail;
        int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
        while(l < r) {
            int mid = (l + r) >> 1;
            if(f[i] + w(i, mid) < f[p] + w(p, mid)) {
                r = mid;
            } else l = mid + 1;
        }
        que[tail].r = l - 1;
        if(l <= n) que[++tail] = {l, n, i};
    }

    cout << f[n] << endl;

    return 0;
}

P3515 [POI 2011] Lightning Conductor

题目大意

给定一个长为 \(n\) 的数列 \(a_i\),你需要对每个 \(i\) 求出

\[ \max_{j=1}^{i-1}\{\Big\lceil a_j+\sqrt{i-j}\Big\rceil \} \]

注意

\(y=\sqrt{x}\) 是上凸函数,但 \(y=\Big\lceil\sqrt{x}\Big\rceil\) 不具有凸性。因此本题需要在浮点数类型下计算,将上取整提到 \(\max\) 外。

注意到 \(w(j,i)=\sqrt{i-j}\) 满足反向四边形不等式,考虑四边形不等式优化。通过画图可以发现,每个 \(j\) 可能贡献到的 \(i\) 确实是一个后缀。直接套用四边形不等式的模板即可。

代码
#include<iostream>
#include<algorithm>
#include<cmath>
#define int long long
#define ld long double
using namespace std;
const int N = 5E5 + 10;
const ld eps = 1e-8;

struct Range {
    int l, r, p;
};

int n;
int a[N];
ld w[N], mx[N];

Range que[N];
int head, tail;

signed main() {

    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    for(int i = 1; i <= n; i++) {
        w[i] = sqrt(i);
    }

    head = 1, tail = 0;
    que[++tail] = {1, n, 1};
    for(int i = 2; i <= n; i++) {
        while(que[head].r < i) ++head;
        que[head].l = i;
        mx[i] = max(mx[i], a[que[head].p] + w[i - que[head].p]);
        while(head < tail && a[i] + w[que[tail].l - i] >= a[que[tail].p] + w[que[tail].l - que[tail].p]) --tail;
        int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
        while(l < r) {
            int mid = (l + r) >> 1;
            if(a[i] + w[mid - i] >= a[p] + w[mid - p]) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        que[tail].r = l - 1;
        if(l <= n) que[++tail] = {l, n, i};
    }

    // 因为原题没有限制 j<i,因此需要反着跑一边 DP
    reverse(a + 1, a + 1 + n);

    head = 1, tail = 0;
    que[++tail] = {1, n, 1};
    for(int i = 2; i <= n; i++) {
        while(que[head].r < i) ++head;
        que[head].l = i;
        mx[n - i + 1] = max(mx[n - i + 1], a[que[head].p] + w[i - que[head].p]);
        while(head < tail && a[i] + w[que[tail].l - i] >= a[que[tail].p] + w[que[tail].l - que[tail].p]) --tail;
        int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
        while(l < r) {
            int mid = (l + r) >> 1;
            if(a[i] + w[mid - i] >= a[p] + w[mid - p]) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        que[tail].r = l - 1;
        if(l <= n) que[++tail] = {l, n, i};
    }

    for(int i = 1; i <= n; i++) {
        // 向上取整
        cout << max((int)(mx[i] + 1 - eps) - a[n - i + 1], 0ll) << '\n';
    }

    return 0;
}

P1912 [NOI2009] 诗人小G

关于 long doublelong long 的精度问题,请参考错题本

代码
#include<iostream>
#include<cstring>
#define int long long
#define ld long double
using namespace std;
const int N = 1E5 + 10;
const double V = 1E18;
const double INF = 1E20;

struct myPair {
    int l, r, p;
};

int T;
int n, L, P;

int s[N], p[N];
ld f[N];
char str[N][40];

myPair que[N];
int head, tail;

inline ld qpow(ld a, int b) {
    ld res = 1;
    while(b) {
        if(b & 1) res = res * a;
        a = a * a;
        b >>= 1;
    }
    return res;
}

inline ld w(int i, int j) {
    return qpow(abs(s[j] - s[i] - L - 1), P);
}

void solve() {
    head = 1, tail = 0;
    que[++tail] = {1, n, 0};
    for(int i = 1; i <= n; i++) {
        if(que[head].r == i - 1) head++;
        else que[head].l = i;
        f[i] = f[que[head].p] + w(que[head].p, i);
        p[i] = que[head].p;
        while(head < tail && f[i] + w(i, que[tail].l) < f[que[tail].p] + w(que[tail].p, que[tail].l)) --tail;
        int l = que[tail].l, r = que[tail].r + 1, p = que[tail].p;
        while(l < r) {
            int mid = (l + r) >> 1;
            if(f[i] + w(i, mid) < f[p] + w(p, mid)) {
                r = mid;
            } else l = mid + 1;
        }
        que[tail].r = l - 1;
        if(l <= n) que[++tail] = {l, n, i};
    }
    if(f[n] > V) throw 114514ll;
}

void outPut(int x) {
    if(x == 0) return;
    outPut(p[x]);
    for(int i = p[x] + 1; i <= x; i++) {
        for(int j = 0; j < s[i] - s[i - 1] - 1; j++) {
            cout << str[i][j];
        }
        if(i < x) cout << ' ';
    }
    cout << '\n';
}

signed main() {

    cin >> T;
    while(T--) {
        cin >> n >> L >> P;
        getchar();
        for(int i = 1; i <= n; i++) {
            cin.getline(str[i], 40, '\n');
            s[i] = strlen(str[i]);
            s[i] += s[i - 1] + 1;
        }
        for(int i = 1; i <= n; i++) f[i] = 0;
        for(int i = 1; i <= n; i++) que[i] = {0, 0, 0};
        try {
            solve();
        } catch(int err) {
            cout << "Too hard to arrange\n";
            cout << "--------------------\n";
            continue;
        }
        cout << (int)f[n] << '\n';
        outPut(n);
        cout << "--------------------\n";
    }

    return 0;
}