跳转至

AC 自动机(ACAM)

ACAM 是一种多模式串匹配的算法。对于 \(m\) 个模式串 \(t_i\),ACAM 可以在 \(O\big(\sum |t_i|+|s|\big)\) 的时间复杂度内求出每个 \(t_i\)\(s\) 中出现了多少次。

ACAM 会先将所有模式串建一个 Trie 树,然后给每个节点维护一个 fail 指针。fail[u] 定义为:\(u\) 子串的所有后缀中,存在于 Trie 上且最长的一个后缀。

在 Trie 上,如果节点 \(u\) 不存在 \(c\)-转移边(chd[u][c]),则在 ACAM 上定义 chd[u][c] = chd[fail[u]][c]

通过上面的改造,fail 指针可以用这个公式快速求出:fail[chd[u][c]] = chd[fail[u]][c]chd[u][c] 是原本就存在的)。

建立 ACAM

考虑在 Trie 树上跑 bfs

bfs 建立 ACAM
void build() {
    queue<int> que;
    // 先将根节点的所有儿子入队,否则这些节点的失配指针会指向自己
    // 有点类似 KMP 中预处理 nxt 数组时,i 从 1 开始而 j 从 0 开始,就是为了保证 fail 是真后缀
    for(int i = 0; i < 26; i++) if(chd[0][i]) que.push(chd[0][i]);
    while(!que.empty()) {
        int u = que.front();
        que.pop();
        for(int i = 0; i < 26; i++) {
            if(chd[u][i]) {
                fail[chd[u][i]] = chd[fail[u]][i];
                que.push(chd[u][i]);
            } else {
                chd[u][i] = chd[fail[u]][i];
            }
        }
    }
}

因为我们不能保证 fail[u] 总在 \(u\) 的返根链上,因此不能使用 dfs;否则在使用 chd[fail[u]][c] 时,fail[u] 可能还没有被遍历到。

匹配

经过这些改造以及 fail 指针的预处理,Trie 树已经变成了一个能快速匹配的自动机。

我们将文本串放在 ACAM 上跑;假设当前已经处理了一个前缀 \([1,x]\),位于自动机上的节点 \(u\);那么在 fail 树上 \(u\) 的祖先都是 \(s[1\sim x]\) 的后缀。也就是说,当前匹配了 \(1\to u\) 根链上的所有串。

因此匹配时需要记录 cnt,匹配结束后对 fail 树进行 dfs,求出 cnt 的子树和,才能知道当前子串被匹配了多少次。

注:以下题目若无特殊说明,字符集均为小写英文字母。

P5357 【模板】AC 自动机

求出每个模式串 \(t_i\)\(s\) 中出现的次数。

SAM 会被卡空间。

\(\sum{|t_i|}\le 2\times 10^5,\ s\le 2\times 10^6\)

模板代码
#include<iostream>
using namespace std;
const int N = 2e5 + 10;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n;
int pos[N], cnt[N];

int nn;
int chd[N][26], fail[N];

int insert(const string &s) {
    int cur = 0;
    for(int i = 0; i < (int)s.size(); i++) {
        if(!chd[cur][s[i] - 'a']) chd[cur][s[i] - 'a'] = ++nn;
        cur = chd[cur][s[i] - 'a'];
    }
    return cur;
}

int que[N], hd = 1, tl = 0;
void build() {
    for(int i = 0; i < 26; i++) if(chd[0][i]) que[++tl] = chd[0][i];
    while(hd <= tl) {
        int u = que[hd++];
        for(int i = 0; i < 26; i++) {
            if(chd[u][i]) {
                fail[chd[u][i]] = chd[fail[u]][i];
                que[++tl] = chd[u][i];
            } else {
                chd[u][i] = chd[fail[u]][i];
            }
        }
    }
}

void dfs(int u) {
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        dfs(v);
        cnt[u] += cnt[v];
    }
}

string s;

int main() {

    cin >> n;
    for(int i = 1; i <= n; i++) {
        string t;
        cin >> t;
        pos[i] = insert(t);
    }

    build();

    for(int i = 1; i <= nn; i++) addEdge(fail[i], i);

    cin >> s;
    int cur = 0;
    for(int i = 0; i < (int)s.size(); i++) {
        cur = chd[cur][s[i] - 'a'];
        ++cnt[cur];
    }

    dfs(0);

    for(int i = 1; i <= n; i++) {
        cout << cnt[pos[i]] << '\n';
    }

    return 0;
}

P2292 [HNOI2004] L 语言

题意

给定 \(n\)\(n\le 20\))个长度不超过 \(20\) 的字符串 \(t_i\),称为字典。

称一个字符串 \(s\)可被理解的,当且仅当 \(s\) 可以被完全分解为若干字典中的单词。

\(m\) 次询问,每次询问给定一个文本串 \(s_i\),求其最长的,可被理解的前缀。

\(m\le 50,\ s_i\le 2\times 10^6\)(要求线性)

考虑如何判断一个字符串 \(s\) 可被理解。我们用一个数组 \(f[i]\) 记录 \(i\) 前缀是否匹配。先对字典建出 ACAM,把 \(s\) 放在 ACAM 上跑,每走到一个节点 \(u\),就跳 fail 找出所有可能匹配的 \(t_i\)。若 \(t_i\) 匹配成功,则 \(f\big[i-|t_i|\big]\rightarrow f[i]\)

由于需要跳 fail,因此时间复杂度无法通过本题。考虑优化。

我们希望在节点 \(u\) 上无需跳 fail 就能找到所有匹配的 \(t_i\)。更进一步的,只有 \(|t_i|\) 就可以了。注意到 \(t_i\)\(n\) 都不超过 \(20\),因此我们可以用状压来记录所有匹配到的 \(|t_i|\)。这只需要在建出 ACAM 后通过一遍 dfs 即可求出。

考虑如何优化转移。由于 \(|t_i|\le 20\),因此只有 \(f[i-20,i-1]\) 是有效的。注意到 \(f\) 只记录了一些 bool 类型的值,因此可以通过状压来保存 \(f[i-20,i-1]\)。转移时只需要 \(O(1)\) 的位运算即可。

代码
#include<iostream>
#include<cstring>
using namespace std;
const int N = 20;
const int N2 = 410;

struct Edge {
    int v, next;
} pool[N2];
int ne, head[N2];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, m;

int nn;
int chd[N2][26], fail[N2], mat[N2];

void insert(const string& s, int id) {
    int cur = 0;
    for(int i = 0; i < s.size(); i++) {
        if(!chd[cur][s[i] - 'a']) chd[cur][s[i] - 'a'] = ++nn;
        cur = chd[cur][s[i] - 'a'];
    }
    mat[cur] |= 1 << id;
}

int que[N2], hd = 1, tl = 0;
void build() {
    for(int i = 0; i < 26; i++) if(chd[0][i]) que[++tl] = chd[0][i];
    while(hd <= tl) {
        int u = que[hd++];
        for(int i = 0; i < 26; i++) {
            if(chd[u][i]) {
                fail[chd[u][i]] = chd[fail[u]][i];
                que[++tl] = chd[u][i];
            } else {
                chd[u][i] = chd[fail[u]][i];
            }
        }
    }
}

void dfs(int u) {
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        mat[v] |= mat[u];
        dfs(v);
    }
}

int work(const string &s) {
    int cur = 0, res = 0;
    int sta = 1, mask = 0x000fffff;
    for(int i = 0; i < s.size(); i++) {
        cur = chd[cur][s[i] - 'a'];
        sta <<= 1;
        sta &= mask;
        if(!sta) return res;
        sta |= (bool)(sta & mat[cur]);
        if(sta & 1) res = i + 1;
    }
    return res;
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        string s;
        cin >> s;
        insert(s, s.size());
    }

    build();

    for(int i = 1; i <= nn; i++) addEdge(fail[i], i);

    dfs(0);

    while(m--) {
        string s;
        cin >> s;
        cout << work(s) << '\n';
    }

    return 0;
}

P3735 [HAOI2017] 字符串

题意

给定常数 \(k\);定义两个字符串 \(a,b\) 相等当且仅当:

  • \(|a|=|b|\)
  • \(\forall a_i\ne b_i,a_j\ne b_j,\ |i-j|<k\)

如果 \(|a|=|b|\le k\),则认为两个字符串相等。

给定一个文本串 \(s\)\(n\) 个模式串 \(t_i\),问对于每个 \(i\in [1,n]\)\(t_i\)\(s\) 中出现的次数。

\(|s|,\sum{|t_i|}\le 2\times 10^5,\ |\Sigma|=95\)

先特判掉 \(t_i\le k\) 的情况。

注意到我们可以枚举 \(t_i\) 匹配的断点,然后枚举 \(s\) 上的断点位置,尝试快速处理贡献。

考察匹配时的情形:\(t_i\)\(s\) 各被分成了 \(3\) 段,中间一段是长度为 \(k\) 的失配区间,左边具有后缀关系,右边具有前缀关系(可以被视为反串上的后缀关系)。

示意图

刻画为二维偏序

后缀关系可以用某些 fail 指针来刻画。具体的,设 \(t_i\) 分成了 \(u,w_1,v\) 三段,\(s\) 分成了 \(x,w_2,y\) 三段;那么 \(u\)fail 树上是 \(x\) 的祖先,\(v\) 在反串的 fail 树上是 \(y\) 的祖先。

我们知道,子树关系可以使用 dfn 序转换成偏序问题。那么在两棵不同树上的父子关系就可以转化成一个二维偏序问题。我们先求出 \(t_i\) 的所有 \((u,v)\) 划分方案,将它们作为数点的查询操作(不难发现总数量不超过 \(\sum|t_i|\));再求出 \(s\) 的所有 \((x,y)\) 划分方案,将它们作为数点的加点操作。

SAM / ACAM

因为我们要处理多个模式串的问题,因此我们不使用 KMP(只能处理单模式串)。考虑使用 SAM 或者 ACAM。

先对 \(s\)\(s\) 的反串分别建出 SAM,然后枚举所有的 \((x,y)\) 划分方案,找到 \(x,y\) 串分别对应在两个 SAM 上的节点编号(在 insert 时记录即可),直接将 dfn 序扔进扫描线;

然后依次枚举每个 \(t_i\)\((u,v)\) 的划分方案,将 \(u,v\) 分别放在原串 SAM 和反串 SAM 上跑,将得到的节点做成子树查询扔进扫描线即可。

先将所有 \(t_i\) 的原串和反串分别建出 ACAM,同时在建立的过程中记录 \((u,v)\) 的划分方案,将 \(u,v\) 对应在 ACAM 上的节点做成子树查询扔进扫描线;

然后将 \(s\)\(s\) 的反串分别放到两个 ACAM 上跑,记录 \((x,y)\) 的划分方案,将 \(x,y\) 对应的状态的 dfn 序扔进扫描线;

容斥

然而我们注意到,当 \(t_i\) 在某个位置的实际失配区间长度少于 \(k\) 时,该位置的贡献可能被计算多次。这是因为可能有多个长度为 \(k\) 的区间包含这个小区间。考虑容斥:如果我们只让最小的 \(pos\) 产生贡献,那么每个匹配位置就恰好只产生一次贡献。因此我们再统计从 \(pos\) 开始的长度为 \(k-1\) 的失配区间能否匹配这个串,如果仍然可以,那么 \(pos\) 就不是最小的。

注意:统计 \(k-1\) 的答案时要考虑 \([pos,pos+k-2]\) 紧贴 \(t_i\) 左边和右边的情况。紧贴左边时,\(pos\) 一定是最小的,不能去掉;紧贴右边时,贡献不会在 \(k\) 处统计到。因此应去掉两个边界。

代码
#include<iostream>
#include<cstring>
#include<vector>
#include<cassert>
#include<algorithm>
using namespace std;
const int N = 2e5 + 10;
const int Sigma = 95;

struct myPair {
    int p, id;
};

struct Op {
    int id, x, y1, y2, w;
    inline bool operator<(const Op &b) const {
        if(x != b.x) return x < b.x;
        return id < b.id;
    }
};

int n, k, m;
int ans[N];
string s;
string t[N];

vector<myPair> mp1[N], mp2[N];

int sta[N], rst[N];
int dfn1[N], dfn2[N], sz1[N], sz2[N], dt1, dt2;

struct ACAM_Tp {

    struct Edge {
        int v, next;
    } pool[N];
    int ne, head[N];
    void addEdge(int u, int v) {
        pool[++ne] = {v, head[u]};
        head[u] = ne;
    }

    int nn;
    int chd[N][Sigma], fail[N];

    inline ACAM_Tp() {
        memset(head, 0, sizeof(head));
        memset(chd, 0, sizeof(chd));
        memset(fail, 0, sizeof(fail));
        ne = nn = 0;
    }

    static int tmp[N];
    inline void insert1(const string &str) {
        int cur = 0;
        for(int i = 0; i < str.size(); i++) {
            int c = str[i] - 33;
            if(!chd[cur][c]) chd[cur][c] = ++nn;
            cur = chd[cur][c];
            tmp[i + 1] = cur;
        }
    }

    inline void insert2(const string &str, int id) {
        int cur = 0;
        mp1[tmp[str.size() - k]].push_back({cur, id});
        for(int i = str.size() - 1; i >= 0; i--) {
            int c = str[i] - 33;
            if(!chd[cur][c]) chd[cur][c] = ++nn;
            cur = chd[cur][c];
            if(i >= k) mp1[tmp[i - k]].push_back({cur, id});
            if(i >= k) mp2[tmp[i - k + 1]].push_back({cur, id});
        }
    }

    inline void build() {
        static int que[N], hd, tl;
        hd = 1, tl = 0;
        for(int i = 0; i < Sigma; i++) if(chd[0][i]) que[++tl] = chd[0][i];
        while(hd <= tl) {
            int u = que[hd++];
            for(int i = 0; i < Sigma; i++) {
                if(chd[u][i]) {
                    fail[chd[u][i]] = chd[fail[u]][i];
                    que[++tl] = chd[u][i];
                } else {
                    chd[u][i] = chd[fail[u]][i];
                }
            }
        }
    }

    void init_Edge() {
        for(int i = 1; i <= nn; i++) addEdge(fail[i], i);
    }

    void dfs(int u, int dfn[], int sz[], int &dt) {
        dfn[u] = ++dt;
        sz[u] = 1;
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            dfs(v, dfn, sz, dt);
            sz[u] += sz[v];
        }
    }

} am, ram;

int ACAM_Tp::tmp[N];

vector<Op> op;

namespace BIT {
    int sum[N];
    inline int lowbit(int x) { return x & -x; }
    inline void clear() {
        memset(sum, 0, sizeof(sum));
    }
    inline void add(int p, int v) {
        for(int i = p + 1; i <= m + 2; i += lowbit(i)) sum[i] += v;
    }
    inline int query(int p) {
        int res = 0;
        for(int i = p + 1; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
    inline int query(int l, int r) {
        return query(r) - query(l - 1);
    }
}

void add_Pt(int x, int y, int w) {
    op.push_back({0, x, y, 0, w});
}

void add_Qr(int x1, int x2, int y1, int y2, int id) {
    op.push_back({id, x1 - 1, y1, y2, -1});
    op.push_back({id, x2, y1, y2, 1});
}

int main() {

    cin >> k >> s >> n;
    for(int i = 1; i <= n; i++) cin >> t[i];
    for(int i = 1; i <= n; i++) m += t[i].size();

    for(int i = 1; i <= n; i++) {
        if(t[i].size() <= k) ans[i] = s.size() - t[i].size() + 1;
        else {
            am.insert1(t[i]);
            ram.insert2(t[i], i);
        }
    }

    am.build();
    ram.build();

    am.init_Edge();
    ram.init_Edge();

    am.dfs(0, dfn1, sz1, dt1);
    ram.dfs(0, dfn2, sz2, dt2);

    for(int i = 0, cur = 0; i < s.size(); i++) {
        cur = am.chd[cur][s[i] - 33];
        sta[i + 1] = cur;
    }
    for(int i = s.size() - 1, cur = 0; i >= 0; i--) {
        cur = ram.chd[cur][s[i] - 33];
        rst[i + 1] = cur;
    }

    for(int i = 0, x, y; i + k <= s.size(); i++) {
        x = sta[i];
        y = rst[i + k + 1];
        add_Pt(dfn1[x], dfn2[y], 1);
    }
    for(int i = 0; i <= am.nn; i++) {
        for(myPair &j : mp1[i]) {
            int x = i, y = j.p;
            add_Qr(dfn1[x], dfn1[x] + sz1[x] - 1, dfn2[y], dfn2[y] + sz2[y] - 1, j.id);
        }
    }
    sort(op.begin(), op.end());
    for(Op &o : op) {
        if(o.id) {
            ans[o.id] += BIT::query(o.y1, o.y2) * o.w;
        } else {
            BIT::add(o.y1, o.w);
        }
    }

    BIT::clear();
    op.clear();

    for(int i = 0, x, y; i + k - 1 <= s.size(); i++) {
        x = sta[i];
        y = rst[i + k];
        add_Pt(dfn1[x], dfn2[y], -1);
    }
    for(int i = 0; i <= am.nn; i++) {
        for(myPair &j : mp2[i]) {
            int x = i, y = j.p;
            add_Qr(dfn1[x], dfn1[x] + sz1[x] - 1, dfn2[y], dfn2[y] + sz2[y] - 1, j.id);
        }
    }
    sort(op.begin(), op.end());
    for(Op &o : op) {
        if(o.id) {
            ans[o.id] += BIT::query(o.y1, o.y2) * o.w;
        } else {
            BIT::add(o.y1, o.w);
        }
    }

    for(int i = 1; i <= n; i++) cout << ans[i] << '\n';

    return 0;
}

P2414 [NOI2011] 阿狸的打字机

题意

有一个字符串 \(t\),初始为空。有若干次操作,都是如下的三种:

  • 加入一个小写英文字母到 \(t\) 的末尾;
  • 删除 \(t\) 末尾的一个字符;
  • ++nn\(t\to s[nn]\)

然后,有 \(m\) 次询问,每次询问给定 \(x,y\),表示询问 \(s[x]\)\(s[y]\) 中出现了多少次。

注意到题目给出的这种构造字符串的方法,虽然 \(\sum|s_i|\) 没有保证,但是建出的 Trie 的总节点数是有保障的。

考虑如何统计出现次数。由经典 Trick,\(t\)\(s\) 中每出现一次,就唯一对应 \(s\) 的一个前缀,其中 \(t\) 是这个前缀的一个后缀。回到原问题,我们就是要统计满足一下条件的节点数量:

  • 该节点是 \(s[y]\) 在 Trie 上的祖先;
  • 该节点是 \(s[x]\)\(\operatorname{fail}\) 树上的儿子;

这可以使用 dfs 序刻画为二维偏序。离线扫描线即可。

代码
#include<iostream>
#include<string>
#include<vector>
#include<algorithm>
using namespace std;
const int N = 1e5 + 10;

struct Edge {
    int v, next;
} pool[N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

void clear_Edge(int n) {
    ne = 0;
    for(int i = 0; i <= n; i++) head[i] = 0;
}

int n, q;
int pt[N];
int chd[N][26], fa[N], fail[N], nn;

void build() {
    static int que[N], hd = 1, tl = 0;
    for(int c = 0; c < 26; c++) if(chd[0][c]) que[++tl] = chd[0][c];
    while(hd <= tl) {
        int u = que[hd++];
        for(int c = 0; c < 26; c++) {
            if(chd[u][c]) {
                fail[chd[u][c]] = chd[fail[u]][c];
                que[++tl] = chd[u][c];
            } else {
                chd[u][c] = chd[fail[u]][c];
            }
        }
    }
}

int dfn1[N], dfn2[N], sz1[N], sz2[N], dt1, dt2;

void dfs(int u, int dfn[], int sz[], int &dt) {
    sz[u] = 1;
    dfn[u] = ++dt;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        dfs(v, dfn, sz, dt);
        sz[u] += sz[v];
    }
}

struct Op {
    int tp, x, y, z;
    inline bool operator<(const Op &b) const {
        if(x != b.x) return x < b.x;
        return tp < b.tp;
    }
};

namespace BIT {
    int sum[N];
    inline int lowbit(int x) { return x & -x; }
    inline void add(int p, int v) {
        for(int i = p; i <= nn + 1; i += lowbit(i)) sum[i] += v;
    }
    inline int query(int p) {
        int res = 0;
        for(int i = p; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
    inline void add(int l, int r, int v) {
        add(l, v);
        add(r + 1, -v);
    }
}

int ans[N];
vector<Op> op;

void add_Qr(int x1, int x2, int y, int id) {
    op.push_back({id, x1 - 1, y, -1});
    op.push_back({id, x2, y, 1});
}

void add_Pt(int x, int y1, int y2) {
    op.push_back({0, x, y1, y2});
}

string s;

int main() {

    cin >> s;
    int cur = 0;
    for(int i = 0; i < s.size(); i++) {
        char c = s[i];
        if(c == 'B') cur = fa[cur];
        else if(c == 'P') pt[++n] = cur;
        else {
            if(!chd[cur][c - 'a']) chd[cur][c - 'a'] = ++nn;
            fa[chd[cur][c - 'a']] = cur;
            cur = chd[cur][c - 'a'];
        }
    }

    build();

    for(int i = 1; i <= nn; i++) addEdge(fa[i], i);
    dfs(0, dfn1, sz1, dt1);
    clear_Edge(nn);
    for(int i = 1; i <= nn; i++) addEdge(fail[i], i);
    dfs(0, dfn2, sz2, dt2);

    cin >> q;
    for(int i = 1; i <= q; i++) {
        int x, y;
        cin >> x >> y;
        x = pt[x], y = pt[y];
        add_Qr(dfn2[x], dfn2[x] + sz2[x] - 1, dfn1[y], i);
    }
    for(int i = 0; i <= nn; i++) {
        add_Pt(dfn2[i], dfn1[i], dfn1[i] + sz1[i] - 1);
    }

    sort(op.begin(), op.end());

    for(Op &o : op) {
        if(o.tp) {
            ans[o.tp] += BIT::query(o.y) * o.z;
        } else {
            BIT::add(o.y, o.z, 1);
        }
    }

    for(int i = 1; i <= q; i++) cout << ans[i] << '\n';

    return 0;
}