跳转至

线段树

线段树是一种 \(O(n)\) 建树,\(O(\log n)\) 查询区间半群信息(满足结合律),\(O(\log n)\) 单点赋值的数据结构。当和另一个运算满足(广义)分配律时,可以通过打 tag 的方式实现 \(O(\log n)\) 区间修改。如果修改操作和查询操作使用的运算相同,分配律退化为交换律。

维护差分数组

P4243 [JSOI2009] 等差数列

给定一个数列,实现两种操作:

  • 区间加一个等差数列;
  • 查询区间 \([l,r]\) 最少被拆成多少个等差数列;
代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int INF = 1E7 + 10;

struct Node {
    int mans, lans, rans, ans, tag;
    int ld, rd, len;
    Node(int _ma = 0, int _la = 0, int _ra = 0, int _a = 0, int _ld = 0, int _rd = 0, int _tg = 0, int _len = 0) {
        mans = _ma;
        lans = _la;
        rans = _ra;
        ans = _a;
        ld = _ld;
        rd = _rd;
        tag = _tg;
        len = _len;
    }
};

inline Node operator+(const Node &a, const Node &b) {
    Node res;
    res.mans = min(a.rans + b.lans - (a.rd == b.ld), min(a.mans + b.lans, a.rans + b.mans));
    res.lans = min(a.ans + b.lans - (a.rd == b.ld), min(a.lans + b.lans, a.ans + b.mans));
    res.rans = min(a.rans + b.ans - (a.rd == b.ld), min(a.mans + b.ans, a.rans + b.rans));
    res.ans = min(a.ans + b.ans - (a.rd == b.ld), min(a.lans + b.ans, a.ans + b.rans));
    res.ld = a.ld;
    res.rd = b.rd;
    return res;
}

int n, q; 
int a[N], d[N];

namespace Seg_T {

    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    Node tr[4 * N];
    inline void push_up(int p) {
        tr[p] = tr[lc(p)] + tr[rc(p)];
    }
    inline void move_tag(int p, int tg) {
        tr[p].ld += tg;
        tr[p].rd += tg;
        tr[p].tag += tg;
    }
    inline void push_down(int p) {
        if(!tr[p].tag) return;
        move_tag(lc(p), tr[p].tag);
        move_tag(rc(p), tr[p].tag);
        tr[p].tag = 0; 
    }
    void build(int p, int l, int r) {
        if(l == r) {
            tr[p] = {0, 1, 1, 1, d[l], d[l], 0};
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p); 
    } 
    void modify(int p, int l, int r, int q, int v) {
        if(l == r) {
            tr[p].ld += v;
            tr[p].rd += v;
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= q) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    void modify(int p, int l, int r, int ql, int qr, int v) {
        if(ql <= l && r <= qr) {
            move_tag(p, v);
            return; 
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= ql) modify(lc(p), l, mid, ql, qr, v);
        if(mid < qr) modify(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    Node query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return tr[p];
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(mid >= qr) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return query(lc(p), l, mid, ql, qr) + query(rc(p), mid + 1, r, ql, qr); 
    }

}

int main() {

    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
        d[i] = a[i] - a[i - 1];
    }

    Seg_T::build(1, 1, n + 1);

    cin >> q;
    while(q--) {
        char op;
        cin >> op;
        if(op == 'A') {
            int l, r, a, b;
            cin >> l >> r >> a >> b;
            if(l == r) {
                Seg_T::modify(1, 1, n + 1, l, a);
                Seg_T::modify(1, 1, n + 1, r + 1, -a); 
                continue;
            }
            Seg_T::modify(1, 1, n + 1, l, a);
            Seg_T::modify(1, 1, n + 1, l + 1, r, b);
            Seg_T::modify(1, 1, n + 1, r + 1, -a - (r - l) * b); 
        } else {
            int l, r;
            cin >> l >> r;
            if(l == r) {
                cout << 1 << '\n';
                continue;
            } 
            cout << Seg_T::query(1, 1, n + 1, l + 1, r).ans << '\n'; 
        }
    }

    return 0;
}

P5278 算术天才⑨与等差数列

你需要维护一个长度为 \(n\) 的序列,支持两种操作:

  • 单点修改;
  • 给定区间 \([l,r]\) 和常数 \(k\),询问区间内的数字能否重排成公差为 \(k\) 的等差数列;

\(n,m\le 3\times 10^5,\ 0\le a_i,k\le 10^9\),强制在线。

特判区间长度等于 \(1\)\(k=0\) 的情况。发现区间合法当且仅当差分数组都是 \(k\) 的倍数,并且区间中不包含重复数字。第一个条件可以使用线段树维护区间 \(\gcd\),第二个条件使用 pre 数组处理即可。

代码
#include<bits/stl_algobase.h>
#include<ctype.h>
#include<cstdio>
#include<map>
#include<set>
#include<unistd.h>
using namespace std;
const int N = 3e5 + 10;

#define endl '\n'
namespace io {
    struct istream {
        char ch;
        inline istream &operator>>(int &x) {
            while(!isdigit(ch = getchar()));
            x = ch - '0';
            while(isdigit(ch = getchar())) x = x * 10 + ch - '0';
            return *this;
        }
    } cin;
    struct ostream {
        char buf[60], top;
        inline ostream() : top(0) {}
        inline ostream &operator<<(int x) {
            do buf[++top] = x % 10 + '0', x /= 10; while(x);
            while(top) putchar(buf[top--]);
            return *this;
        }
        inline ostream &operator<<(char c) {
            putchar(c);
            return *this;
        }
        inline ostream &operator<<(const char s[]) {
            for(int i = 0; s[i]; i++) putchar(s[i]);
            return *this;
        }
    } cout;
}

using io::cin;
using io::cout;

int n, q;
int a[N];

// 区间数颜色
namespace checker1 {
    set<int> st[2 * N];
    map<int, int> mp;
    int nn;
    struct myPair {
        int mn, mx;
        inline myPair operator+(const myPair &b) const {
            return {min(mn, b.mn), max(mx, b.mx)};
        }
    };
    namespace SegT {
        int pre[4 * N], mn[4 * N], mx[4 * N];
        inline int lc(int x) { return x << 1; }
        inline int rc(int x) { return x << 1 | 1; }
        void modify_pre(int p, int l, int r, int q, int v) {
            // if(p == 1) cout << q << ' ' << v << endl;
            if(l == r) {
                pre[p] = v;
                return;
            }
            int mid = (l + r) >> 1;
            if(q <= mid) modify_pre(lc(p), l, mid, q, v);
            else modify_pre(rc(p), mid + 1, r, q, v);
            pre[p] = max(pre[lc(p)], pre[rc(p)]);
        }
        int query_pre(int p, int l, int r, int ql, int qr) {
            if(ql <= l && r <= qr) return pre[p];
            int mid = (l + r) >> 1;
            return max(ql <= mid ? query_pre(lc(p), l, mid, ql, qr) : 0, mid < qr ? query_pre(rc(p), mid + 1, r, ql, qr) : 0);
        }
        void modify_mnx(int p, int l, int r, int q, int v) {
            if(l == r) {
                mn[p] = mx[p] = v;
                return;
            }
            int mid = (l + r) >> 1;
            if(q <= mid) modify_mnx(lc(p), l, mid, q, v);
            else modify_mnx(rc(p), mid + 1, r, q, v);
            mn[p] = min(mn[lc(p)], mn[rc(p)]);
            mx[p] = max(mx[lc(p)], mx[rc(p)]);
        }
        myPair query_mnx(int p, int l, int r, int ql, int qr) {
            // cout << p << ' ' << l << ' ' << r << ' ' << ql << ' ' << qr << endl;
            // usleep(100000);
            if(ql <= l && r <= qr) return {mn[p], mx[p]};
            int mid = (l + r) >> 1;
            if(qr <= mid) return query_mnx(lc(p), l, mid, ql, qr);
            if(mid < ql) return query_mnx(rc(p), mid + 1, r, ql, qr);
            return query_mnx(lc(p), l, mid, ql, qr) + query_mnx(rc(p), mid + 1, r, ql, qr);
        }
    }
    void f(int &v) {
        int &x = mp[v];
        if(!x) x = ++nn;
        v = x;
    }
    void insert(int p, int v) {
        SegT::modify_mnx(1, 1, n, p, v);
        f(v);
        int pre = 0, suc = 0;
        set<int>::iterator it;
        it = st[v].lower_bound(p);
        if(it != st[v].end()) suc = *it;
        if(it != st[v].begin()) pre = *--it;
        st[v].insert(p);
        if(suc) SegT::modify_pre(1, 1, n, suc, p);
        SegT::modify_pre(1, 1, n, p, pre);
    }
    void erase(int p, int v) {
        // throw;
        f(v);
        int pre = 0, suc = 0;
        set<int>::iterator it;
        it = st[v].lower_bound(p);
        if(it != st[v].begin()) pre = *--it;
        it = st[v].upper_bound(p);
        if(it != st[v].end()) suc = *it;
        st[v].erase(p);
        if(suc) SegT::modify_pre(1, 1, n, suc, pre);
    }
    bool check(int l, int r, int k) {
        myPair res = SegT::query_mnx(1, 1, n, l, r);
        if(k == 0) return res.mn == res.mx;
        // cout << SegT::query_pre(1, 1, n, l, r) << '\n';
        return (SegT::query_pre(1, 1, n, l, r) < l) && ((res.mx - res.mn) == (long long)(r - l) * k);
    }
}

namespace checker2 {
    inline int gcd(int a, int b) {
        if(a < b) swap(a, b);
        while(b) {
            a = a % b;
            swap(a, b);
        }
        return a;
    }
    namespace SegT {
        int g[4 * N];
        inline int lc(int x) { return x << 1; }
        inline int rc(int x) { return x << 1 | 1; }
        void modify(int p, int l, int r, int q, int v) {
            if(l == r) {
                g[p] = v;
                return;
            }
            int mid = (l + r) >> 1;
            if(q <= mid) modify(lc(p), l, mid, q, v);
            else modify(rc(p), mid + 1, r, q, v);
            g[p] = gcd(g[lc(p)], g[rc(p)]);
        }
        int query(int p, int l, int r, int ql, int qr) {
            if(ql <= l && r <= qr) return g[p];
            int mid = (l + r) >> 1;
            return gcd(ql <= mid ? query(lc(p), l, mid, ql, qr) : 0, mid < qr ? query(rc(p), mid + 1, r, ql, qr) : 0);
        }
    }
    void modify(int p, int v) {
        SegT::modify(1, 1, n, p, v < 0 ? -v : v);
    }
    bool check(int l, int r, int k) {
        ++l;
        return SegT::query(1, 1, n, l, r) == k;
    }
}

void modify(int x, int y) {
    checker1::erase(x, a[x]);
    checker1::insert(x, y);
    checker2::modify(x, y - a[x - 1]);
    checker2::modify(x + 1, a[x + 1] - y);
    a[x] = y;
}

bool query(int l, int r, int k) {
    if(l == r) return true;
    return checker1::check(l, r, k) && checker2::check(l, r, k);
}

int main() {

    cin >> n >> q;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= n; i++) checker1::insert(i, a[i]);
    for(int i = 1; i <= n; i++) checker2::modify(i, a[i] - a[i - 1]);

    int lstans = 0;
    while(q--) {
        int op, x, y, l, r;
        cin >> op;
        if(op == 1) {
            cin >> x >> y;
            x ^= lstans;
            y ^= lstans;
            modify(x, y);
        } else {
            cin >> l >> r >> x;
            l ^= lstans;
            r ^= lstans;
            x ^= lstans;
            int res = query(l, r, x);
            lstans += res;
            cout << (res ? "Yes\n" : "No\n");
        }
    }

    return 0;
}

维护矩阵

矩阵的大部分运算具有结合律,可以使用线段树维护,从而进行快速的区间查询。例如普通矩阵乘法各种广义矩阵乘法

*线段树维护矩阵乘法优化 DP 见 DP 优化

2025 北京冬令营 B 班模拟赛 2 T1 optimization

题面

给定一个 \(m×n\) 的网格,称 \((i,j)\) 为第 \(i\) (\(1 \leq i \leq m\)) 行第 \(j\) (\(1 \leq j \leq n\)) 列的点。对任意 \((i,j)\),它与 \((i,j−1)\)\((i,j+1)\)\((i−1,j)\)\((i+1,j)\)(如果存在)均有双向边连接,边上有非负权值。

\(q\) 次询问,每次询问给定 \(a\), \(b\), \(c\), \(d\),求 \((a,b)\)\((c,d)\) 之间的最短路径的权值。

为了方便叙述,我们记:列的编号为 \(x\),行的编号为 \(y\)

容易发现一个性质:对于询问 \((sx,sy)\rightarrow(tx,ty)\),其中间的一条竖线 \(x=x_0\)\(x_0\in[sx,tx]\))中至少有一个节点会被最短路径经过。因此可以考虑分块、分治,但时间复杂度都不行。(我打的分治,\(80pts\)

对于最短路问题,其松弛操作 \(d[u][v]=\min(d[u][k]+d[k][v])\) 和矩阵乘法非常相似。事实上,这正是一种 \((\min,+)\) 运算的广义矩阵乘法。

最短路和广义矩阵乘法

最短路可以转化为 \((\min,+)\) 广义矩阵乘法模型,在一些情况下可以使用矩阵求解。

在思考过程中,很容易想到使用 dp 预处理,低复杂度查询的思路(虽然直接行不通,因为查询区间不是前缀)。再结合矩阵乘法优化 dp 的思路,将 dp 转移方程化为矩阵乘法,就可以使用线段树维护区间 dp 值了。

同时由于 \(m\le 10\) 非常小,可以开 \(m\times m\) 的矩阵。我们可以将矩阵 \(D_{i,j}\) 定义为:从区间左端点 \(l\)\(i\) 行,走到区间右端点 \(r\)\(j\) 行,最短路是多少(如果你考虑过 dp,状态应该就是 类似 这么定义的。当然,最短路可以经过区间之外的点。不过为了简化思路,可以先不考虑这一点)。

注意到这样定义状态有一些很好的性质:比如可以 \(O(m^3)\) 快速合并两个相邻的区间、可直接由区间对应的矩阵得到询问的答案。我们理所应当的使用线段树维护这个矩阵乘法(快速合并 \(\rightarrow\) 线段树)。

接下来考虑两个仍然很重要的问题:如何处理最短路超出区间的情况?边界条件如何处理?

首先,这两个问题并不影响两个区间可以快速合并,以及可以快速查询答案。

既然不影响区间合并,我们不如直接从边界情况入手。我们想要求出每一列的 \(m\) 个点之间的全源最短路 \(d_{i,u,v}\)(最短路可以经过其他列)。我们注意到:经过其他列的路线一定可以被分割成完全在左边的路径完全在右边的路径。如果将它们分开考虑,然后再进行合并,即可得到每一列的全源最短路。

对于一列点 \(x=x_0\),我们不妨先考虑完全在其左边的路径对 dis_{x_0} 的贡献。我们从左到右依次考虑每一列,并从前一列转移到当前列。注意到每次新考虑一列的点,会加入该列内部的 \(m-1\) 条边,以及和左侧其他列之间的 \(m\) 条边。我们通过后者将 \(dis_{x_0-1}\) 转移到 \(dis_{x_0}\),再加入前者产生的贡献。最后跑 Floyd 松弛,融合两者,即可得到完全在左边的最短路。

我们以相同的方式反方向跑一遍,得到完全在右边的最短路;最后跑 Floyd 松弛,将两者融合,即可得到全局最短路。

当然,我们无需开两个 dis 数组,只需要在第二次跑的时候直接将贡献更新在原先的 dis 上,Floyd 松弛即可。

代码
#include<iostream>
#define int long long
#define cint const int&
using namespace std;
const int N = 1E5 + 10;
const int M = 12;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Rdis {
    int l, r;
    int dis[M][M];
    Rdis(cint _l = 0, cint _r = 0) {
        l = _l, r = _r;
        for(int i = 0; i < M; i++) {
            for(int j = 0; j < M; j++) {
                dis[i][j] = INF;
            }
        }
    }
    inline void lr(cint _l, cint _r) {
        l = _l, r = _r;
    }
    inline const int* operator[](cint index) const {
        return dis[index];
    }
    inline int* operator[](cint index) {
        return dis[index];
    }
};

int m, n, q;
int w[2][N][M], d[N][M][M];

inline Rdis operator+(const Rdis &a, const Rdis &b) {
    Rdis res(a.l, b.r);
    for(int u = 1; u <= m; u++) {
        for(int v = 1; v <= m; v++) {
            for(int k = 1; k <= m; k++) {
                res[u][v] = min(res[u][v], a[u][k] + w[0][a.r][k] + b[k][v]);
            }
        }
    }
    return res;
}

inline void mul(const int *a, const Rdis &b, int *res, cint MODE) {
    for(int i = 1; i <= m; i++) {
        res[i] = INF;
        for(int j = 1; j <= m; j++) {
            res[i] = min(res[i], a[j] + w[0][b.l - 1][j] * MODE + b[j][i]);
        }
    }
}

namespace SegT {
    Rdis tr[4 * N];
    inline int lc(cint x) { return x << 1; }
    inline int rc(cint x) { return x << 1 | 1; }
    inline void push_up(cint p) {
        tr[p] = tr[lc(p)] + tr[rc(p)];
    }
    void build(cint p, cint l, cint r) {
        if(l == r) {
            tr[p].lr(l, l);
            for(int i = 1; i <= m; i++) {
                for(int j = 1; j <= m; j++) {
                    tr[p][i][j] = d[l][i][j];
                }
            }
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void query(cint p, cint l, cint r, cint ql, cint qr, int* res) {
        if(ql <= l && r <= qr) {
            int tmp[M];
            for(int i = 1; i <= m; i++) tmp[i] = res[i];
            mul(tmp, tr[p], res, l > ql);
            return;
        }
        int mid = (l + r) >> 1;
        if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
        if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
    }
}

// 松弛整列
void floyd(cint col) {
    for(int k = 1; k <= m; k++) {
        for(int i = 1; i <= m; i++) {
            for(int j = 1; j <= m; j++) {
                d[col][i][j] = min(d[col][i][j], d[col][i][k] + d[col][k][j]);
            }
        }
    }
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> m >> n >> q;
    for(int i = 1; i <= m; i++) {
        for(int j = 1; j <= n - 1; j++) {
            cin >> w[0][j][i];
        }
    }
    for(int i = 1; i <= m - 1; i++) {
        for(int j = 1; j <= n; j++) {
            cin >> w[1][j][i];
        }
    }

    for(int u = 1; u <= m; u++) {
        for(int v = 1; v <= m; v++) {
            d[0][u][v] = d[n + 1][u][v] = INF;
        }
    }

    for(int i = 1; i <= n; i++) {
        int s[M] = {0};
        for(int j = 2; j <= m; j++) {
            s[j] = s[j - 1] + w[1][i][j - 1];
        }
        for(int u = 1; u <= m; u++) {
            for(int v = 1; v <= m; v++) {
                if(u == v) continue;
                d[i][u][v] = min(d[i - 1][u][v] + w[0][i - 1][u] + w[0][i - 1][v], abs(s[u] - s[v]));
            }
        }
        floyd(i);
    }

    for(int i = n; i >= 1; i--) {
        int s[M] = {0};
        for(int j = 2; j <= m; j++) {
            s[j] = s[j - 1] + w[1][i][j - 1];
        }
        for(int u = 1; u <= m; u++) {
            for(int v = 1; v <= m; v++) {
                if(u == v) continue;
                d[i][u][v] = min(d[i][u][v], min(d[i + 1][u][v] + w[0][i][u] + w[0][i][v], abs(s[u] - s[v])));
            }
        }
        floyd(i);
    }

    SegT::build(1, 1, n);

    while(q--) {
        int sx, sy, tx, ty;
        cin >> sy >> sx >> ty >> tx;
        if(sx == tx && sy == ty) {
            cout << 0 << '\n';
            continue;
        }
        if(sx > tx) swap(sx, tx), swap(sy, ty);
        int res[M] = {0};
        for(int i = 1; i <= m; i++) res[i] = INF;
        res[sy] = 0;
        SegT::query(1, 1, n, sx, tx, res);
        cout << res[ty] << '\n';
    }

    return 0;
}