跳转至

FWT(快速沃尔什变换)

FWT 可以在 \(O(p2^p)\)\(O(n\log n)\) 的时间内计算位运算卷积:

\[ c_i=\sum_{j\oplus k=i}a_{j}b_{k} \]

其中 \(\oplus\) 表示一种二进制位运算,通常是按位与、按位或、按位异或。

FWT

FWT 是作用于序列的一种线性变换,即

\[ \operatorname{FWT}(A)_i=\sum_{j=0}^{2^p-1}A_jc_{i,j}\tag{1} \]

因此

\[ \begin{align*} &\operatorname{FWT}(kA)=k\cdot \operatorname{FWT}(A)\\ &\operatorname{FWT}(A+B)=\operatorname{FWT}(A)+\operatorname{FWT}(B) \end{align*} \]

且对于序列 \(C=A*_{\oplus}B\),满足

\[ \operatorname{FWT}_{\oplus}(A)\cdot\operatorname{FWT}_{\oplus}(B)=\operatorname{FWT}_{\oplus}(C)\tag{2} \]

根据 \((2)\),我们能推导出 \(c\) 合法的一个充要条件:

\[ c_{i,j}c_{i,k}=c_{i,j\oplus k}\tag{3} \]
推导过程

\((2)\)

\[ \begin{align*} \big(\sum_{j=0}^{2^p-1}a_jc_{i,j}\big)\big(\sum_{j=0}^{2^p-1}b_jc_{i,j}\big)=&\ \sum_{j=0}^{2^p-1}(\sum_{j_1=0}^{2^p-1}a_{j_1}b_{j\oplus j_1})c_{i,j}\\ \sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1}c_{i,j_2}=&\ \sum_{j_1=0}^{2^p-1}\sum_{j_2=0}^{2^p-1}a_{j_1}b_{j_2}c_{i,j_1\oplus j_2} \end{align*} \]

恒成立,因此 \(c_{i,j}c_{i,k}=c_{i,j\oplus k}\)

为了便于计算,我们写出 \(c\) 成立的一个充分不必要条件:

\[ \begin{cases} c_{i,j}=\prod_{k=0}^{p-1} c_{[2^k]i,[2^k]j} \tag{4}\\ c_{i,j}c_{i,k}=c_{i,j\oplus k}\ (i,j,k\in \{0,1\}) \end{cases} \]

其中 \([2^k]j\) 表示 \(j\) 二进制表示下的第 \(k\) 位,取值为 \(\{0,1\}\)

推导过程
\[ \begin{align*} c_{i,j_1}c_{i,j_2}=&\ (\prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1})(\prod_{k=0}^{p-1}{c_{[2^k]i,[2^k]j_2}})\\ =&\ \prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1}c_{[2^k]i,[2^k]j_2}\\ =&\ \prod_{k=0}^{p-1}c_{[2^k]i,[2^k]j_1\oplus j_2}\\ =&\ c_{i,j_1\oplus j_2} \end{align*} \]

这样,我们只需要对运算 \(\oplus\) 找到符合条件的 \(2\times 2\) 矩阵:

\[ \begin{bmatrix} c_{0,0}& c_{0,1}\\ c_{1,0}& c_{1,1}\\ \end{bmatrix} \]

然后就能扩展到整个 \(c\) 矩阵。

现在考虑如何快速计算 \(\operatorname{FWT}(A)\)。由于 \(c_{i,j}\) 可以拆位,我们分别考虑每一位。假设当前考虑了二进制位的集合 \(S\),按下标的二进制去掉 \(S\) 的位之后,序列可以被分为若干类,我们现在只考虑同一类以内的贡献。同时,\(c\)\((4)\) 式中的连乘形式也只会考虑 \(S\) 中的位。

尝试向 \(S\) 中加入一个新的元素 \(k\)(新考虑了第 \(k\) 位),这个过程会将 \(2^{p-|S|}\) 类合并成 \(2^{p-|S|-1}\) 类。合并前,序列的第 \(i\) 项为:

\[ A_i=\sum_{[\overline S]j=[\overline S]i}{c_{i,j}a_j} \]

考虑第 \(k\) 位,它会使 \(S'\leftarrow S+\{k\}\)\(c'_{i,j}\leftarrow c_{i,j}\times c_{[2^k]i,[2^k]j}\),序列变为:

\[ \begin{align*} A'_i=\sum_{[\overline S']j=[\overline S']i}{c'_{i,j}a_j}&=\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}}\\ &=\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}\times \big[[2^k]j=0\big]}\\&+\sum_{[\overline S']j=[\overline S']i}{c_{i,j}a_j\times c_{[2^k]i,[2^k]j}\times \big[[2^k]j=1\big]}\\ &= \begin{cases} c_{0,0}A_{i_0}+c_{0,1}A_{i_1},\ [2^k]i=0\\ c_{1,0}A_{i_0}+c_{1,1}A_{i_1},\ [2^k]i=1 \end{cases} \end{align*} \]

其中 \(i_0\) 表示将 \(i\) 的第 \(k\) 位赋值为 \(0\)\(i_1\) 表示将 \(i\) 的第 \(k\) 位赋值位 \(1\)。上式正是我们熟见的形式,我们上面的过程也阐释了为什么 FWT 外层的循环可以枚举任意排列。写成代码如下:

// 异或
inline void fwt () {
    for(int k = 1; k < n; k <<= 1) {
        int l = k << 1;
        for(int i = 0; i < n; i += l) {
            for(int j = 0; j < k; j++) {
                int x = a[i + j], y = a[i + j + k];
                a[i + j] = (x + y) % MOD;
                a[i + j + k] = (x - y + MOD) % MOD;
            }
        }
    }
}

iFWT

iFWT 是 FWT 的逆变换;iFWT 能正常进行的充要条件是 \(c_{0\sim 1,0\sim 1}\) 满秩。iFWT 仍然是线性变换,即

\[ \operatorname{iFWT}(A)_i=\sum_{j=0}^{2^p-1}A_jd_{i,j} \]

且满足

\[ \operatorname{iFWT}(\operatorname{FWT(A)})=A \]

考察 FWT 的每一个小步都是用 \(c_{0\sim 1,0\sim 1}\) 去乘以 \([A_{i_0}\ \ A_{i_1}]^T\),因此我们令 \(d_{0\sim 1,0\sim 1}\)\(c_{0\sim 1,0\sim 1}\) 的逆矩阵,然后用 \(d_{0\sim 1,0\sim 1}\) 去乘以变换之后的 \([A_{i_0}\ \ A_{i_1}]^T\),即可得到变换之前的值。我们倒序进行每一步变换,一定可以实现 FWT 的逆变换。

同时我们发现,由于 FWT 的顺序是任意的,因此对应 iFWT 的顺序也是任意的;同时由于 FWT 的顺序不影响结果,因此 iFWT 的顺序也不影响结果。

写成代码如下:

// 异或
for(int k = 1; k < n; k <<= 1) {
    int l = k << 1;
    for(int i = 0; i < n; i += l) {
        for(int j = 0; j < k; j++) {
            int x = a[i + j] * inv[2] % MOD, y = a[i + j + k] * inv[2] % MOD;
            a[i + j] = (x + y) % MOD;
            a[i + j + k] = (x - y + MOD) % MOD;
        }
    }
}

更多位运算

我们只考虑具有交换律、不总返回 \(1\) 且不总返回 \(0\) 的位运算;此时真值表还剩 \(6\) 种位运算,分别是与、或、异或以及与非、或非、同或。因此我们只需要实现基本的与、或、异或就可以涵盖所有值得讨论的运算。

对于三种运算,我们分别写出满足 \((3)\) 且满秩的 \(2\times 2\)\(c\) 矩阵:

\[ C= \begin{bmatrix} 1& 1\\ 0& 1\\ \end{bmatrix} \]
\[ D= \begin{bmatrix} 1& -1\\ 0& 1 \end{bmatrix} \]

\[ C= \begin{bmatrix} 1& 0\\ 1& 1\\ \end{bmatrix} \]
\[ D= \begin{bmatrix} 1& 0\\ -1& 1 \end{bmatrix} \]

异或

\[ C= \begin{bmatrix} 1& 1\\ 1& -1\\ \end{bmatrix} \]
\[ D= \begin{bmatrix} 1/2& 1/2\\ 1/2& -1/2 \end{bmatrix} \]

P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)

模板代码
#include<iostream>
#include<vector>
#define int long long
using namespace std;
const int N = 17;
const int MOD = 998244353;

int op[2][2];

void FWT(int n, vector<int>& a) {
    for(int i = n; i >= 2; i >>= 1) {
        int o = i / 2;
        for(int j = 0; j < n; j += i) {
            for(int k = 0; k < o; k++) {
                int x = a[j + k], y = a[j + k + o];
                a[j + k] = (op[0][0] * x % MOD + op[0][1] * y % MOD) % MOD;
                a[j + k + o] = (op[1][0] * x % MOD + op[1][1] * y % MOD) % MOD;
            }
        }
    }
}

int n;
vector<int> a, b, a1, b1, c;

signed main() {

    cin >> n;
    n = 1 << n;
    a.resize(n);
    b.resize(n);
    c.resize(n);
    for(int i = 0; i < n; i++) cin >> a[i];
    for(int i = 0; i < n; i++) cin >> b[i];

    op[0][0] = 1;   op[0][1] = 0;
    op[1][0] = 1;   op[1][1] = 1;
    a1 = a, b1 = b;
    FWT(n, a1);
    FWT(n, b1);
    for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
    op[0][0] = 1;       op[0][1] = 0;
    op[1][0] = MOD - 1; op[1][1] = 1;
    FWT(n, c);
    for(int i = 0; i < n; i++) cout << c[i] << ' ';
    cout << '\n';

    op[0][0] = 1;   op[0][1] = 1;
    op[1][0] = 0;   op[1][1] = 1;
    a1 = a, b1 = b;
    FWT(n, a1);
    FWT(n, b1);
    for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
    op[0][0] = 1;   op[0][1] = MOD - 1;
    op[1][0] = 0;   op[1][1] = 1;
    FWT(n, c);
    for(int i = 0; i < n; i++) cout << c[i] << ' ';
    cout << '\n';

    int half = 499122177;

    op[0][0] = 1;   op[0][1] = 1;
    op[1][0] = 1;   op[1][1] = MOD - 1;
    a1 = a, b1 = b;
    FWT(n, a1);
    FWT(n, b1);
    for(int i = 0; i < n; i++) c[i] = a1[i] * b1[i] % MOD;
    op[0][0] = half;  op[0][1] = half;
    op[1][0] = half;  op[1][1] = MOD - half;
    FWT(n, c);
    for(int i = 0; i < n; i++) cout << c[i] << ' ';
    cout << '\n';

    return 0;
}

P3781 [SDOI2017] 切树游戏

本题考察 FWT 的线性性。

代码
#include<iostream>
#include<vector>
#include<cstring>
#include<cassert>
#define ll signed
#define int short
using namespace std;
const int N = 3e4 + 10;
const int MOD = 1e4 + 7;

struct Edge {
    int v;
    ll next;
} pool[2 * N];
ll ne, head[N];

void addEdge (int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int inv[MOD];

struct Vec {
    int a[128];
    inline Vec () { memset(a, 0, sizeof(a)); }
    inline void clear () { memset(a, 0, sizeof(a)); }
    inline void fill (int x) { for(int i = 0; i < 128; i++) a[i] = x; }
    inline int &operator[] (int index) { return a[index]; }
    inline const int &operator[] (int index) const { return a[index]; }
    inline void fwt () {
        for(int k = 1; k < 128; k <<= 1) {
            int l = k << 1;
            for(int i = 0; i < 128; i += l) {
                for(int j = 0; j < k; j++) {
                    int x = a[i + j], y = a[i + j + k];
                    a[i + j] = (x + y) % MOD;
                    a[i + j + k] = (x - y + MOD) % MOD;
                }
            }
        }
    }
    inline void ifwt () {
        for(int k = 1; k < 128; k <<= 1) {
            int l = k << 1;
            for(int i = 0; i < 128; i += l) {
                for(int j = 0; j < k; j++) {
                    int x = a[i + j] * inv[2] % MOD, y = a[i + j + k] * inv[2] % MOD;
                    a[i + j] = (x + y) % MOD;
                    a[i + j + k] = (x - y + MOD) % MOD;
                }
            }
        }
    }
};

inline void mul (const Vec &a, const Vec &b, Vec &res) {
    for(int i = 0; i < 128; i++) res[i] = a[i] * b[i] % MOD;
}
inline void mul_add (const Vec &a, const Vec &b, Vec &res) {
    for(int i = 0; i < 128; i++) res[i] = (res[i] + a[i] * b[i] % MOD) % MOD;
}

// 矩阵套向量
struct Matrix {
    Vec a[3][3];
    inline Vec* operator[] (int index) { return a[index]; }
    inline const Vec* operator[] (int index) const { return a[index]; }
    inline void clear () {
        a[0][0].clear(); a[0][1].clear(); a[0][2].clear();
        a[1][0].clear(); a[1][1].clear(); a[1][2].clear();
        a[2][0].clear(); a[2][1].clear(); a[2][2].clear();
    }
};

inline void mul (const Matrix &a, const Matrix &b, Matrix &res) {
    mul(a[0][0], b[0][0], res[0][0]);
    mul(a[0][0], b[0][2], res[0][2]);
    mul_add(a[0][2], b[2][2], res[0][2]);
    mul(a[1][0], b[0][0], res[1][0]);
    mul_add(a[1][1], b[1][0], res[1][0]);
    mul(a[1][1], b[1][1], res[1][1]);
    mul(a[1][0], b[0][2], res[1][2]);
    mul_add(a[1][1], b[1][2], res[1][2]);
    mul_add(a[1][2], b[2][2], res[1][2]);
    res[2][2].fill(1);
}

// 向量套向量
struct Node {
    Vec a[3];
    inline Node() { a[2].fill(1); }
    inline Vec &operator[] (int index) { return a[index]; }
    inline const Vec &operator[] (int index) const { return a[index]; }
};

inline void mul (const Matrix &a, Node b, Node &res) {
    mul(a[0][0], b[0], res[0]);
    mul_add(a[0][2], b[2], res[0]);
    mul(a[1][0], b[0], res[1]);
    mul_add(a[1][1], b[1], res[1]);
    mul_add(a[1][2], b[2], res[1]);
    res[2] = b[2];
}

int n, q, V;
int w[N];

int sz[N], son[N], fa[N];
int dfn[N], id[N], dt;
int top[N], bot[N];

Vec f[N], g[N];
Vec a[128];
Matrix m[N];
Vec zeroCnt[N];
Vec nozero[N];

void dfs1 (int u) {
    sz[u] = 1;
    for(ll i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(sz[v]) continue;
        fa[v] = u;
        dfs1(v);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

void dfs2 (int u, int tp) {
    dfn[u] = ++dt;
    id[dt] = u;
    top[u] = tp;
    if(son[u]) {
        dfs2(son[u], tp);
        bot[u] = bot[son[u]];
    } else bot[u] = u;
    for(ll i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(dfn[v]) continue;
        dfs2(v, v);
    }
}

void dfs3 (int u) {
    nozero[u] = f[u] = a[0];
    if(son[u]) dfs3(son[u]);
    for(ll e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa[u] || v == son[u]) continue;
        dfs3(v);
        for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * (f[v][i] + 1) % MOD;
        for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + g[v][i]) % MOD;
        for(int i = 0; i < 128; i++) ((f[v][i] + 1) % MOD) ? (nozero[u][i] = nozero[u][i] * (f[v][i] + 1) % MOD) : ++zeroCnt[u][i];
    }
    Matrix &t = m[dfn[u]];
    for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * a[w[u]][i] % MOD;
    for(int i = 0; i < 128; i++) a[w[u]][i] ? (nozero[u][i] = nozero[u][i] * a[w[u]][i] % MOD) : ++zeroCnt[u][i];

    for(int i = 0; i < 128; i++) {
        if(zeroCnt[u][i]) assert(f[u][i] == 0), assert(nozero[u][i]);
    }

    t[0][0] = t[0][2] = t[1][0] = t[1][2] = f[u];
    for(int i = 0; i < 128; i++) t[1][1][i] = 1;
    for(int i = 0; i < 128; i++) t[2][2][i] = 1;
    for(int i = 0; i < 128; i++) t[1][2][i] = (t[1][2][i] + g[u][i]) % MOD;
    if(son[u]) {
        for(int i = 0; i < 128; i++) f[u][i] = f[u][i] * (f[son[u]][i] + 1) % MOD;
        for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + g[son[u]][i]) % MOD;
    }
    for(int i = 0; i < 128; i++) g[u][i] = (g[u][i] + f[u][i]) % MOD;
}

namespace SegT {
    Matrix tr[4 * N];
    inline ll lc (ll x) { return x << 1; }
    inline ll rc (ll x) { return x << 1 | 1; }
    inline void push_up (ll p) { mul(tr[lc(p)], tr[rc(p)], tr[p]); }
    void build (ll p, int l, int r) {
        if(l == r) {
            tr[p] = m[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(lc(p), l, mid);
        build(rc(p), mid + 1, r);
        push_up(p);
    }
    void update (ll p, int l, int r, int q) {
        if(l == r) {
            tr[p] = m[l];
            return;
        }
        int mid = (l + r) >> 1;
        if(q <= mid) update(lc(p), l, mid, q);
        else update(rc(p), mid + 1, r, q);
        push_up(p);
    }
    void query (ll p, int l, int r, int ql, int qr, Node &res) {
        if(ql <= l && r <= qr) {
            mul(tr[p], res, res);
            return;
        }
        int mid = (l + r) >> 1;
        if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
        if(ql <= mid) query(lc(p), l, mid, ql, qr, res);
    }
}

void modify (int p, int v) {
    {
        Matrix &t = m[dfn[p]];
        for(int i = 0; i < 128; i++) {
            int t1 = a[w[p]][i];
            int t2 = a[v][i];
            if(t1 == 0) --zeroCnt[p][i];
            else nozero[p][i] = nozero[p][i] * inv[t1] % MOD;
            if(t2 == 0) ++zeroCnt[p][i];
            else nozero[p][i] = nozero[p][i] * t2 % MOD;
            assert(zeroCnt[p][i] >= 0);
            if(zeroCnt[p][i]) {
                t[1][2][i] = (t[1][2][i] - t[0][0][i] + MOD) % MOD;
                t[0][0][i] = t[0][2][i] = t[1][0][i] = 0;
            } else {
                t[1][2][i] = (t[1][2][i] + nozero[p][i] - t[0][0][i] + MOD) % MOD;
                t[0][0][i] = t[0][2][i] = t[1][0][i] = nozero[p][i];
            }
        }
        w[p] = v;
        SegT::update(1, 1, n, dfn[p]);
    }
    p = top[p];
    while(p != 1) {
        Node tmp;
        SegT::query(1, 1, n, dfn[p], dfn[bot[p]], tmp);
        Matrix &t = m[dfn[fa[p]]];
        for(int i = 0; i < 128; i++) {
            t[1][2][i] = (t[1][2][i] + tmp[1][i] - g[p][i] + MOD) % MOD;
        }
        for(int i = 0; i < 128; i++) {
            int t1 = (f[p][i] + 1) % MOD;
            int t2 = (tmp[0][i] + 1) % MOD;
            int p1 = fa[p];
            if(t1 == 0) --zeroCnt[p1][i];
            else nozero[p1][i] = nozero[p1][i] * inv[t1] % MOD;
            if(t2 == 0) ++zeroCnt[p1][i];
            else nozero[p1][i] = nozero[p1][i] * t2 % MOD;
            if(zeroCnt[p1][i] < 0) {
                cout << p1 << ' ' << i << endl;
            }
            if(zeroCnt[p1][i]) {
                t[1][2][i] = (t[1][2][i] - t[0][0][i] + MOD) % MOD;
                t[0][0][i] = t[0][2][i] = t[1][0][i] = 0;
            } else {
                t[1][2][i] = (t[1][2][i] + nozero[p1][i] - t[0][0][i] + MOD) % MOD;
                t[0][0][i] = t[0][2][i] = t[1][0][i] = nozero[p1][i];
            }
        }
        SegT::update(1, 1, n, dfn[fa[p]]);
        f[p] = tmp[0];
        g[p] = tmp[1];
        p = top[fa[p]];
    }
}

int query (int k) {
    Node tmp;
    SegT::query(1, 1, n, dfn[1], dfn[bot[1]], tmp);
    tmp[1].ifwt();
    return tmp[1][k];
}

ll main () {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    inv[1] = 1;
    for(int i = 2; i < MOD; i++) inv[i] = (MOD - (MOD / i * inv[MOD % i] % MOD)) % MOD;

    cin >> n >> V;
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    for(int i = 0; i < 128; i++) a[i][i] = 1;
    for(int i = 0; i < 128; i++) a[i].fwt();

    dfs1(1);
    dfs2(1, 1);
    dfs3(1);

    SegT::build(1, 1, n);

    cin >> q;
    while (q--) {
        string s;
        int x, y, k;
        cin >> s;
        if(s == "Change") {
            cin >> x >> y;
            modify(x, y);
        } else {
            cin >> k;
            cout << query(k) << '\n';
        }
    }

    return 0;
}