FWT(快速沃尔什变换)
FWT 可以在 \(O(p2^p)\) 或 \(O(n\log n)\) 的时间内计算位运算卷积:
\[
c_i=\sum_{j\oplus k=i}a_{j}b_{k}
\]
其中 \(\oplus\) 表示一种二进制位运算,通常是按位与、按位或、按位异或。
FWT
FWT 是作用于序列的一种线性变换,即
\[
\operatorname{FWT}(A)_i=\sum_{j=0}^{2^p-1}A_jc_{i,j}\tag{1}
\]
因此
\[
\begin{align*}
&\operatorname{FWT}(kA)=k\cdot \operatorname{FWT}(A)\\
&\operatorname{FWT}(A+B)=\operatorname{FWT}(A)+\operatorname{FWT}(B)
\end{align*}
\]
且对于序列 \(C=A*_{\oplus}B\),满足
\[
\operatorname{FWT}_{\oplus}(A)\cdot\operatorname{FWT}_{\oplus}(B)=\operatorname{FWT}_{\oplus}(C)\tag{2}
\]
根据 \((2)\),我们能推导出 \(c\) 合法的一个充要条件:
\[
c_{i,j}c_{i,k}=c_{i,j\oplus k}\tag{3}
\]
推导过程
由 \((2)\):
\[
\begin{align*}
\big(\sum_{j=0}^{2^p-1}a_jc_{i,j}\big)\big(\sum_{j=0}^{2^p-1}b_jc_{i,j}\big)=&\ \sum_{j=0}^{2^p-1}(\sum_{j_1=0}^{2^p-1}a_{j_1}b_{j\oplus j_1})c_{i,j}\\
\sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1}c_{i,j_2}=&\ \sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1\oplus j_2}
\end{align*}
\]
恒成立,因此 \(c_{i,j}c_{i,k}=c_{i,j\oplus k}\)。
为了便于计算,我们写出 \(c\) 成立的一个充分不必要条件:
\[
\begin{cases}
c_{i,j}=\prod_{k=0}^{p-1} c_{[2^k]i,[2^k]j} \tag{4}\\
c_{i,j}c_{i,k}=c_{i,j\oplus k}\ (i,j,k\in \{0,1\})
\end{cases}
\]
其中 \([2^k]j\) 表示 \(j\) 二进制表示下的第 \(k\) 位,取值为 \(\{0,1\}\)。
推导过程
\[
\begin{align*}
c_{i,j_1}c_{i,j_2}=&\ (\prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1})(\prod_{k=0}^{p-1}{c_{[2^k]i,[2^k]j_2}})\\
=&\ \prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1}c_{[2^k]i,[2^k]j_2}\\
=&\ \prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1\oplus j_2}\\
=&\ c_{i,j_1\oplus j_2}
\end{align*}
\]
这样,我们只需要对运算 \(\oplus\) 找到符合条件的 \(2\times 2\) 矩阵:
\[
\begin{bmatrix}
c_{0,0}& c_{0,1}\\
c_{1,0}& c_{1,1}\\
\end{bmatrix}
\]
然后就能扩展到整个 \(c\) 矩阵。
现在考虑如何快速计算 \(\operatorname{FWT}(A)\)。由于 \(c_{i,j}\) 可以拆位,我们分别考虑每一位。假设当前考虑了二进制位的集合 \(S\),按下标的二进制去掉 \(S\) 的位之后,序列可以被分为若干类,我们现在只考虑同一类以内的贡献。同时,\(c\) 在 \((4)\) 式中的连乘形式也只会考虑 \(S\) 中的位。
尝试向 \(S\) 中加入一个新的元素 \(k\)(新考虑了第 \(k\) 位),这个过程会将 \(2^{p-|S|}\) 类合并成 \(2^{p-|S|-1}\) 类。合并前,序列的第 \(i\) 项为:
\[
A_i=\sum_{[\overline S]j=[\overline S]i}{c_{i,j}a_j}
\]
考虑第 \(k\) 位,它会使 \(S'\leftarrow S+\{k\}\),\(c'_{i,j}\leftarrow c_{i,j}\times c_{[2^k]i,[2^k]j}\),序列变为:
\[
\begin{align*}
A'_i=\sum_{[\overline S']j=[\overline S']i}{c'_{i,j}a_j}&=\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}}\\
&=\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}\times \big[[2^k]j=0\big]}\\&+\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}\times \big[[2^k]j=1\big]}\\
&=
\begin{cases}
c_{0,0}A_{i_0}+c_{0,1}A_{i_1},\ [2^k]i=0\\
c_{1,0}A_{i_0}+c_{1,1}A_{i_1},\ [2^k]i=1
\end{cases}
\end{align*}
\]
其中 \(i_0\) 表示将 \(i\) 的第 \(k\) 位赋值为 \(0\),\(i_1\) 表示将 \(i\) 的第 \(k\) 位赋值位 \(1\)。上式正是我们熟见的形式,我们上面的过程也阐释了为什么 FWT 外层的循环可以枚举任意排列。写成代码如下:
// 异或
inline void fwt () {
for(int k = 1; k < n; k <<= 1) {
int l = k << 1;
for(int i = 0; i < n; i += l) {
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k];
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
}
iFWT
iFWT 是 FWT 的逆变换;iFWT 能正常进行的充要条件是 \(c_{0\sim 1,0\sim 1}\) 满秩。iFWT 仍然是线性变换,即
\[
\operatorname{iFWT}(A)_i=\sum_{j=0}^{2^p-1}A_jd_{i,j}
\]
且满足
\[
\operatorname{iFWT}(\operatorname{FWT(A)})=A
\]
考察 FWT 的每一个小步都是用 \(c_{0\sim 1,0\sim 1}\) 去乘以 \([A_{i_0}\ \ A_{i_1}]^T\),因此我们令 \(d_{0\sim 1,0\sim 1}\) 为 \(c_{0\sim 1,0\sim 1}\) 的逆矩阵,然后用 \(d_{0\sim 1,0\sim 1}\) 去乘以变换之后的 \([A_{i_0}\ \ A_{i_1}]^T\),即可得到变换之前的值。我们倒序进行每一步变换,一定可以实现 FWT 的逆变换。
同时我们发现,由于 FWT 的顺序是任意的,因此对应 iFWT 的顺序也是任意的;同时由于 FWT 的顺序不影响结果,因此 iFWT 的顺序也不影响结果。
写成代码如下:
// 异或
for(int k = 1; k < n; k <<= 1) {
int l = k << 1;
for(int i = 0; i < n; i += l) {
for(int j = 0; j < k; j++) {
int x = a[i + j] * inv[2] % MOD, y = a[i + j + k] * inv[2] % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
更多位运算
我们只考虑具有交换律、不总返回 \(1\) 且不总返回 \(0\) 的位运算;此时真值表还剩 \(6\) 种位运算,分别是与、或、异或以及与非、或非、同或。因此我们只需要实现基本的与、或、异或就可以涵盖所有值得讨论的运算。
对于三种运算,我们分别写出满足 \((3)\) 且满秩的 \(2\times 2\) 的 \(c\) 矩阵:
与
\[
C=
\begin{bmatrix}
1& 1\\
0& 1\\
\end{bmatrix}
\]
\[
D=
\begin{bmatrix}
1& -1\\
0& 1
\end{bmatrix}
\]
或
\[
C=
\begin{bmatrix}
1& 0\\
1& 1\\
\end{bmatrix}
\]
\[
D=
\begin{bmatrix}
1& 0\\
-1& 1
\end{bmatrix}
\]
异或
\[
C=
\begin{bmatrix}
1& 1\\
1& -1\\
\end{bmatrix}
\]
\[
D=
\begin{bmatrix}
1/2& 1/2\\
1/2& -1/2
\end{bmatrix}
\]
模板代码
| #include<iostream>
#include<vector>
#define int long long
using namespace std;
const int N = 17;
const int MOD = 998244353;
int op[2][2];
void FWT(int n, vector<int>& a) {
for(int i = n; i >= 2; i >>= 1) {
int o = i / 2;
for(int j = 0; j < n; j += i) {
for(int k = 0; k < o; k++) {
int x = a[j + k], y = a[j + k + o];
a[j + k] = (op[0][0] * x % MOD + op[0][1] * y % MOD) % MOD;
a[j + k + o] = (op[1][0] * x % MOD + op[1][1] * y % MOD) % MOD;
}
}
}
}
int n;
vector<int> a, b, a1, b1, c;
signed main() {
cin >> n;
n = 1 << n;
a.resize(n);
b.resize(n);
c.resize(n);
for(int i = 0; i < n; i++) cin >> a[i];
for(int i = 0; i < n; i++) cin >> b[i];
op[0][0] = 1; op[0][1] = 0;
op[1][0] = 1; op[1][1] = 1;
a1 = a, b1 = b;
FWT(n, a1);
FWT(n, b1);
for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
op[0][0] = 1; op[0][1] = 0;
op[1][0] = MOD - 1; op[1][1] = 1;
FWT(n, c);
for(int i = 0; i < n; i++) cout << c[i] << ' ';
cout << '\n';
op[0][0] = 1; op[0][1] = 1;
op[1][0] = 0; op[1][1] = 1;
a1 = a, b1 = b;
FWT(n, a1);
FWT(n, b1);
for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
op[0][0] = 1; op[0][1] = MOD - 1;
op[1][0] = 0; op[1][1] = 1;
FWT(n, c);
for(int i = 0; i < n; i++) cout << c[i] << ' ';
cout << '\n';
int half = 499122177;
op[0][0] = 1; op[0][1] = 1;
op[1][0] = 1; op[1][1] = MOD - 1;
a1 = a, b1 = b;
FWT(n, a1);
FWT(n, b1);
for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
op[0][0] = half; op[0][1] = half;
op[1][0] = half; op[1][1] = MOD - half;
FWT(n, c);
for(int i = 0; i < n; i++) cout << c[i] << ' ';
cout << '\n';
return 0;
}
|
本题考察 FWT 的线性性。
代码
| #include<iostream>
#include<vector>
#include<cstring>
#include<cassert>
#define ll signed
#define int short
using namespace std;
const int N = 3e4 + 10;
const int MOD = 1e4 + 7;
struct Edge {
int v;
ll next;
} pool[2 * N];
ll ne, head[N];
void addEdge (int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
int inv[MOD];
struct Vec {
int a[128];
inline Vec () { memset(a, 0, sizeof(a)); }
inline void clear () { memset(a, 0, sizeof(a)); }
inline void fill (int x) { for(int i = 0; i < 128; i++) a[i] = x; }
inline int &operator[] (int index) { return a[index]; }
inline const int &operator[] (int index) const { return a[index]; }
inline void fwt () {
for(int k = 1; k < 128; k <<= 1) {
int l = k << 1;
for(int i = 0; i < 128; i += l) {
for(int j = 0; j < k; j++) {
int x = a[i + j], y = a[i + j + k];
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
}
inline void ifwt () {
for(int k = 1; k < 128; k <<= 1) {
int l = k << 1;
for(int i = 0; i < 128; i += l) {
for(int j = 0; j < k; j++) {
int x = a[i + j] * inv[2] % MOD, y = a[i + j + k] * inv[2] % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
}
};
inline void mul (const Vec &a, const Vec &b, Vec &res) {
for(int i = 0; i < 128; i++) res[i] = a[i] * b[i] % MOD;
}
inline void mul_add (const Vec &a, const Vec &b, Vec &res) {
for(int i = 0; i < 128; i++) res[i] = (res[i] + a[i] * b[i] % MOD) % MOD;
}
// 矩阵套向量
struct Matrix {
Vec a[3][3];
inline Vec* operator[] (int index) { return a[index]; }
inline const Vec* operator[] (int index) const { return a[index]; }
inline void clear () {
a[0][0].clear(); a[0][1].clear(); a[0][2].clear();
a[1][0].clear(); a[1][1].clear(); a[1][2].clear();
a[2][0].clear(); a[2][1].clear(); a[2][2].clear();
}
};
inline void mul (const Matrix &a, const Matrix &b, Matrix &res) {
mul(a[0][0], b[0][0], res[0][0]);
mul(a[0][0], b[0][2], res[0][2]);
mul_add(a[0][2], b[2][2], res[0][2]);
mul(a[1][0], b[0][0], res[1][0]);
mul_add(a[1][1], b[1][0], res[1][0]);
mul(a[1][1], b[1][1], res[1][1]);
mul(a[1][0], b[0][2], res[1][2]);
mul_add(a[1][1], b[1][2], res[1][2]);
mul_add(a[1][2], b[2][2], res[1][2]);
res[2][2].fill(1);
}
// 向量套向量
struct Node {
Vec a[3];
inline Node() { a[2].fill(1); }
inline Vec &operator[] (int index) { return a[index]; }
inline const Vec &operator[] (int index) const { return a[index]; }
};
inline void mul (const Matrix &a, Node b, Node &res) {
mul(a[0][0], b[0], res[0]);
mul_add(a[0][2], b[2], res[0]);
mul(a[1][0], b[0], res[1]);
mul_add(a[1][1], b[1], res[1]);
mul_add(a[1][2], b[2], res[1]);
res[2] = b[2];
}
int n, q, V;
int w[N];
int sz[N], son[N], fa[N];
int dfn[N], id[N], dt;
int top[N], bot[N];
Vec f[N], g[N];
Vec a[128];
Matrix m[N];
Vec zeroCnt[N];
Vec nozero[N];
void dfs1 (int u) {
sz[u] = 1;
for(ll i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(sz[v]) continue;
fa[v] = u;
dfs1(v);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2 (int u, int tp) {
dfn[u] = ++dt;
id[dt] = u;
top[u] = tp;
if(son[u]) {
dfs2(son[u], tp);
bot[u] = bot[son[u]];
} else bot[u] = u;
for(ll i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(dfn[v]) continue;
dfs2(v, v);
}
}
void dfs3 (int u) {
nozero[u] = f[u] = a[0];
if(son[u]) dfs3(son[u]);
for(ll e = head[u]; e; e = pool[e].next) {
int v = pool[e].v;
if(v == fa[u] || v == son[u]) continue;
dfs3(v);
for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * (f[v][i] + 1) % MOD;
for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + g[v][i]) % MOD;
for(int i = 0; i < 128; i++) ((f[v][i] + 1) % MOD) ? (nozero[u][i] = nozero[u][i] * (f[v][i] + 1) % MOD) : ++zeroCnt[u][i];
}
Matrix &t = m[dfn[u]];
for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * a[w[u]][i] % MOD;
for(int i = 0; i < 128; i++) a[w[u]][i] ? (nozero[u][i] = nozero[u][i] * a[w[u]][i] % MOD) : ++zeroCnt[u][i];
for(int i = 0; i < 128; i++) {
if(zeroCnt[u][i]) assert(f[u][i] == 0), assert(nozero[u][i]);
}
t[0][0] = t[0][2] = t[1][0] = t[1][2] = f[u];
for(int i = 0; i < 128; i++) t[1][1][i] = 1;
for(int i = 0; i < 128; i++) t[2][2][i] = 1;
for(int i = 0; i < 128; i++) t[1][2][i] = (t[1][2][i] + g[u][i]) % MOD;
if(son[u]) {
for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * (f[son[u]][i] + 1) % MOD;
for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + g[son[u]][i]) % MOD;
}
for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + f[u][i]) % MOD;
}
namespace SegT {
Matrix tr[4 * N];
inline ll lc (ll x) { return x << 1; }
inline ll rc (ll x) { return x << 1 | 1; }
inline void push_up (ll p) { mul(tr[lc(p)], tr[rc(p)], tr[p]); }
void build (ll p, int l, int r) {
if(l == r) {
tr[p] = m[l];
return;
}
int mid = (l + r) >> 1;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
push_up(p);
}
void update (ll p, int l, int r, int q) {
if(l == r) {
tr[p] = m[l];
return;
}
int mid = (l + r) >> 1;
if(q <= mid) update(lc(p), l, mid, q);
else update(rc(p), mid + 1, r, q);
push_up(p);
}
void query (ll p, int l, int r, int ql, int qr, Node &res) {
if(ql <= l && r <= qr) {
mul(tr[p], res, res);
return;
}
int mid = (l + r) >> 1;
if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
if(ql <= mid) query(lc(p), l, mid, ql, qr, res);
}
}
void modify (int p, int v) {
{
Matrix &t = m[dfn[p]];
for(int i = 0; i < 128; i++) {
int t1 = a[w[p]][i];
int t2 = a[v][i];
if(t1 == 0) --zeroCnt[p][i];
else nozero[p][i] = nozero[p][i] * inv[t1] % MOD;
if(t2 == 0) ++zeroCnt[p][i];
else nozero[p][i] = nozero[p][i] * t2 % MOD;
assert(zeroCnt[p][i] >= 0);
if(zeroCnt[p][i]) {
t[1][2][i] = (t[1][2][i] - t[0][0][i] + MOD) % MOD;
t[0][0][i] = t[0][2][i] = t[1][0][i] = 0;
} else {
t[1][2][i] = (t[1][2][i] + nozero[p][i] - t[0][0][i] + MOD) % MOD;
t[0][0][i] = t[0][2][i] = t[1][0][i] = nozero[p][i];
}
}
w[p] = v;
SegT::update(1, 1, n, dfn[p]);
}
p = top[p];
while(p != 1) {
Node tmp;
SegT::query(1, 1, n, dfn[p], dfn[bot[p]], tmp);
Matrix &t = m[dfn[fa[p]]];
for(int i = 0; i < 128; i++) {
t[1][2][i] = (t[1][2][i] + tmp[1][i] - g[p][i] + MOD) % MOD;
}
for(int i = 0; i < 128; i++) {
int t1 = (f[p][i] + 1) % MOD;
int t2 = (tmp[0][i] + 1) % MOD;
int p1 = fa[p];
if(t1 == 0) --zeroCnt[p1][i];
else nozero[p1][i] = nozero[p1][i] * inv[t1] % MOD;
if(t2 == 0) ++zeroCnt[p1][i];
else nozero[p1][i] = nozero[p1][i] * t2 % MOD;
if(zeroCnt[p1][i] < 0) {
cout << p1 << ' ' << i << endl;
}
if(zeroCnt[p1][i]) {
t[1][2][i] = (t[1][2][i] - t[0][0][i] + MOD) % MOD;
t[0][0][i] = t[0][2][i] = t[1][0][i] = 0;
} else {
t[1][2][i] = (t[1][2][i] + nozero[p1][i] - t[0][0][i] + MOD) % MOD;
t[0][0][i] = t[0][2][i] = t[1][0][i] = nozero[p1][i];
}
}
SegT::update(1, 1, n, dfn[fa[p]]);
f[p] = tmp[0];
g[p] = tmp[1];
p = top[fa[p]];
}
}
int query (int k) {
Node tmp;
SegT::query(1, 1, n, dfn[1], dfn[bot[1]], tmp);
tmp[1].ifwt();
return tmp[1][k];
}
ll main () {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
inv[1] = 1;
for(int i = 2; i < MOD; i++) inv[i] = (MOD - (MOD / i * inv[MOD % i] % MOD)) % MOD;
cin >> n >> V;
for(int i = 1; i <= n; i++) cin >> w[i];
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
for(int i = 0; i < 128; i++) a[i][i] = 1;
for(int i = 0; i < 128; i++) a[i].fwt();
dfs1(1);
dfs2(1, 1);
dfs3(1);
SegT::build(1, 1, n);
cin >> q;
while (q--) {
string s;
int x, y, k;
cin >> s;
if(s == "Change") {
cin >> x >> y;
modify(x, y);
} else {
cin >> k;
cout << query(k) << '\n';
}
}
return 0;
}
|