NTT
对于一个长为 \(n=2^k\) 的序列 \(\{a_{0\sim n-1}\}\),考虑多项式
\[
F(x)=\sum_{i=0}^{n-1}a_ix^i
\]
定义 \(\{b_{0\sim n-1}\}=\operatorname{NTT}(\{a_{0\sim n-1}\})\) 满足
\[
b_i=F(w_n^i)
\]
其中 \(w_{n}\) 是 \(n\) 次单位根。NTT 本质上就是将若干点值代入了以 \(a\) 为系数的多项式。那么根据多项式乘法的定义,考虑 \(a,b,c\) 三个序列,其中 \(c_i=\sum_{j=0}^{i}a_jb_{i-j}\),那么 \(F_c(x)=F_a(x)\times F_b(x)\),因此 \(\operatorname{NTT}(c)_i=\operatorname{NTT}(a)_i\times \operatorname{NTT}(b)_i\)。如果我们能在 \(O(n\log n)\) 的时间复杂度内实现 \(\operatorname{NTT}\) 和逆 \(\operatorname{NTT}\),那么只需要做一次点乘就可以实现卷积了。
由于我不想再花时间写一遍证明了,我甚至都不想把式子再写一遍,所以直接贴个模板在下面。
代码
| inline void ntt(int a[], int n, int g) {
static int swp[N2];
for(int i = 1; i < n; i++) swp[i] = (swp[i >> 1] >> 1) | (n >> 1) * (i & 1);
for(int i = 0; i < n; i++) if(i < swp[i]) swap(a[i], a[swp[i]]);
for(int k = 1; k < n; k <<= 1)
for(int i = 0, w0 = qpow(g, (MOD - 1) / (k << 1)); i < n; i += (k << 1))
for(int j = 0, w = 1; j < k; j++, w = (ull)w * w0 % MOD) {
int x = a[i + j], y = (ull)a[i + j + k] * w % MOD;
a[i + j] = mod(x + y), a[i + j + k] = mod(x - y + MOD);
}
if(g == 332748118) for(int i = 0, inv_n = qpow(n, MOD - 2); i < n; i++) a[i] = (ull)a[i] * inv_n % MOD;
}
|
其中,如果要进行逆 ntt,则传入原根 \(g=332748118\) 即可。
分治 NTT / 半在线卷积
考虑递推式:
\[
f(i)=\sum_{j=1}^{i}f(i-j)g(j)
\]
朴素的递推是 \(O(n^2)\) 的。注意到其形式和卷积非常类似,但是其中的 \(f\) 是在线的,因此无法直接 NTT。
考虑分治,设当前区间为 \([l,r)\),计算 \(mid=\frac{l+r}{2}\),然后先递归处理 \([l,mid)\),然后一遍 NTT 处理 \([l,mid)\to [mid,r)\) 的贡献,然后再递归处理 \([mid,r)\) 内部的贡献。这样做我们可以保证每次计算贡献的时候,左侧的 \(f\) 都是最终结果。不难分析出这样做的时间复杂度为 \(O(n\log^2 n)\)。
模板
| #include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;
int n = 131072;
int inv[N];
int f[N], g[N];
inline int qpow(int a, int b) {
int res = 1;
while(b) {
if(b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
void ntt(int a[], int n, int g) {
for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
for(int k = 1; k < n; k <<= 1) {
for(int i = 0; i < n; i += 2 * k) {
int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x + MOD - y) % MOD;
w = w * w0 % MOD;
}
}
}
}
int a[N], b[N], c[N];
void solve(int l, int r) {
if(l + 1 == r) return;
int mid = (l + r) >> 1, len = r - l;
solve(l, mid);
for(int i = l; i < mid; i++) a[i - l] = f[i];
for(int i = mid; i < r; i++) a[i - l] = 0;
for(int i = 0; i < len; i++) b[i] = g[i];
ntt(a, len, G);
ntt(b, len, G);
for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
ntt(c, len, iG);
for(int i = mid, t = qpow(len, MOD - 2); i < r; i++) f[i] = (f[i] + c[i - l] * t % MOD) % MOD;
solve(mid, r);
}
signed main() {
int m;
cin >> m;
for(int i = 1; i <= m - 1; i++) cin >> g[i];
f[0] = 1;
solve(0, n);
for(int i = 0; i <= m - 1; i++) cout << f[i] << ' ';
cout << '\n';
return 0;
}
|
给定一个长为 \(n-1\) 的字符串 \(s\),只包含 '>' 和 '<' 两种字符。称一个长为 \(n\) 的排列 \(p\) 是合法的,当且仅当 \(\forall i\in [1,n-1],s_i=\text{'<'},~p_i<p_{i+1}\),否则 \(p_i>p_{i+1}\)。问合法排列数。
\(n\le 10^5\)
有一种显然的 \(O(n^2)\) dp,但是状态数就高达 \(n^2\),因此难以继续优化。
注意到 \(s\) 可以看成 \(n-1\) 个较为独立的限制,因此考虑容斥。由于两种类型的限制难以处理,因此考虑将一种限制容斥为另一种,这里将大于号容斥成小于号。先钦定所有小于号成立,然后设第 \(i\) 个限制表示不满足第 \(i\) 个大于号(也就是小于)。
我们每次钦定若干个大于号的位置填小于号,然后乘上对应的容斥系数即可。现在问题转化为:只有小于号和空位。
注意到一个小于号组成的连续段内,数字只有一种排列顺序。因此总方案数可以写成
\[
\frac{n!}{\prod_{i=1}len_i!}
\]
其中 \(len_i\) 表示第 \(i\) 个连续段的长度。由此我们也可以设计出一种连续段 dp:设 \(f_i\) 表示前缀 \(i\) 的答案。有转移:
\[
f_i=\sum_{1\le j<i}\frac{1}{i-j}(-1)^{cnt(j,i)}f_j
\]
考虑优化,注意到式子形如一种半在线卷积,考虑分治 NTT。对于容斥系数,在本层分治时将其贡献拆成中点左侧和右侧,卷积之前先把左半区间中后缀大于号数量为奇数的加一个负号;卷积之后,再把右半区间结果序列中前缀大于号数量为奇数的加一个负号即可。
代码
| #include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5 + 10;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;
int n;
char str[N];
int s[N], f[N];
inline int qpow(int a, int b) {
int res = 1;
while(b) {
if(b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
void ntt(int a[], int n, int g) {
static int inv[N];
for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
for(int k = 1; k < n; k <<= 1) {
for(int i = 0; i < n; i += 2 * k) {
int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x + MOD - y) % MOD;
w = w * w0 % MOD;
}
}
}
}
int inv[N], fact[N], ifact[N];
int a[N], b[N], c[N];
void solve(int l, int r) {
if(l + 1 == r) return;
int mid = (l + r) >> 1, len = r - l;
solve(l, mid);
for(int i = mid - 1, j = 1; i >= l; i--) {
j ^= s[i];
a[i - l] = j ? (MOD - f[i]) % MOD : f[i];
}
for(int i = mid; i < r; i++) a[i - l] = 0;
for(int i = 0; i < len; i++) b[i] = ifact[i];
ntt(a, len, G);
ntt(b, len, G);
for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
ntt(c, len, iG);
for(int i = 0, in = qpow(len, MOD - 2); i < len; i++) c[i] = c[i] * in % MOD;
for(int i = mid, j = 0; i < r; i++) {
if(s[i]) f[i] = (f[i] + (j ? MOD - c[i - l] : c[i - l])) % MOD;
j ^= s[i];
}
solve(mid, r);
}
signed main() {
inv[1] = 1;
for(int i = 2; i < N; i++) inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
fact[0] = 1;
for(int i = 1; i < N; i++) fact[i] = fact[i - 1] * i % MOD;
ifact[0] = 1;
for(int i = 1; i < N; i++) ifact[i] = ifact[i - 1] * inv[i] % MOD;
cin >> (str + 1);
do ++n; while(str[n]);
str[0] = str[n] = '>';
for(int i = 0; i <= n; i++) s[i] = str[i] == '>';
f[0] = fact[n];
solve(0, 131072);
cout << f[n] << '\n';
return 0;
}
|