跳转至

FWT(快速沃尔什变换)

推荐大家去看一看 Alex_Wei 的博客,肯定写的比我好。里面还有集合幂级数。

FWT 可以在 \(O(k2^k)\)\(O(n\log n)\))的时间内计算位运算卷积:

\[ c_i=\sum_{j\oplus k=i}a_{j}b_{k} \]

其中 \(\oplus\) 表示一种二进制位运算,通常是按位与、按位或、按位异或。

FWT

FWT 是作用于序列的一种线性变换,即

\[ \operatorname{FWT}(A)_i=\sum_{j=0}^{2^p-1}A_jc_{i,j} \]

因此

\[ \begin{align*} &\operatorname{FWT}(kA)=k\cdot \operatorname{FWT}(A)\\ &\operatorname{FWT}(A+B)=\operatorname{FWT}(A)+\operatorname{FWT}(B) \end{align*} \]

且对于序列 \(C=A*B\),满足

\[ \operatorname{FWT}(A)\cdot\operatorname{FWT}(B)=\operatorname{FWT}(C) \]

这样,我们能推导出 \(c\) 合法的一个充要条件:

\[ c_{i,j}c_{i,k}=c_{i,j\oplus k} \]
推导过程
\[ \begin{align*} \big(\sum_{j=0}^{2^p-1}a_jc_{i,j}\big)\big(\sum_{j=0}^{2^p-1}b_jc_{i,j}\big)=&\ \sum_{j=0}^{2^p-1}(\sum_{j_1=0}^{2^p-1}a_{j_1}b_{j\oplus j_1})c_{i,j}\\ \sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1}c_{i,j_2}=&\ \sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1\oplus j_2} \end{align*} \]

恒成立,因此 \(c_{i,j}c_{i,k}=c_{i,j\oplus k}\)

由于位运算的各位之间是独立的,因此我们可以钦定 \(c_{i,j}=\prod_{p=0}^{k-1} c_{[2^p]i,[2^p]j}\)。这样,我们只需考察 \(c_{0\sim 1,0\sim 1}\) 即可。为了使 \(\operatorname{FWT}\) 有逆,\(c_{0\sim 1,0\sim 1}\) 也应该有逆。

或运算

我们直接提供一个满秩的 \(c_{0\sim 1,0\sim 1}\) 矩阵:

\[ \begin{bmatrix} 1& 0\\ 1& 1\\ \end{bmatrix} \]

或者也可以理解为 \(\operatorname{FWT}(A)_i=\sum[j\subseteq i]a_j\)(FMT)。可以从这个条件推出 \(\operatorname{FWT}(A)\cdot \operatorname{FWT}(B)=\operatorname{FWT}(A*B)\)

与运算

基本同上,把矩阵转置即可。

异或运算

异或运算的 \(c_{0\sim 1,0\sim 1}\) 为:

\[ \begin{bmatrix} 1& 1\\ 1& -1\\ \end{bmatrix} \]

也可以理解为:\(\operatorname{FWT}(A)_i=\sum_{j}(-1)^{\operatorname{popcnt}(i~\cap~j)}a_j\)。Alex_wei 的博客里面有构造证明。

inline void fwt () {
    for(int k = 1; k < n; k <<= 1) {
        int l = k << 1;
        for(int i = 0; i < n; i += l) {
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k];
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x - y + MOD) % MOD;
            }
        }
    }
}

CF1906K Deck-Building Game

给定 \(n\) 个数 \(a_{1\sim n}\),求:

\[ \sum_{S\subseteq \{1,2,3,\cdots,n\}} \left[\bigoplus a[S_i]=0\right]2^{|S|} \]

\(n\le 10^5\)

定义 \(x_p\cdot x_q=x_{p\oplus q}\)(异或卷积),那么答案即

\[ [1]\prod_{i=1}^{n}(2x^{a_i}+1) \]

现在问题变为如何快速求出 \(10^5\) 个二项式的异或卷积。注意到在我们钦定了值域的前 \(k\) 位之后,剩下的数字只能组合出 \(2^{17-k}\) 级别个异或和。因为前 \(k\) 位要么不变,要么全是 \(0\),只和奇偶性有关。而后 \(17-k\) 位显然只有 \(2^{17-k}\) 种组合。

考虑在值域上分治,钦定一个二进制位的前缀。由于前面保证了一段区间能组合出的值是和区间长度同阶的,这里直接模仿分治 ntt 就行了。由于需要讨论当前第 \(k\) 位是多少,因此需要记一下区间内集合大小的奇偶性。

代码
#include<iostream>
#define int long long
using namespace std;
const int N = 1.5e5 + 10;
const int MOD = 998244353;

int n;
int cnt[N];
int res[2][N];

int inv[N], fact[N], ifact[N], pw2[N];

inline int C(int a, int b) { return fact[a] * ifact[a - b] % MOD * ifact[b] % MOD; }

void fwt(int a[], int n, int t = 1) {
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k];
                a[i + j] = (x + y % MOD) % MOD;
                a[i + j + k] = (x + MOD - y % MOD) % MOD;
            }
        }
    }
    if(t == -1)
    for(int i = 0; i < n; i++) a[i] = a[i] * inv[n] % MOD;
}

void solve(int l, int r) {
    if(l + 1 == r) {
        for(int i = 0; i <= cnt[l]; i++) (res[i & 1][l] += C(cnt[l], i) * pw2[i] % MOD) %= MOD;
        return;
    }
    int mid = (l + r) >> 1, len = mid - l;
    solve(l, mid);
    solve(mid, r);
    fwt(res[0] + l, len);
    fwt(res[0] + mid, len);
    fwt(res[1] + l, len);
    fwt(res[1] + mid, len);
    for(int i = l; i < mid; i++) {
        int x = res[0][i], y = res[1][i], z = res[0][i + len], w = res[1][i + len];
        res[0][i] = x * z % MOD;
        res[1][i] = y * z % MOD;
        res[0][i + len] = y * w % MOD;
        res[1][i + len] = x * w % MOD;
    }
    fwt(res[0] + l, len, -1);
    fwt(res[0] + mid, len, -1);
    fwt(res[1] + l, len, -1);
    fwt(res[1] + mid, len, -1);
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);

    inv[1] = fact[0] = ifact[0] = pw2[0] = 1;
    for(int i = 2; i < N; i++) inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
    for(int i = 1; i < N; i++) fact[i] = fact[i - 1] * i % MOD;
    for(int i = 1; i < N; i++) ifact[i] = ifact[i - 1] * inv[i] % MOD;
    for(int i = 1; i < N; i++) pw2[i] = pw2[i - 1] * 2 % MOD;

    cin >> n;
    for(int i = 1; i <= n; i++) {
        int x;
        cin >> x;
        ++cnt[x];
    }

    solve(0, 131072);

    cout << (res[0][0] + res[1][0]) % MOD << '\n';

    return 0;
}