NTT
对于一个长为 \(n=2^k\) 的序列 \(a_{0\sim n-1}\),考虑多项式函数 \(F(x)=\sum_{i=0}^{n-1}a_ix^i\),那么序列 \(a\) 的 NTT 也是一个长为 \(n\) 的序列,它的第 \(i\) 项是 \(F(w_{n}^{i})\)。其中 \(w_{n}^{i}=(w_{n})^i\),\(w_n=g^{(p-1)/n}\)。
可以证明,两个序列 \(a,b\) 和它们的卷积 \(a*b\) 满足 \(NTT(a)\cdot NTT(B)=NTT(a*b)\)。其中 '\(\cdot\)' 表示对应位置相乘。
同样,定义序列 \(a\) 的逆 \(NTT\) 也是一个长为 \(n\) 的序列,它的第 \(i\) 项是 \(\frac{1}{n}F(w_{n}^{-i})\)(对,就是把后 \(n-1\) 项翻转一下然后除以归一化系数,但是一般我们使用传入逆元的写法)。
代码
| void ntt(int a[], int g) {
for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((1 << 20) * (i & 1));
for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
for(int k = 1; k < n; k <<= 1) {
for(int i = 0; i < n; i += 2 * k) {
int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x + MOD - y) % MOD;
w = w * w0 % MOD;
}
}
}
}
|
逆 ntt 只需传入原根 \(3\) 的逆元 \(998244354/3\) 即可实现。注意归一化系数 \(\frac 1n\)。
NTT 加速递推
考虑一类递推式:
\[
f(i)=\sum_{j=1}^{i}f(i-j)g(j)
\]
朴素的递推是 \(O(n^2)\) 的。注意到其形式和卷积非常类似,但是其中的 \(f\) 是在线的,因此无法直接 NTT。
考虑分治,设当前区间为 \([l,r)\),计算 \(mid=\frac{l+r}{2}\),然后先递归处理 \([l,mid)\),然后一遍 NTT 处理 \([l,mid)\to [mid,r)\) 的贡献,然后再递归处理 \([mid,r)\) 内部的贡献。这样做我们可以保证每次计算贡献的时候,左侧的 \(f\) 都是最终结果。不难分析出这样做的时间复杂度为 \(O(n\log^2 n)\)。
模板
| #include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;
int n = 131072;
int inv[N];
int f[N], g[N];
inline int qpow(int a, int b) {
int res = 1;
while(b) {
if(b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
void ntt(int a[], int n, int g) {
for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
for(int k = 1; k < n; k <<= 1) {
for(int i = 0; i < n; i += 2 * k) {
int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x + MOD - y) % MOD;
w = w * w0 % MOD;
}
}
}
}
int a[N], b[N], c[N];
void solve(int l, int r) {
if(l + 1 == r) return;
int mid = (l + r) >> 1, len = r - l;
solve(l, mid);
for(int i = l; i < mid; i++) a[i - l] = f[i];
for(int i = mid; i < r; i++) a[i - l] = 0;
for(int i = 0; i < len; i++) b[i] = g[i];
ntt(a, len, G);
ntt(b, len, G);
for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
ntt(c, len, iG);
for(int i = mid, t = qpow(len, MOD - 2); i < r; i++) f[i] = (f[i] + c[i - l] * t % MOD) % MOD;
solve(mid, r);
}
signed main() {
int m;
cin >> m;
for(int i = 1; i <= m - 1; i++) cin >> g[i];
f[0] = 1;
solve(0, n);
for(int i = 0; i <= m - 1; i++) cout << f[i] << ' ';
cout << '\n';
return 0;
}
|
给定一个长为 \(n-1\) 的字符串 \(s\),只包含 '>' 和 '<' 两种字符。称一个长为 \(n\) 的排列 \(p\) 是合法的,当且仅当 \(\forall i\in [1,n-1],s_i=\text{'<'},~p_i<p_{i+1}\),否则 \(p_i>p_{i+1}\)。问合法排列数。
\(n\le 10^5\)
有一种显然的 \(O(n^2)\) dp,但是状态数就高达 \(n^2\),因此难以继续优化。
注意到 \(s\) 可以看成 \(n-1\) 个较为独立的限制,因此考虑容斥。由于两种类型的限制难以处理,因此考虑将一种限制容斥为另一种,这里将大于号容斥成小于号。先钦定所有小于号成立,然后设第 \(i\) 个限制表示不满足第 \(i\) 个大于号(也就是小于)。
我们每次钦定若干个大于号的位置填小于号,然后乘上对应的容斥系数即可。现在问题转化为:只有小于号和空位。
注意到一个小于号组成的连续段内,数字只有一种排列顺序。因此总方案数可以写成
\[
\frac{n!}{\prod_{i=1}len_i!}
\]
其中 \(len_i\) 表示第 \(i\) 个连续段的长度。由此我们也可以设计出一种连续段 dp:设 \(f_i\) 表示前缀 \(i\) 的答案。有转移:
\[
f_i=\sum_{1\le j<i}\frac{1}{i-j}(-1)^{cnt(j,i)}f_j
\]
考虑优化,注意到式子形如一种卷积的形式,考虑分治 ntt。对于容斥系数,在本层分治时将其贡献拆成中点左侧和右侧,卷积之前先把左半区间中后缀大于号数量为奇数的加一个负号;卷积之后,再把右半区间结果序列中前缀大于号数量为奇数的加一个负号即可。
代码
| #include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5 + 10;
const int MOD = 998244353;
const int G = 3;
const int iG = 332748118;
int n;
char str[N];
int s[N], f[N];
inline int qpow(int a, int b) {
int res = 1;
while(b) {
if(b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
void ntt(int a[], int n, int g) {
static int inv[N];
for(int i = 0; i < n; i++) inv[i] = (inv[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 0; i < n; i++) if(i < inv[i]) swap(a[i], a[inv[i]]);
for(int k = 1; k < n; k <<= 1) {
for(int i = 0; i < n; i += 2 * k) {
int w0 = qpow(g, (MOD - 1) / (2 * k)), w = 1;
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x + MOD - y) % MOD;
w = w * w0 % MOD;
}
}
}
}
int inv[N], fact[N], ifact[N];
int a[N], b[N], c[N];
void solve(int l, int r) {
if(l + 1 == r) return;
int mid = (l + r) >> 1, len = r - l;
solve(l, mid);
for(int i = mid - 1, j = 1; i >= l; i--) {
j ^= s[i];
a[i - l] = j ? (MOD - f[i]) % MOD : f[i];
}
for(int i = mid; i < r; i++) a[i - l] = 0;
for(int i = 0; i < len; i++) b[i] = ifact[i];
ntt(a, len, G);
ntt(b, len, G);
for(int i = 0; i < len; i++) c[i] = a[i] * b[i] % MOD;
ntt(c, len, iG);
for(int i = 0, in = qpow(len, MOD - 2); i < len; i++) c[i] = c[i] * in % MOD;
for(int i = mid, j = 0; i < r; i++) {
if(s[i]) f[i] = (f[i] + (j ? MOD - c[i - l] : c[i - l])) % MOD;
j ^= s[i];
}
solve(mid, r);
}
signed main() {
inv[1] = 1;
for(int i = 2; i < N; i++) inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
fact[0] = 1;
for(int i = 1; i < N; i++) fact[i] = fact[i - 1] * i % MOD;
ifact[0] = 1;
for(int i = 1; i < N; i++) ifact[i] = ifact[i - 1] * inv[i] % MOD;
cin >> (str + 1);
do ++n; while(str[n]);
str[0] = str[n] = '>';
for(int i = 0; i <= n; i++) s[i] = str[i] == '>';
f[0] = fact[n];
solve(0, 131072);
cout << f[n] << '\n';
return 0;
}
|