跳转至

权值线段树

线段树版的桶

桶可以记录数据的分布。若要动态查询数据中,给定区间内共有多少个数,就可以将桶用线段树进行处理,这种线段树叫做权值线段树。

  • 记录权值的线段树。记录权值指的是,每个点上存的是区间内数字出现的总次数。比如一个长度为 \(10\) 的数组 [1,1,2,3,3,4,4,4,4,5]
  • 其中数字 \(1\) 出现了两次,那么 \([1,1]\) 这个区间的值为 \(2\),数字 \(2\) 出现了一次,那么 \([2,2]\) 这个区间的值为 \(1\)
  • 通过 push_up 可以得到,\([1,2]\) 这个节点的值为 \(3\),即 \(1\) 出现的次数和 \(2\) 出现的次数加和。
  • 如果我想要知道这个数组上的第 \(k\) 小,我就可以在这个权值线段树上用 \(O(\log n)\)的时间来实现。
  • 如果原始输入的值域范围比较大,可能需要先离散化。

权值线段树

例题

P1908 逆序对

代码
#include<iostream>
#include<algorithm>
#include<map>
using namespace std;
const int N = 1E6 + 10;

int n;
int a[N], num[N];

int sum[4 * N];
map<int, int> mp;

inline int lc(int x){ return x << 1; }
inline int rc(int x){ return x << 1 | 1; }

void push_up(int u){
    sum[u] = sum[lc(u)] + sum[rc(u)];
}

void update(int u, int l, int r, int q, int val){
    if(l == r){
    sum[u] += val;
        return;
    }
    int mid = (l + r) >> 1;
    if(mid >= q) update(lc(u), l, mid, q, val);
    else update(rc(u), mid + 1, r, q, val);
    push_up(u);
}

int query(int u, int l, int r, int ql, int qr){
    if(ql <= l && r <= qr){
        return sum[u];
    }
    int mid = (l + r) >> 1, res = 0;
    if(mid >= ql) res += query(lc(u), l, mid, ql, qr);
    if(mid < qr) res += query(rc(u), mid + 1, r, ql, qr);
    return res;
}

int main(){

    cin >> n;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
        num[i] = a[i];
    }

    sort(num + 1, num + 1 + n);
    int newN = unique(num + 1, num + 1 + n) - num - 1;
    for(int i = 1; i <= newN; i++){
        mp[num[i]] = i;
    }

    newN += 5;

    long long ans = 0;
    for(int i = 1; i <= n; i++){
        ans += query(1, 1, newN, mp[a[i]] + 1, newN);
        update(1, 1, newN, mp[a[i]], 1);
    }
    cout << ans << endl;

    return 0;
}

权值线段树的动态开点

使用上面的普通线段树改造的权值线段树存在各种问题:

  • 有时,线段树需要维护的值域很大,但是实际用到的节点很少;
  • 强制在线不能离散化;

于是,我们可以写出动态开点的权值线段树:

  • 初始时线段树中没有节点,用到的时候再向内存要。每一次单点修改都会产生一条从根走向叶子的路径,总空间消耗 \(O(n\log V)\)\(n\) 表示调用 insert() 的次数);
  • 查询时若遇到空节点,则说明该区间内没有信息,返回 \(0\)\(\pm \infty\) 即可。

例题

U74894 有便便的厕所

代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int INF = 1E9 + 10;

struct Node{
    int lc;
    int rc;
    int cnt;
    int del;
} tree[35 * N];

int q, nn = 1;

void move_tag(int p){
    tree[p].cnt = 0;
    tree[p].del = 1;
}

void push_down(int p){
    if(tree[p].del){
        move_tag(tree[p].lc);
        move_tag(tree[p].rc);
        tree[p].del = 0;
    }
}

void push_up(int p){
    tree[p].cnt = tree[tree[p].lc].cnt + tree[tree[p].rc].cnt;
}

void insert(int p, int l, int r, int q){
    if(l == r){
        tree[p].cnt++;
        tree[p].del = 0;
        return;
    }
    int mid = (l + r) >> 1;
    push_down(p);
    if(mid >= q){
        if(tree[p].lc == 0) tree[p].lc = ++nn;
        insert(tree[p].lc, l, mid, q);
    } else{
        if(tree[p].rc == 0) tree[p].rc = ++nn;
        insert(tree[p].rc, mid + 1, r, q);
    }
    push_up(p);
}

void remove(int p, int l, int r, int ql, int qr){
    if(p == 0 || tree[p].del == 1 || tree[p].cnt == 0) return;
    if(ql <= l && r <= qr){
        tree[p].cnt = 0;
        tree[p].del = 1;
        return;
    }
    int mid = (l + r) >> 1;
    push_down(p);
    if(mid >= ql) remove(tree[p].lc, l, mid, ql, qr);
    if(mid < qr) remove(tree[p].rc, mid + 1, r, ql, qr);
    push_up(p);
}

int query_sum(int p, int l, int r, int ql, int qr){
    if(p == 0 || tree[p].del == 1 || tree[p].cnt == 0) return 0;
    if(ql <= l && r <= qr){
        return tree[p].cnt;
    }
    int mid = (l + r) >> 1;
    int res = 0;
    push_down(p);
    if(mid >= ql) res += query_sum(tree[p].lc, l, mid, ql, qr);
    if(mid < qr) res += query_sum(tree[p].rc, mid + 1, r, ql, qr);
    return res;
}

int queryK(int p, int l, int r, int ql, int qr, int k){
    if(p == 0 || tree[p].del == 1 || tree[p].cnt == 0) return -1;
    if(l == r){
        if(tree[p].cnt >= k) return l;
        return -1;
    }
    int mid = (l + r) >> 1, cnt = 0;
    push_down(p);
    if(mid < qr){
        cnt = query_sum(tree[p].rc, mid + 1, r, ql, qr);
        if(cnt >= k) return queryK(tree[p].rc, mid + 1, r, ql, qr, k);
    }
    if(mid >= ql){
        return queryK(tree[p].lc, l, mid, ql, qr, k - cnt);
    }
    return -1;
}

int main(){

    cin >> q;
    while(q--){
        int op, x, l, r, k;
        cin >> op;
        if(op == 1){
            cin >> x;
            insert(1, 1, INF, x);
        } else if(op == 2){
            cin >> l >> r;
            remove(1, 1, INF, l, r);
        } else{
            cin >> l >> r >> k;
            cout << queryK(1, 1, INF, l, r, k) << endl;
        }
    }

    return 0;
}

U74895 有便便的厕所2

代码
#include<iostream>
#define ll long long
using namespace std;
const int N = 1E6 + 10;
const int LOGN = 20;

int q, nn = 1;
struct Node{
    int lc;
    int rc;
    ll sum1, sum2; //和,平方和
    int tag;
} tree[LOGN * N];

inline ll getSum1(ll l, ll r){
    return (l + r) * (r - l + 1) >> 1;
}

inline ll getSum2(ll l, ll r){
    return (r * (r + 1) * (r << 1 | 1) - l * (l - 1) * ((l - 1) << 1 | 1)) / 6;
}

void move_tag(int p, int l, int r, int tg){
    tree[p].sum1 += getSum1(l, r) * tg;
    tree[p].sum2 += getSum2(l, r) * tg;
    tree[p].tag += tg;
}

void push_down(int p, int l, int r){
    if(tree[p].lc == 0) tree[p].lc = ++nn;
    if(tree[p].rc == 0) tree[p].rc = ++nn;
    if(tree[p].tag != 0){
        int mid = (l + r) >> 1;
        move_tag(tree[p].lc, l, mid, tree[p].tag);
        move_tag(tree[p].rc, mid + 1, r, tree[p].tag);
        tree[p].tag = 0;
    }
}

void push_up(int p){
    int lc = tree[p].lc;
    int rc = tree[p].rc;
    tree[p].sum1 = tree[lc].sum1 + tree[rc].sum1;
    tree[p].sum2 = tree[lc].sum2 + tree[rc].sum2;
}

void add(int p, int l, int r, int ql, int qr){
    if(ql <= l && r <= qr){
        tree[p].sum1 += getSum1(l, r);
        tree[p].sum2 += getSum2(l, r);
        tree[p].tag ++;
        return;
    }
    push_down(p, l, r);
    int mid = (l + r) >> 1;
    if(mid >= ql) add(tree[p].lc, l, mid, ql, qr);
    if(mid < qr) add(tree[p].rc, mid + 1, r, ql, qr);
    push_up(p);
}

ll query1(int p, int l, int r, int ql, int qr){
    if(ql <= l && r <= qr){
        return tree[p].sum1;
    }
    push_down(p, l, r);
    int mid = (l + r) >> 1;
    ll res = 0;
    if(mid >= ql) res += query1(tree[p].lc, l, mid, ql, qr);
    if(mid < qr) res += query1(tree[p].rc, mid + 1, r, ql, qr);
    return res;
}

ll query2(int p, int l, int r, int ql, int qr){
        if(ql <= l && r <= qr){
                return tree[p].sum2;
        }
        push_down(p, l, r);
        int mid = (l + r) >> 1;
        ll res = 0;
        if(mid >= ql) res += query2(tree[p].lc, l, mid, ql, qr);
        if(mid < qr) res += query2(tree[p].rc, mid + 1, r, ql, qr);
        return res;
}

int main(){

    cin >> q;
    while(q--){
        int op, x, l, r;
        cin >> op;
        if(op == 1){
            cin >> x;
            add(1, 1, N, x, x);
        } else if(op == 2){
            cin >> l >> r;
            add(1, 1, N, l, r);
        } else{
            cin >> l >> r;
            cout << query1(1, 1, N, l, r) << ' ';
            cout << query2(1, 1, N, l, r) << endl;
        }
    }

    return 0;
}