跳转至

250419 D0 模拟赛 T1 题解

题意

给定一个长为 \(n\) 的排列 \(p\)。你可以进行一次操作:选定若干个互不相交的区间 \([l_1,r_1],[l_2,r_2],\cdots,[l_k,r_k]\),然后将每个区间内部进行升序排序。问通过这一次操作可以得到多少种不同的 \(p'\),输出答案对 \(998244353\) 取模的即如果。

\(n\le 2\times 10^5\)

题解

首先,对于区间之间的空隙,我们可以使用长度等于 \(1\) 的小区间填满。因此操作等价于将序列划分为若干区间,然后分别排序。

我们考虑使用 DP 求解。我们设 \(f_i\) 表示前缀 \(i\) 的答案。转移时,枚举前缀中的最后一个连续区间 \([j,i]\) 的位置 \(j\),然后从 \(f_{j-1}\) 转移。

注意到这样会统计到一些重复的贡献,因为把一个大区间拆分成两个(值域)无交的小区间,两者得出的序列是相同的,应该记为一种情况。

考虑钦定单射:对于一类等价的方案,我们只令其中的一种产生贡献。称一个区间是合法的,当且仅当它不能被拆分成两个区间分别排序,还能得到相同的效果。如果一种方案使用的所有区间都是合法的,则我们令这个方案产生贡献。能够证明,对于一种合法的 \(p'\),一定存在且仅存在一种合法的方案(反证法易证)。

为了去除不合法的区间所产生的重复贡献,我们考虑不合法区间 \([i,j]\) 的性质:存在一个分界点 \(k\) 使得 \(\max\{a[i,k-1]\}<\min\{a[k,j]\}\)

到这里,我们已经可以使用一个二维的 dp 解决这个问题:使用第一个维度记录下标,第二个维度记录最后一个区间中的最大值。通过前缀和优化可以做到 \(O(n^2)\)

这个 dp 已经很优了,不容易直接进行优化。我们进一步观察不合法区间的性质。考虑固定 \(\max\{a[i,k-1]\}=a_p\)。记 \(a_i\) 左边第一个比它大的位置为 \(mxl[i]\),右边第一个比它大的位置为 \(mxr[i]\)。容易观察到 \(k\le mxr[p]\),否则 \(\max\{a[i,k]\}\ne a_p\)

进一步的,我们发现如果 \(k< mxr[p]\),则 \(a_k\in [k,j]\) 会导致 \(\min\{a[k,j]\}<\max\{a[i,k-1]\}\),无法找出不合法的区间。因此一定有 \(k=mxr[p]\)

接下来我们分别考察 \(i\)\(j\) 的取值范围。显然 \(i\in [mxl[p]+1,p]\)。而 \(j\) 需要满足 \(\min\{a[k,j]\}>\max\{a[i,k-1]\}=a_p\)。我们可以使用 st 表和二分,找到 \(j\) 的上界;其下界显然为 \(k\)

这样,所有不合法的区间都被我们分为了 \(O(n)\) 组区间,我们只需要在转移时排除这些区间即可。

具体的,我们可以使用线段树,将当前所有不合法的转移点都标记在线段树上(区间 \(+1\)),然后从没有标记的位置转移(查最小值的数量)。时间复杂度 \(O(n\log n)\)

AC 代码
#include<iostream>
#include<vector>
#define int long long
using namespace std;
const int N = 2e5 + 10;
const int LOGN = 19;
const int MOD = 998244353;

struct Range {
    int l, r, w;
};

int n;
int a[N];

int mxl[N], mxr[N];
int sta[N], top;

namespace st {
    int mn[LOGN][N], lg[N];
    void init() {
        for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        for(int i = 1; i <= n; i++) mn[0][i] = a[i];
        for(int k = 1; k < LOGN; k++) {
            for(int i = 1; i + (1 << k) - 1 <= n; i++) {
                mn[k][i] = min(mn[k - 1][i], mn[k - 1][i + (1 << (k - 1))]);
            }
        }
    }
    int get_mn(int l, int r) {
        int d = lg[r - l + 1];
        return min(mn[d][l], mn[d][r - (1 << d) + 1]);
    }
}

namespace SegT {
    struct Node {
        int val, cnt;
        inline Node() { val = 0; cnt = 0; }
        inline Node operator+(const Node &b) const {
            if(b.val < val) return b;
            Node res = *this;
            if(b.val == res.val) (res.cnt += b.cnt) %= MOD;
            return res;
        }
    };
    Node tr[4 * N];
    int tag[4 * N];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        tr[p] = tr[lc(p)] + tr[rc(p)];
    }
    inline void move_tag(int p, int tg) {
        tr[p].val += tg;
        tag[p] += tg;
    }
    inline void push_down(int p) {
        if(!tag[p]) return;
        move_tag(lc(p), tag[p]);
        move_tag(rc(p), tag[p]);
        tag[p] = 0;
    }
    void modify(int p, int l, int r, int q, int v) {
        if(l == r) {
            tr[p].cnt = v;
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(q <= mid) modify(lc(p), l, mid, q, v);
        else modify(rc(p), mid + 1, r, q, v);
        push_up(p);
    }
    void add(int p, int l, int r, int ql, int qr, int v) {
        if(ql <= l && r <= qr) {
            move_tag(p, v);
            return;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(ql <= mid) add(lc(p), l, mid, ql, qr, v);
        if(mid < qr) add(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    int query(int p, int l, int r, int ql, int qr) {
        if(ql <= l && r <= qr) {
            return !tr[p].val ? tr[p].cnt : 0;
        }
        push_down(p);
        int mid = (l + r) >> 1;
        if(qr <= mid) return query(lc(p), l, mid, ql, qr);
        if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
        return (query(lc(p), l, mid, ql, qr) + query(rc(p), mid + 1, r, ql, qr)) % MOD;
    }
}

vector<Range> ivld[N];

int f[N];

// #define O_J

signed main() {

    #ifdef O_J
    freopen("bai.in", "r", stdin);
    freopen("bai.out", "w", stdout);
    #endif

    cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i];

    st::init();

    top = 0;
    sta[top] = 0;
    for(int i = 1; i <= n; i++) {
        while(top && a[sta[top]] < a[i]) --top;
        mxl[i] = sta[top] + 1;
        sta[++top] = i;
    }

    top = 0;
    sta[top] = n + 1;
    for(int i = n; i >= 1; i--) {
        while(top && a[sta[top]] < a[i]) --top;
        mxr[i] = sta[top] - 1;
        sta[++top] = i;
    }

    for(int i = 1; i <= n; i++) {
        if(mxr[i] == n) continue;
        int l1 = mxl[i], r1 = i;
        int l2 = mxr[i] + 1, r2;
        int l = mxr[i] + 1, r = n;
        while(l < r) {
            int mid = (l + r + 1) >> 1;
            if(st::get_mn(l2, mid) > a[i]) l = mid;
            else r = mid - 1;
        }
        r2 = l;
        ivld[l2].push_back({l1 - 1, r1 - 1, 1});
        ivld[r2 + 1].push_back({l1 - 1, r1 - 1, -1});
    }

    SegT::modify(1, 0, n, 0, 1);
    for(int i = 1; i <= n; i++) {
        for(Range &o : ivld[i]) {
            SegT::add(1, 0, n, o.l, o.r, o.w);
        }
        f[i] = SegT::query(1, 0, n, 0, i - 1);
        SegT::modify(1, 0, n, i, f[i]);
    }

    cout << f[n] << '\n';

    return 0;
}