跳转至

P12496 [集训队互测 2024] 又一个欧拉数问题

给定长度 \(n\),常数 \(k\)\(2^{k-1}\) 个数 \(w[0\sim 2^{k-1}-1]\)。定义一个长为 \(n\) 的排列的权值为

\[ \prod_{i=1}^{n-k+1}w\Big[\sum_{j=0}^{k-2}[p_{i+j}<p_{i+j+1}]2^j\Big] \]

求所有排列的权值和。

\(2\le k\le 4,~k\le n\le 10^5\)

考虑对上升序列 \(a_i=[p_i<p_{i+1}]\) 进行计数。如果已知 \(a\),那么我们可以通过一些经典的容斥来计数 \(p\) 的数量,具体的,设序列 \(b\) 取遍所有满足 \(a_i\le b_i\le 1\) 的序列,即把下降的限制容斥成没有限制或钦定上升。此时 \(p\) 的方案数等于 \(n!\) 除以 \(b\) 的所有 \(1\) 连续段的长度加一的阶乘,也就是一个可重集排列。

然后现在我们只需要解决:设长为 \(n-1\) 的序列二元组 \((a,b)\) 取遍所有满足 \(\forall i,~0\le a_i\le b_i\le 1\) 的序列,设序列二元组的权值为

\[ \frac{n!}{\prod_{i=1}^{m}(len_i+1)!}\prod_{i=1}^{n-k+1}w\Big[\sum_{j=0}^{k-2}a_{i+j}2^j\Big]\prod_{i=1}^{n}(-1)^{b_i-a_i} \]

求全体序列二元组的权值和。其中 \(len_{1\sim m}\) 表示 \(b\) 中各个连续段的长度。

现在就可以 dp 了,我们设计状态 \(f_{i,S}\) 表示考虑了序列的前 \(i\) 位,最后 \(k-2\)\(a_i\) 的状态为 \(S\),对应的权值和。每次转移考虑添加一个 \(b_i=1\) 的连续段,段间以 \(0\) 分开。这样的复杂度差不多是 \(O(2^kn^2)\)

细节差不多就是,每个段形如 01111..111(一个 \(0\) 和一些 \(1\)),然后把这些段拼接到前缀的连续 \(1\) 上。

考虑优化,由于状态数(\(S\))较小,考虑用矩阵乘法刻画。乘以一个矩阵表示加入一个连续段,为了限制长度恰好为 \(n-1\),考虑令矩阵中的每个元素都是一个多项式,若连续段加上段前面的 \(0\) 总共长度为 \(t\),那么将其贡献放到 \(x^t\) 的系数中。

然后矩阵的行列下标肯定就是代表 \(S\) 的初末状态,所以矩阵大小是 \(2^{k-2}\times 2^{k-2}\) 的。那么矩阵显然可以 \(O(4^kn)\) 预处理出来。

考虑初值怎么搞,因为如果太短的话是不能产生 \(w\) 的贡献的。直接跑上面的朴素 dp,但是长度超过 \(k-1\) 的部分就不继续刷表转移了。这样复杂度其实是 \(O(2^knk)\)。初值形如一个 \(2^{k-2}\) 列,值为多项式的行向量,第 \(i\) 个位置中的 \(x^j\) 系数表示末位状态为 \(i\),长度为 \(j\) 的权值之和。多项式的 \(x^{k-3}\) 次及以下的位置显然都是 \(0\)

设初值向量为 \(f\),转移矩阵为 \(op\),那么我们要求的就是

\[ \begin{align*} g&=f\cdot \Big(\sum_{i=0}^{+\infty}op^i\Big)\\ &=f\cdot \frac{1}{I-op} \end{align*} \]

然而朴素多项式矩阵求逆的复杂度高达 \(O(8^kn\log n)\),不能通过。考虑将多项式牛顿迭代推广到矩阵上,猜测出一个式子:

\[ X=2X_0-X_0FX_0 \]

代入检验:

\[ \begin{align*} FX&=2FX_0-FX_0FX_0\\ &=2FX_0-(FX_0-I)(FX_0-I)-FX_0-FX_0+I\\ &\equiv I\pmod {x^{2k}} \end{align*} \]

单次矩阵乘法只需要做 \(4^k\) 次 ntt,\(8^k\) 次点乘,由于牛顿迭代的多项式长度总和是 \(O(n)\) 的,所以时间复杂度 \(O(4^kn\log n+8^kn)\)

代码
#include<iostream>
#include<vector>
#include<cassert>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 1e5 + 10;
const int N2 = 262144;
const int MOD = 998244353;

inline void print_frac(int x) {
    if(x <= 200) cout << x << ' ';
    else {
        for(int i = 1; i <= 200; i++)
            if((ll)x * i % MOD <= 200) {
                cout << (ll)x * i % MOD << '/' << i << ' ';
                return;
            }
        cout << x << ' ';
    }
}

inline void add(int &a, int b) { a += b; (a >= MOD) && (a -= MOD); }
inline int mod(int a) { return (a >= MOD) ? (a - MOD) : a; }
inline int qpow(int a, int b) {
    int res = 1;
    while(b) {
        if(b & 1) res = (ull)res * a % MOD;
        a = (ull)a * a % MOD;
        b >>= 1;
    }
    return res;
}
inline void ntt(int a[], int n, int g) {
    static int swp[N2];
    for(int i = 1; i < n; i++) swp[i] = (swp[i >> 1] >> 1) | (n >> 1) * (i & 1);
    for(int i = 0; i < n; i++) if(i < swp[i]) swap(a[i], a[swp[i]]);
    for(int k = 1; k < n; k <<= 1)
        for(int i = 0, w0 = qpow(g, (MOD - 1) / (k << 1)); i < n; i += (k << 1))
            for(int j = 0, w = 1; j < k; j++, w = (ull)w * w0 % MOD) {
                int x = a[i + j], y = (ull)a[i + j + k] * w % MOD;
                a[i + j] = mod(x + y), a[i + j + k] = mod(x - y + MOD);
            }
    if(g == 332748118) for(int i = 0, inv_n = qpow(n, MOD - 2); i < n; i++) a[i] = (ull)a[i] * inv_n % MOD;
}

int n, m, k, ans;
int inv[N], fac[N], ifac[N];
int w[8];

// 感觉我的模板写的挺好看的,可以用作参考
size_t mod_len;
struct poly {
    vector<int> a;
    inline int &operator[](size_t index) { return a[index]; }
    inline const int &operator[](size_t index) const { return a[index]; }
    inline poly() {}
    inline poly(size_t _n) { a.resize(_n); }
    inline poly(const poly &b) { a.resize(mod_len); for(size_t i = 0; i < mod_len; i++) a[i] = b[i]; }
    inline poly(poly &&b) { a = move(b.a); a.resize(mod_len); }
    inline void resize(size_t _n) { a.resize(_n); }
    inline poly &operator=(const poly &b) { a.resize(mod_len); for(size_t i = 0; i < mod_len; i++) a[i] = b[i]; return *this; }
    inline poly operator+(const poly &b) const { poly res(mod_len); for(size_t i = 0; i < mod_len; i++) res[i] = mod(a[i] + b[i]); return res; }
    inline poly &operator+=(const poly &b) { for(size_t i = 0; i < mod_len; i++) add(a[i], b[i]); return *this; }
    inline poly operator-(const poly &b) const { poly res(mod_len); for(size_t i = 0; i < mod_len; i++) res[i] = mod(a[i] + MOD - b[i]); return res; }
    inline poly operator-() const { poly res(mod_len); for(size_t i = 0; i < mod_len; i++) res[i] = mod(MOD - a[i]); return res; }
    inline poly dot(const poly &b) const { poly res(mod_len); for(size_t i = 0; i < mod_len; i++) res[i] = (ull)a[i] * b[i] % (size_t)MOD; return res; }
    inline poly &dot_eq(const poly &b) { for(size_t i = 0; i < mod_len; i++) a[i] = (ull)a[i] * b[i] % (size_t)MOD; return *this; }
    inline void ntt(int g) { if(mod_len > a.size()) resize(mod_len); ::ntt(a.data(), mod_len, g); }
    inline poly operator*(const poly &b) const { poly x(*this), y(b); mod_len <<= 1, x.ntt(3), y.ntt(3), x.dot_eq(y), x.ntt(332748118), mod_len >>= 1, x.resize(mod_len); return x; }
};
struct mat {
    poly a[4][4];
    inline poly *operator[](size_t index) { return a[index]; }
    inline const poly *operator[](size_t index) const { return a[index]; }
    inline mat() {}
    inline mat(size_t _n) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j].resize(_n); }
    inline mat(const mat &b) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j] = b[i][j]; }
    inline mat(mat &&b) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j] = move(b[i][j]); }
    inline void resize(size_t _n) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j].resize(_n); }
    inline mat &operator=(const mat &b) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j] = b[i][j]; return *this; }
    inline mat operator+(const mat &b) const { mat res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) res[i][j] = a[i][j] + b[i][j]; return res; }
    inline mat operator-(const mat &b) const { mat res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) res[i][j] = a[i][j] - b[i][j]; return res; }
    inline mat operator-() const { mat res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) res[i][j] = -a[i][j]; return res; }
    inline mat dot(const mat &b) const { mat res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) for(size_t k = 0; k < m; k++) res[i][k] += a[i][j].dot(b[j][k]); return res; }
    inline void ntt(int g) { for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) a[i][j].ntt(g); }
    inline mat operator*(const mat &b) const { mat x(*this), y(b); mod_len <<= 1, x.ntt(3), y.ntt(3), x = x.dot(y), x.ntt(332748118), mod_len >>= 1, x.resize(mod_len); return x; }
    inline mat get_inv() const {
        mat res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) res[i][j][0] = (assert(a[i][j][0] == (i == j)), a[i][j][0]);
        size_t pre_modlen = mod_len;
        for(mod_len = 2; mod_len <= pre_modlen; mod_len <<= 1) {
            mat F = (*this); mod_len <<= 1, res.ntt(3), F.ntt(3);
            res = res + res - res.dot(F).dot(res);
            res.ntt(332748118), mod_len >>= 1, res.resize(mod_len);
        } mod_len = pre_modlen; return res;
    }
} I;
struct vec {
    poly a[4];
    inline poly &operator[](size_t index) { return a[index]; }
    inline const poly &operator[](size_t index) const { return a[index]; }
    inline vec() {}
    inline vec(size_t _n) { for(size_t i = 0; i < m; i++) a[i].resize(_n); }
    inline vec(const vec &b) { for(size_t i = 0; i < m; i++) a[i] = b[i]; }
    inline vec(vec &&b) { for(size_t i = 0; i < m; i++) a[i] = move(b[i]); }
    inline void resize(size_t _n) { for(size_t i = 0; i < m; i++) a[i].resize(_n); }
    inline vec &operator=(const vec &b) { for(size_t i = 0; i < m; i++) a[i] = b[i]; return *this; }
    inline vec operator+(const vec &b) const { vec res(mod_len); for(size_t i = 0; i < m; i++) res[i] = a[i] + b[i]; return res; }
    inline vec operator-(const vec &b) const { vec res(mod_len); for(size_t i = 0; i < m; i++) res[i] = a[i] - b[i]; return res; }
    inline vec operator-() const { vec res(mod_len); for(size_t i = 0; i < m; i++) res[i] = -a[i]; return res; }
    inline vec dot(const mat &b) const { vec res(mod_len); for(size_t i = 0; i < m; i++) for(size_t j = 0; j < m; j++) res[j] += a[i].dot(b[i][j]); return res; }
    inline void ntt(int g) { for(size_t i = 0; i < m; i++) a[i].ntt(g); }
    inline vec operator*(const mat &b) const { vec x(*this); mat y(b); mod_len <<= 1, x.ntt(3), y.ntt(3), x = x.dot(y), x.ntt(332748118), mod_len >>= 1, x.resize(mod_len); return x; }
};

vec f, g;
mat op;

int dp[N][4];
inline int trs(int s, int c) { return k == 1 ? 0 : s >> 1 | (c << k - 2); }

int main() {

    inv[1] = fac[0] = ifac[0] = 1;
    for(int i = 2; i < N; i++) inv[i] = MOD - (ll)inv[MOD % i] * (MOD / i) % MOD;
    for(int i = 1; i < N; i++) fac[i] = (ll)fac[i - 1] * i % MOD;
    for(int i = 1; i < N; i++) ifac[i] = (ll)ifac[i - 1] * inv[i] % MOD;

    cin >> n >> k; k--; mod_len = n; m = (1 << k - 1);
    while(mod_len != (mod_len & -mod_len)) mod_len++;
    I.resize(mod_len); for(int i = 0; i < m; i++) I[i][i][0] = 1;
    for(int i = 0; i < (1 << k); i++) cin >> w[i];

    dp[0][0] = 1; f.resize(mod_len);
    for(int i = 0; i < max(1, k - 1); i++) {
        int tmp[4] = {};
        for(int s = 0; s < m; s++) add(tmp[trs(s, 0)], (ll)dp[i][s] * (i + 1 >= k ? w[s] : 1) % MOD);
        for(int s = 0; s < m; s++) add(dp[i + 1][s], tmp[s]);
        for(int j = 1; i + 1 + j < n; j++) {
            int nw[4] = {};
            for(int s = 0; s < m; s++) {
                add(nw[trs(s, 0)], MOD - (ll)tmp[s] * (i + 1 + j >= k ? w[s] : 1) % MOD);
                add(nw[trs(s, 1)],       (ll)tmp[s] * (i + 1 + j >= k ? w[s | 1 << k - 1] : 1) % MOD);
            }
            for(int s = 0; s < m; s++) {
                add(dp[i + 1 + j][s], (ll)nw[s] * ifac[j + 1] % MOD);
                tmp[s] = nw[s];
            }
        }
        // 只有 i=0 时可以添加前缀连续 1。
        if(i == 0) {
            for(int s = 0; s < m; s++) tmp[s] = 0;
            for(int s = 0; s < m; s++) tmp[s] = dp[i][s];
            for(int j = 1; i + j < n; j++) {
                int nw[4] = {};
                for(int s = 0; s < m; s++) {
                    add(nw[trs(s, 0)], MOD - (ll)tmp[s] * (i + j >= k ? w[s] : 1) % MOD);
                    add(nw[trs(s, 1)],       (ll)tmp[s] * (i + j >= k ? w[s | 1 << k - 1] : 1) % MOD);
                }
                for(int s = 0; s < m; s++) {
                    add(dp[i + j][s], (ll)nw[s] * ifac[j + 1] % MOD);
                    tmp[s] = nw[s];
                }
            }
        }
        for(int s = 0; s < m; s++) dp[i][s] = 0;
    }

    for(int s = 0; s < m; s++) {
        for(int i = 0; i < n; i++) f[s][i] = dp[i][s];
    }

    op.resize(mod_len);
    for(int s = 0; s < m; s++) {
        int tmp[4] = {};
        tmp[trs(s, 0)] = w[s];
        add(op[s][trs(s, 0)][1], w[s]);
        for(int i = 1; i < n - 1; i++) {
            int nw[4] = {};
            for(int t = 0; t < m; t++) {
                add(nw[trs(t, 0)], MOD - (ll)tmp[t] * w[t] % MOD);
                add(nw[trs(t, 1)],       (ll)tmp[t] * w[t | 1 << k - 1] % MOD);
            }
            for(int t = 0; t < m; t++) {
                add(op[s][t][i + 1], (ll)nw[t] * ifac[i + 1] % MOD);
                tmp[t] = nw[t];
            }
        }
    }

    g = f * (I - op).get_inv();

    for(int s = 0; s < m; s++) add(ans, g[s][n - 1]);
    cout << (ll)ans * fac[n] % MOD << '\n';

    return 0;
}