跳转至

主席树

主席树是可持久化线段树的简称,可以在 \(O(\log V)\) 的时间和空间复杂度内实现一次插入,同时可以保存所有历史版本。

  • 如果我们需要给序列的每个前缀/后缀都建一棵线段树,但是空间开不下,就可以使用主席树。主席树的每一个历史版本对应一个前缀/后缀。
  • 如果我们希望快速得到区间 \([l,r]\) 对应的线段树,并进行线段树上二分,就可以使用主席树+差分解决。

例题

P3919 【模板】可持久化线段树 1(可持久化数组)

题目大意

维护一个长度为 \(n\) 的数组,每次指定一个历史版本,进行单点修改或单点查询,同时生成一个新的版本。

模板代码
#include<iostream>
using namespace std;
const int N = 1E6 + 10;
const int LOGN = 25; 

int n, q;
int a[N], rt[N];

namespace Seg_T {

    int nn;
    int lc[N * LOGN], rc[N * LOGN], sum[N * LOGN];

    int addNode(int p) {
        int nw = ++nn;
        lc[nw] = lc[p];
        rc[nw] = rc[p];
        return nw;
    }

    inline void push_up(int p) {
        sum[p] = sum[lc[p]] + sum[rc[p]];
    }

    void build(int &p, int l, int r) {
        if(p == 0) p = ++nn;
        if(l == r) {
            sum[p] = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc[p], l, mid);
        build(rc[p], mid + 1, r);
        push_up(p); 
    }

    int modify(int p1, int l, int r, int q, int v) {
        int p = addNode(p1);
        if(l == r) {
            sum[p] = v;
            return p;
        }
        int mid = (l + r) >> 1;
        if(mid >= q) lc[p] = modify(lc[p1], l, mid, q, v);
        else rc[p] = modify(rc[p1], mid + 1, r, q, v);
        push_up(p);
        return p;
    }

    int query(int p, int l, int r, int q) {
        if(l == r) {
            return sum[p];
        }
        int mid = (l + r) >> 1;
        if(mid >= q) return query(lc[p], l, mid, q);
        else return query(rc[p], mid + 1, r, q); 
    }

}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> q;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
    } 

    Seg_T::build(rt[0], 1, n);

    for(int i = 1; i <= q; i++) {
        int ver, op, pos, val;
        cin >> ver >> op;
        if(op == 1) {
            cin >> pos >> val;
            rt[i] = Seg_T::modify(rt[ver], 1, n, pos, val);
        } else {
            cin >> pos;
            cout << Seg_T::query(rt[ver], 1, n, pos) << '\n';
            rt[i] = rt[ver]; 
        }
    }

    return 0;
}

P3834 【模板】可持久化线段树 2

题目大意

给定一个长度为 \(n\) 的序列 \(a\),每次询问给定区间 \([l,r]\) 和整数 \(k\),要求查询区间第 \(k\) 小。

使用主席树我们可以对原数组的每一个前缀都“建”一棵权值线段树,通过对两棵不同下标对应的树作差,我们可以得到任意下标区间 \([l,r]\) 对应的权值线段树,通过线段树上二分,我们可以求出第 \(k\) 小的值。

时间复杂度 \(O(n\log V)\) - \(O(m\log V)\)

模板代码
#include<iostream>
#define int long long
using namespace std;
const int N = 2E5 + 10;
const int A = 1E9; 
const int LOGA = 20;

int n, m; 
int rt[N];

namespace Seg_T {

    int nn;
    int lc[N * LOGA], rc[N * LOGA], sum[N * LOGA];

    int addNode(int p1) {
        int p = ++nn;
        lc[p] = lc[p1];
        rc[p] = rc[p1];
        sum[p] = sum[p1];
        return p;
    }

    void push_up(int p) {
        sum[p] = sum[lc[p]] + sum[rc[p]];
    }

    int insert(int p1, int l, int r, int q, int v) {
        int p = addNode(p1);
        if(l == r) {
            sum[p] += v;
            return p;
        }
        int mid = (l + r) >> 1;
        if(mid >= q) lc[p] = insert(lc[p1], l, mid, q, v);
        else rc[p] = insert(rc[p1], mid + 1, r, q, v);
        push_up(p); 
        return p;
    }

    int queryKth(int pl, int pr, int l, int r, int k) {
        if(l == r) {
            return l;
        }
        int mid = (l + r) >> 1;
        if(k <= sum[lc[pr]] - sum[lc[pl]]) return queryKth(lc[pl], lc[pr], l, mid, k);
        else return queryKth(rc[pl], rc[pr], mid + 1, r, k - (sum[lc[pr]] - sum[lc[pl]]));
    }

}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        int x;
        cin >> x;
        rt[i] = Seg_T::insert(rt[i - 1], 0, A, x, 1);
    }

    for(int i = 1; i <= m; i++) {
        int l, r, k;
        cin >> l >> r >> k;
        cout << Seg_T::queryKth(rt[l - 1], rt[r], 0, A, k) << '\n';
    }

    return 0;
}

P7252 [JSOI2011] 棒棒糖

题目大意

给定一个序列 \(a\),每次询问一个区间 \([l,r]\),回答这个区间的绝对众数的值(没有输出 \(-1\)

注意到绝对众数有一些很好的性质。因为其出现次数大于总数的一半,所以可以直接在值域线段树上二分。因为绝对众数所在的值域区间的 sum 一定大于整个 sum 的一半(记作 \(k\))。

要想快速获得下标区间 \([l,r]\) 中所有元素组成的权值线段树,并在其上使用线段树上二分,使用主席树即可。

(本题也可以使用普通线段树+摩尔投票法+vector+lower_bound验证绝对众数) (本题也可以使用随机化+vector+lower_bound验证+)

代码
#include<iostream>
using namespace std;
const int N = 5E4 + 10;
const int V = 5E4 + 10;
const int LOGV = 16;

int n, m; 
int c[N], rt[N];

namespace Seg_T {

    int sum[N * LOGV], lc[N * LOGV], rc[N * LOGV], nn;
    inline void push_up(int p) {
        sum[p] = sum[lc[p]] + sum[rc[p]];
    } 
    int insert(int p1, int l, int r, int q, int v) {
        int p = ++nn;
        lc[p] = lc[p1];
        rc[p] = rc[p1];
        sum[p] = sum[p1];
        if(l == r) {
            sum[p] += v;
            return p;
        }
        int mid = (l + r) >> 1;
        if(mid >= q) lc[p] = insert(lc[p1], l, mid, q, v);
        else rc[p] = insert(rc[p1], mid + 1, r, q, v); 
        push_up(p);
        if(sum[p] != sum[p1] + 1) throw -1; 
        return p; 
    }
    int query(int pl, int pr, int l, int r, int k) {
        if(l == r) {
            return l; 
        }
        int mid = (l + r) >> 1;
        if(sum[lc[pr]] - sum[lc[pl]] >= k) return query(lc[pl], lc[pr], l, mid, k);
        if(sum[rc[pr]] - sum[rc[pl]] >= k) return query(rc[pl], rc[pr], mid + 1, r, k);
        return 0; 
    }

}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        cin >> c[i];
        rt[i] = Seg_T::insert(rt[i - 1], 1, V, c[i], 1);
    }

    while(m--) {
        int l, r;
        cin >> l >> r;
        cout << Seg_T::query(rt[l - 1], rt[r], 1, V, (r - l + 1) / 2 + 1) << '\n'; 
    }

    return 0; 
} 

CF786C Till I Collapse

题目大意

对于 \(k=1,2,3,⋯,n\),分别求出最小的 \(m\),使得存在一种将 \(n\) 个数划分成 \(m\) 段的方案,每段中不同数字的种类不超过 \(k\) 个。

对于一个 \(k\),从左往右贪心是显然的。此时每段长度至少为 \(k\),最坏情况下会分成 \(\lceil\frac{n}{k}\rceil\) 段(序列中所有数字两两不同),总的段数就是 \(n\log n\)调和级数-时间复杂度)。问题就转化为:对于一个起点下标 \(x\),找到最靠左的下标 \(y\),满足 \([x,y]\) 中的数字种类恰好为 \(k\)

相等的数字产生的贡献归结到该数字第一次出现的位置上。我们希望对后缀 \([x,n]\) 可以获得一个线段树,在后缀中第一次出现的数字对应的位置为 \(1\),其余为 \(0\)。这样就可以直接进行线段树上二分,找到最后一个前缀和 \(\le k\) 的位置即是 \(y\)

想要获得所有后缀对应的线段树,考虑主席树解决。我们从右向左遍历数组,用桶数组记录每个数字上一次出现的位置。每遇到一个数字,如果其已经出现过,就将原来出现的位置置为 \(0\),本次出现的位置置为 \(1\)。记录两次修改后产生的新版本,即使每个后缀对应的线段树。

代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int LOGN = 20;

int n;
int rt[N], a[N], last[N];

namespace Seg_T {

    int sum[2 * N * LOGN], lc[2 * N * LOGN], rc[2 * N * LOGN], nn;

    inline void push_up(int p) { sum[p] = sum[lc[p]] + sum[rc[p]]; }
    int insert(int p1, int l, int r, int q, int v) {
        int p = ++nn;
        lc[p] = lc[p1];
        rc[p] = rc[p1];
        sum[p] = sum[p1];
        if(l == r) {
            sum[p] += v;
            return p; 
        }
        int mid = (l + r) >> 1;
        if(mid >= q) lc[p] = insert(lc[p1], l, mid, q, v);
        else rc[p] = insert(rc[p], mid + 1, r, q, v);
        push_up(p);
        return p;
    }
    int query(int p, int l, int r, int k) {
        if(l == r) {
            return l;
        }
        int mid = (l + r) >> 1;
        if(sum[lc[p]] > k) return query(lc[p], l, mid, k);
        return query(rc[p], mid + 1, r, k - sum[lc[p]]); 
    }

}

int main() {

    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    for(int i = n; i >= 1; i--) {
        rt[i] = Seg_T::insert(rt[i + 1], 1, n + 1, i, 1);
        if(last[a[i]]) {
            rt[i] = Seg_T::insert(rt[i], 1, n + 1, last[a[i]], -1);
        }
        last[a[i]] = i;
    }

    for(int k = 1; k <= n; k++) {
        int cur = 1, ans = 0;
        while(cur <= n) {
            cur = Seg_T::query(rt[cur], 1, n + 1, k);
            ans++;
        }
        cout << ans << ' ';
    }
    cout << endl;

    return 0;
}