跳转至

线段树上二分

有些二分答案的 check() 函数需要快速查询区间信息,例如区间最值、区间和,此时可以使用线段树维护。直接在线段树上进行二分答案的时间复杂度为 \(O(\log^2 n)\)。通过线段树上二分可以将单次查询的时间复杂度降低到 \(O(\log n)\)

\(O(\log n)\) 解决的问题:

  • 动态查询任意位置 \(p_0\) 左(右)侧第一个满足 check( SegT::query(p0, p) ) 的位置 \(p\)

基本原理

注意到线段树正是一种形如二分的结构。

如果查询区间是 \([1,n]\),则可以直接利用线段树上每个节点保存的区间信息进行二分。例如,我们要查询 \([1,n]\) 范围内最靠左的满足 \(a_p\ge v_0\) 的位置 \(p\)

*样例中 \(n=8\)\(v_0=9\),答案 \(p=3\)

线段树上二分1

1
2
3
4
5
6
7
8
int query(int p, int l, int r, int v) {
    if(mx[p] < v) return -1;
    if(l == r) return l;
    int mid = (l + r) >> 1;
    int tmp = query(lc(p), l, mid, v);
    if(tmp != -1) return tmp;
    return query(rc(p), mid + 1, r, v);
}

但是如果查询区间不是 \([1,n]\),而是任意区间 \([l,r]\),我们就可以把 \([l,r]\) 拆成 \(O(\log n)\) 段线段树上的区间,按从左向右的顺序依次遍历所有区间,找到第一个 \(mx[p]\ge v_0\) 的区间:

线段树上二分2

这个区间可以看作是一棵完整的线段树,直接使用上文提到的方法二分即可。

更改处:

传入查询区间 int ql, int qr

int query(int p, int l, int r, int ql, int qr, int v);

添加剪枝,剪掉和查询区间无交的区间:

if(r < ql || l > qr) return -1;

这样程序就只会递归 \([l_q,r_q]\) 所包含的区间(和它们的父区间)。同时因为递归的顺序是从左向右的,因此返回的答案一定是第一个满足条件的。

完整代码
1
2
3
4
5
6
7
8
9
int query(int p, int l, int r, int ql, int qr, int v) {
    if(mx[p] < v) return -1;
    if(r < ql || l > qr) return -1;
    if(l == r) return l;
    int mid = (l + r) >> 1;
    int tmp = query(lc(p), l, mid, v);
    if(tmp != -1) return tmp;
    return query(rc(p), mid + 1, r, v);
}

例题

北京冬令营07B-C expand

题面

给定一个长为 \(n\) 的序列 \(a\)。你当前有一个区间 \([l,r]\) 和权值 \(v\),在此基础上你可以做以下操作:

  1. \(l\ne 1\)\(a_{l-1}\le v\),将 \(l\) 减去 \(1\)
  2. \(r\ne n\)\(a_{r+1}\le v\),将 \(r\) 加上 \(1\)
  3. \(v\) 变为 \(\min\{a_{l-1},a_{r+1}\}\),这里规定 \(a_0=a_{n+1}=+\infty\)

定义 \(f(i)\) 为,假设初始区间为 \([i,i]\)\(v=a_i\),将区间变为 \([1,n]\),最少需要做多少次第三种操作。

你需要处理 \(m\) 次以下两种操作:

  1. 给定 \(x\),交换 \(a_x, a_{x+1}\)
  2. 给定 \(l,r\),求出 \(\sum_{i=l}^{r}f(i)\)

观察题面容易发现:记集合 \(A\) 为区间 \([1,i]\) 的所有后缀的最大值组成的集合,\(B\) 为区间 \([i,n]\) 的所有前缀的最大值组成的集合,则 \(f(i)=|A\cup B|-1\)

先考虑没有修改的情况如何求出所有 \(f(i)\)。注意到如果直接对每个元素取当前单调栈的深度,则无法去掉两侧重复的数字。

考虑每个元素 \(j\) 可以贡献到什么样的 \(f(i)\)。记 \(l_j\)\(j\) 左侧第一个 \(a_l\ge a_j\) 的位置;\(r_j\)\(j\) 右侧第一个 \(a_r\ge a_j\) 的位置(“第一个”均指离 \(j\) 最近的一个)。容易发现,\(j\) 可以贡献到 \((l_j,j)\)\((j,r_j)\) 两个开区间的所有数字;而若 \(a_l=a_j\),则不应贡献到 \((l_j,j)\),因为 \((l_j,j)\) 已经被 \(l_j\) 统计过了。

注意到这样做是容易的,因为我们容易判断 \(a_l=a_j\) 的情况。只需要跑正反两遍单调栈即可。

在此我们额外介绍一种静态求 \(f(i)\) 的方法。注意到 \(f(i)\) 可由其左侧或右侧第一个 \(a_j\ge a_i\)\(j\) 转移而来,即 \(f(i)=f(j)+[a_j\ne a_i]\),因此直接记忆化搜索即可。这种方法也适用于小规模快速单点查询 \(f(i)\)

接下来考虑带修的情况。我们分讨 \(a_x\)\(a_{x+1}\) 的大小关系。如果 \(a_x=a_{x+1}\) 则直接跳过;而 \(a_x>a_{x+1}\) 的情况其实和 \(a_x<a_{x+1}\) 的情况本质相同。考虑如何处理 \(a_x<a_{x+1}\)。我们依照上文单调栈的思路,先去除 \(a_x\) 在左侧的贡献,然后加入 \(a_x\) 在右侧产生的贡献。而这需要我们动态查询每个数的 \(l_i\)\(r_i\)。这可以直接由上文的线段树上二分解决。

考虑如何处理 \(a_x\)\(a_{x+1}\)\(f\)。我们可以使用第二种求 \(f(i)\) 的思路,找到第一个比它大的 \(j\),然后从 \(j\) 转移到 \(i\)。分讨 \(j=l_x\)\(j=x+1\) 的情况即可。

代码
#include<iostream>
#define int long long
#define cint const int&
using namespace std;
const int N = 3E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

int n, q;
int a[N];
int sta[N], top;

namespace BIT {
    int sum1[N], sum2[N];
    inline int lowbit(cint x) { return x & -x; }
    inline void add(cint p, cint v) {
        for(int i = p; i <= n + 4; i += lowbit(i)) {
            sum1[i] += v;
            sum2[i] += p * v;
        }
    }
    inline int query(cint p) {
        int res = 0;
        for(int i = p; i > 0; i -= lowbit(i)) {
            res += (p + 1) * sum1[i];
            res -= sum2[i];
        }
        return res;
    }
    inline void add(cint l, cint r, cint v) {
        if(l > r) return;
        add(l, v);
        add(r + 1, -v);
    }
    inline int query(cint l, cint r) {
        return query(r) - query(l - 1);
    }
    inline void set(cint p, cint v) {
        add(p, p, v - query(p, p));
    }
    inline int query_pt(cint p) {
        return query(p, p);
    }
}

namespace SegT {
    int mx[4 * N];
    inline int lc(cint x) { return x << 1; }
    inline int rc(cint x) { return x << 1 | 1; }
    inline void push_up(cint p) {
        mx[p] = max(mx[lc(p)], mx[rc(p)]);
    }
    void build(cint p, cint l, cint r) {
        if(l == r) {
            mx[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void modify(cint p, cint l, cint r, cint q, cint v) {
        if(l == r) {
            mx[p] = v;
            return;
        }
        int mid = (l + r) >> 1;
        if(mid >= q) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    int query_pre(cint p, cint l, cint r, cint q, cint v) {
        if(q < l) return -1;
        if(mx[p] < v) return -1;
        if(l == r) {
            return l;
        }
        int mid = (l + r) >> 1;
        int tmp = query_pre(rc(p), mid + 1, r, q, v);
        if(~tmp) return tmp;
        return query_pre(lc(p), l, mid, q, v);
    }
    int query_suc(cint p, cint l, cint r, cint q, cint v) {
        if(q > r) return -1;
        if(mx[p] < v) return -1;
        if(l == r) {
            return l;
        }
        int mid = (l + r) >> 1;
        int tmp = query_suc(lc(p), l, mid, q, v);
        if(~tmp) return tmp;
        return query_suc(rc(p), mid + 1, r, q, v);
    }
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> q;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    a[0] = INF;
    a[n + 1] = INF;
    SegT::build(1, 0, n + 1);

    for(int i = 0; i <= n + 1; i++) {
        while(top && a[i] > a[sta[top]]) --top;
        if(top && a[i] != a[sta[top]]) BIT::add(sta[top] + 1, i - 1, 1);
        while(top && a[sta[top]] == a[i]) --top;
        sta[++top] = i;
    }
    top = 0;
    for(int i = n + 1; i >= 0; i--) {
        while(top && a[i] > a[sta[top]]) --top;
        if(top) BIT::add(i + 1, sta[top] - 1, 1);
        while(top && a[sta[top]] == a[i]) --top;
        sta[++top] = i;
    }

    while(q--) {
        int op, l, r, x;
        cin >> op;
        if(op == 1) {
            cin >> x;
            if(a[x] == a[x + 1]) continue;
            if(a[x] < a[x + 1]) {
                int pr = SegT::query_pre(1, 0, n + 1, x - 1, a[x]);
                int nx = SegT::query_suc(1, 0, n + 1, x + 2, a[x]);
                if(a[pr] != a[x]) BIT::add(pr + 1, x - 1, -1);
                if(a[nx] != a[x]) BIT::add(x + 2, nx - 1, 1);
                int tmp; 
                BIT::set(x, tmp = BIT::query(x + 1, x + 1));
                if(a[nx] <= a[x + 1]) {
                    BIT::set(x + 1, BIT::query(nx, nx) + (a[nx] != a[x]));
                } else {
                    BIT::set(x + 1, tmp + 1);
                }
            } else {
                int pr = SegT::query_pre(1, 0, n + 1, x - 1, a[x + 1]);
                int nx = SegT::query_suc(1, 0, n + 1, x + 2, a[x + 1]);
                if(a[pr] != a[x + 1]) BIT::add(pr + 1, x - 1, 1);
                if(a[nx] != a[x + 1]) BIT::add(x + 2, nx - 1, -1);
                int tmp;
                BIT::set(x + 1, tmp = BIT::query(x, x));
                if(a[pr] <= a[x]) {
                    BIT::set(x, BIT::query(pr, pr) + (a[pr] != a[x + 1]));
                } else {
                    BIT::set(x, tmp + 1);
                }
            }
            swap(a[x], a[x + 1]);
            SegT::modify(1, 0, n + 1, x, a[x]);
            SegT::modify(1, 0, n + 1, x + 1, a[x + 1]);
        } else {
            cin >> l >> r;
            cout << BIT::query(l, r) - (r - l + 1) << '\n';
        }
    }

    return 0;
}