据说很多公司都有这样一道面试题:给你几个 G 的字符串,让你想办法快速地找出其中的很多个需要和谐的敏感词。
这个问题里,如果“需要和谐的字符串”称为“模式串”,“待被查的字符串”称为文本串。对于这样的问题,如果暴力做,复杂度就是 $\Theta(N \ast M \ast Len)$……用 AC 自动机这种高级的算法,可以在 $\Theta (N)$ 左右复杂度内得出答案。Excited!
引
这个 AC 自动机可不是 Accepted 自动机……其实是“Aho–Corasick 算法”,是由 Alfred V. Aho 和 Margaret J.Corasick 发明的字符串搜索算法,
维基百科上说:
该算法主要依靠构造一个有限状态机(类似于在一个trie树中添加失配指针)来实现。这些额外的失配指针允许在查找字符串失败时进行回退(例如设Trie树的单词cat匹配失败,但是在Trie树中存在另一个单词cart,失配指针就会指向前缀ca),转向某前缀的其他分支,免于重复匹配前缀,提高算法效率。
是不是听起来就很高级!
AC 自动机与 KMP
其实按照 AC 自动机的思想可以发现:这其实就是树上 KMP,即 Trie 树与 KMP 算法的结合。网上的几乎各路题解都说:“在学习 AC 自动机前要先掌握 Trie 树与 KMP”……
但是其实学 AC 自动机不需要掌握 KMP。
AC 自动机的实际应用
前面已经提到:据说很多公司都有这样一道面试题(据说是美团 2015 面试题?):给你几个 G 的字符串,让你想办法快速地找出其中的很多个需要和谐的敏感词。这几乎是很多互联网公司要做的事情。这种东西用 AC 自动机来做就会大大提高效率。
据说,UNIX 系统中的一个命令 fgrep 就是以 AC 自动机算法作为基础实现的。高级!
失配指针(Fail 指针)
首先我们对于所有的模式串,构造一棵 Trie 树(即字典树)。举个例子,假设文本串是 $shers$,模式串分别是:$sher$,$hers$,$er$,那么构造的 Trie 树是这样的:
graph TB;
R((Root))---S1((S))
S1---H1((H))
H1---E1((E))
E1---R1((R))
R---H2((H))
H2---E2((E))
E2---R2((R))
R2---S2((S))
R---E3((E))
E3---R3((R))
从 Trie 树到 Trie 图
在查询的时候,查询的文本串是 $sher$,自然一直往最左边一路;走的时候可以发现前面四个一直匹配;当走到左边一颗子树的最底下的 R 结点,再走发现 R 没有儿子了,也就是说不再匹配。那么这时候就称为 R 失配了。
那么当 R 失配了,我们下一步要从哪里开始走呢?我们想要找一段 $sher$ 的最长后缀,这样才能保证这段后缀仍然是文本串的子串,并且答案最优。显然,第二棵子树里 $h - e -r$ 这条路径是最优解。当 R 失配后,我们可以直接跳到左起第二棵子树的 R 结点,这样保证 $her$ 是 $sher$ 的最长后缀。
所以最左边一条子树的 R 结点的失配指针(fail 指针)指向左起第二棵子树的 R。类似地,造出所有失配指针,用虚线连接:
graph TB;
R((Root))---S1((S))
S1---H1((H))
H1---E1((E))
E1---R1((R))
R---H2((H))
H2---E2((E))
E2---R2((R))
R2---S2((S))
R---E3((E))
E3---R3((R))
H1.->H2
E1.->E2
R1.->R2
E2.->E3
R2.->R3
构造失配指针之后,后面的处理就很方便了。
构造失配指针
怎么构造失配指针呢?
首先我们设 fail[x] 表示 x 的失配指针。假设当前的结点不存在 c 儿子了:
- 如果其父节点的失配指针存在 c 儿子,则可以“直接转移失配指针”,fail[x]=ch[fail[fa]][c];
- 否则,就去看 fail[fail[fa]] 有没有 c 儿子,如果没有再看 fal[fail[fail[fa]]]……直到找到 c 儿子;
- 如果一直找不到,则 fail[u]=0,也就是指向 Root。
因此,用一个 BFS 就可以方便地构造出失配指针了。
但是这样找 fail 指针可能会很慢,因为要一直 fail fail 地找,十分暴力,有没有优化方法呢?
答案是肯定的。可以方便地通过 Trie 图实现。
还可以更快
仔细思考,其实我们可以当 x 没有儿子时直接把 x 与 fail[x] 连边,反正都是要走过去的。然后就可以将所有空儿子都连上后继结点。这样对于每次寻找失配指针,我们都可以 $O(1)$ 找到,不需要每次都 fail 去找了。
代码如下:
inline void BuildFail(){ // Build Mismatch Pointer
for (int i=0;i<26;i++) if (c[0][i]) que.push(c[0][i]);
while (!que.empty()){
int x=que.front();que.pop();
for (int i=0;i<26;i++) if (c[x][i]){
fail[c[x][i]]=c[fail[x]][i];
que.push(c[x][i]);
} else c[x][i]=c[fail[x]][i];
}
}
查询的实现
查询其实很方便了,我们只需要按照文本串在字典树上走,把走过的点的 value 标记为 -1。因为任意点只会走到一次。没走到一个点就一直 fail、累计即可。
判断一个数 x 是否为 -1 的高端写法是:如果 ~ x==false,则 x 为 -1。因为 -1 是 32 位整数里每位都为 1,按位取反就是 0 了。
代码如下:
inline int query(char s[]){
int len=strlen(s),ret=0,x=0;
for (int i=0;i<len;i++){
int now=s[i]-'a';
x=c[x][now];
for (int t=x;t&&~val[t];t=fail[t]) ret+=val[t],val[t]=-1;
}
return ret;
}
拓展应用
通过上面的 Trie 图可以发现,AC 自动机几乎把字符串题目转化为了图论题目。这样我们就可以结合树形 DP、结合 DFS 序等等等等衍生出一大堆用途。这也是 AC 自动机的意义之一了。
(未完待续……)
模板题
当然可以先去做洛谷上的模板题了~
这也是道模板题:HDU 2222 Keywords Search
代码
HDU 2222
// Aho–Corasick Algorithm
#include<cstdio>
#include<cstring>
#include<iostream>
#include<queue>
using namespace std;
const int maxn=1000005,maxlen=55,maxlen_q=1000005;
int T,n,cnt=0;
int c[maxn][30],fail[maxn],val[maxn];
char s[maxlen],q[maxlen_q];
queue<int> que;
inline int read(){
int ret=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
return ret*f;
}
inline void init(){
cnt=0;
memset(c,0,sizeof(c));
memset(fail,0,sizeof(fail));
memset(val,0,sizeof(val));
}
inline void insert(char s[]){ // Insert a word to Trie Tree
int len=strlen(s),x=0;
for (int i=0;i<len;i++){
int now=s[i]-'a';
if (!c[x][now]) c[x][now]=++cnt;
x=c[x][now];
}
val[x]++;
}
inline void BuildFail(){ // Build Mismatch Pointer
for (int i=0;i<26;i++) if (c[0][i]) que.push(c[0][i]);
while (!que.empty()){
int x=que.front();que.pop();
for (int i=0;i<26;i++) if (c[x][i]){
fail[c[x][i]]=c[fail[x]][i];
que.push(c[x][i]);
} else c[x][i]=c[fail[x]][i];
}
}
inline int query(char s[]){
int len=strlen(s),ret=0,x=0;
for (int i=0;i<len;i++){
int now=s[i]-'a';
x=c[x][now];
for (int t=x;t&&~val[t];t=fail[t]) ret+=val[t],val[t]=-1;
}
return ret;
}
int main(){
T=read();
while (T--){
n=read();init();
for (int i=1;i<=n;i++) scanf("%s",s),insert(s);
BuildFail();
scanf("%s",q);
printf("%d\n",query(q));
}
return 0;
}