跳转至

NTT

对于一个长为 \(n=2^k\) 的序列 \(a_{0\sim n-1}\),考虑多项式函数 \(F(x)=\sum_{i=0}^{n-1}a_ix^i\),那么序列 \(a\) 的 NTT 也是一个长为 \(n\) 的序列,它的第 \(i\) 项是 \(F(w_{n}^{i})\)。其中 \(w_{n}^{i}=(w_{n})^i\)\(w_n=g^{(p-1)/n}\)

可以证明,两个序列 \(a,b\) 和它们的卷积 \(a*b\) 满足 \(NTT(a)\cdot NTT(B)=NTT(a*b)\)。其中 '\(\cdot\)' 表示对应位置相乘。

同样,定义序列 \(a\) 的逆 \(NTT\) 也是一个长为 \(n\) 的序列,它的第 \(i\) 项是 \(\frac{1}{n}F(w_{n}^{-i})\)(对,就是把后 \(n-1\) 项翻转一下然后除以归一化系数,但是一般我们使用传入逆元的写法)。

代码
void ntt(int a[], int g) {
    for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((1 << 20) * (i & 1));
    for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k] * w % MOD;
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x + MOD - y) % MOD;
                w = w * w0 % MOD;
            }
        }
    }
}

逆 ntt 只需传入原根 \(3\) 的逆元 \(998244354/3\) 即可实现。注意归一化系数 \(\frac 1n\)

NTT 加速递推

考虑一类递推式:

\[ f(i)=\sum_{j=1}^{i}f(i-j)g(j) \]

朴素的递推是 \(O(n^2)\) 的。注意到其形式和卷积非常类似,但是其中的 \(f\) 是在线的,因此无法直接 NTT。

考虑分治,设当前区间为 \([l,r)\),计算 \(mid=\frac{l+r}{2}\),然后先递归处理 \([l,mid)\),然后一遍 NTT 处理 \([l,mid)\to [mid,r)\) 的贡献,然后再递归处理 \([mid,r)\) 内部的贡献。这样做我们可以保证每次计算贡献的时候,左侧的 \(f\) 都是最终结果。不难分析出这样做的时间复杂度为 \(O(n\log^2 n)\)

模板
#include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;

int n = 131072;

int inv[N];
int f[N], g[N];

inline int qpow(int a, int b) {
    int res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

void ntt(int a[], int n, int g) {
    for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k] * w % MOD;
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x + MOD - y) % MOD;
                w = w * w0 % MOD;
            }
        }
    }
}

int a[N], b[N], c[N];

void solve(int l, int r) {
    if(l + 1 == r) return;
    int mid = (l + r) >> 1, len = r - l;
    solve(l, mid);
    for(int i = l; i < mid; i++) a[i - l] = f[i];
    for(int i = mid; i < r; i++) a[i - l] = 0;
    for(int i = 0; i < len; i++) b[i] = g[i];
    ntt(a, len, G);
    ntt(b, len, G);
    for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
    ntt(c, len, iG);
    for(int i = mid, t = qpow(len, MOD - 2); i < r; i++) f[i] = (f[i] + c[i - l] * t % MOD) % MOD;
    solve(mid, r);
}

signed main() {

    int m;
    cin >> m;
    for(int i = 1; i <= m - 1; i++) cin >> g[i];

    f[0] = 1;
    solve(0, n);

    for(int i = 0; i <= m - 1; i++) cout << f[i] << ' ';
    cout << '\n';

    return 0;
}

loj575 不等关系

给定一个长为 \(n-1\) 的字符串 \(s\),只包含 '>' 和 '<' 两种字符。称一个长为 \(n\) 的排列 \(p\) 是合法的,当且仅当 \(\forall i\in [1,n-1],s_i=\text{'<'},~p_i<p_{i+1}\),否则 \(p_i>p_{i+1}\)。问合法排列数。

\(n\le 10^5\)

有一种显然的 \(O(n^2)\) dp,但是状态数就高达 \(n^2\),因此难以继续优化。

注意到 \(s\) 可以看成 \(n-1\) 个较为独立的限制,因此考虑容斥。由于两种类型的限制难以处理,因此考虑将一种限制容斥为另一种,这里将大于号容斥成小于号。先钦定所有小于号成立,然后设第 \(i\) 个限制表示不满足第 \(i\) 个大于号(也就是小于)。

我们每次钦定若干个大于号的位置填小于号,然后乘上对应的容斥系数即可。现在问题转化为:只有小于号和空位。

注意到一个小于号组成的连续段内,数字只有一种排列顺序。因此总方案数可以写成

\[ \frac{n!}{\prod_{i=1}len_i!} \]

其中 \(len_i\) 表示第 \(i\) 个连续段的长度。由此我们也可以设计出一种连续段 dp:设 \(f_i\) 表示前缀 \(i\) 的答案。有转移:

\[ f_i=\sum_{1\le j<i}\frac{1}{i-j}(-1)^{cnt(j,i)}f_j \]

考虑优化,注意到式子形如一种卷积的形式,考虑分治 ntt。对于容斥系数,在本层分治时将其贡献拆成中点左侧和右侧,卷积之前先把左半区间中后缀大于号数量为奇数的加一个负号;卷积之后,再把右半区间结果序列中前缀大于号数量为奇数的加一个负号即可。

代码
#include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5 + 10;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;

int n;
char str[N];
int s[N], f[N];

inline int qpow(int a, int b) {
    int res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

void ntt(int a[], int n, int g) {
    static int inv[N];
    for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k] * w % MOD;
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x + MOD - y) % MOD;
                w = w * w0 % MOD;
            }
        }
    }
}

int inv[N], fact[N], ifact[N];
int a[N], b[N], c[N];

void solve(int l, int r) {
    if(l + 1 == r) return;
    int mid = (l + r) >> 1, len = r - l;
    solve(l, mid);
    for(int i = mid - 1, j = 1; i >= l; i--) {
        j ^= s[i];
        a[i - l] = j ? (MOD - f[i]) % MOD : f[i];
    }
    for(int i = mid; i < r; i++) a[i - l] = 0;
    for(int i = 0; i < len; i++) b[i] = ifact[i];
    ntt(a, len, G);
    ntt(b, len, G);
    for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
    ntt(c, len, iG);
    for(int i = 0, in = qpow(len, MOD - 2); i < len; i++) c[i] = c[i] * in % MOD;
    for(int i = mid, j = 0; i < r; i++) {
        if(s[i]) f[i] = (f[i] + (j ? MOD - c[i - l] : c[i - l])) % MOD;
        j ^= s[i];
    }
    solve(mid, r);
}

signed main() {

    inv[1] = 1;
    for(int i = 2; i < N; i++) inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
    fact[0] = 1;
    for(int i = 1; i < N; i++) fact[i] = fact[i - 1] * i % MOD;
    ifact[0] = 1;
    for(int i = 1; i < N; i++) ifact[i] = ifact[i - 1] * inv[i] % MOD;

    cin >> (str + 1);
    do ++n; while(str[n]);
    str[0] = str[n] = '>';

    for(int i = 0; i <= n; i++) s[i] = str[i] == '>';

    f[0] = fact[n];
    solve(0, 131072);

    cout << f[n] << '\n';

    return 0;
}