ac自动机通常用来解决多字符串匹配问题,如从字符串s找字符串t[i](i<=n),如果直接用KMP那么时间复杂度为,而用ac自动机时间复杂度为。
ac自动机可以认为是kmp和trie的结合,因为ac自动机就是在trie的基础上怎加了fail变量,fail指向的是当前字符串的最长后缀的尾节点,
作用就是在当前匹配失败时,将当前指针转向fail指向的位置,继续匹配,这样就避免了重复匹配,类似于kmp的next数组的作用。
代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<queue> 5 #include<algorithm> 6 using namespace std; 7 const int N=1e6+5; 8 const int MAX_Tot=5e5+5; 9 10 //建议还是用数组写 11 struct node{ 12 int next[26]; 13 int fail,cnt; 14 }trie[N]; 15 16 int idx; 17 char s[N]; 18 19 void init(){ 20 for(int i=0;i<MAX_Tot;i++){ 21 memset(trie[i].next,0,sizeof(trie[i].next)); 22 trie[i].fail=trie[i].cnt=0; 23 } 24 idx=0; 25 } 26 27 void Insert(char *s){ 28 int n=strlen(s); 29 int now=0; 30 for(int i=0;i<n;i++){ 31 char ch=s[i]; 32 if(!trie[now].next[ch-‘a‘]) 33 trie[now].next[ch-‘a‘]=++idx; 34 now=trie[now].next[ch-‘a‘]; 35 } 36 trie[now].cnt++; //以now节点为结尾的字符串数目+1 37 } 38 39 void getfail(){ 40 trie[0].fail=-1; 41 queue<int>q; 42 q.push(0); 43 44 while(!q.empty()){ 45 int u=q.front(); 46 q.pop(); 47 for(int i=0;i<26;i++){ 48 if(trie[u].next[i]){ 49 if(u==0) trie[trie[u].next[i]].fail=0; 50 else{ 51 int v=trie[u].fail; 52 while(v!=-1){ 53 if(trie[v].next[i]){ 54 trie[trie[u].next[i]].fail=trie[v].next[i]; 55 break; 56 } 57 v=trie[v].fail; 58 } 59 if(v==-1) trie[trie[u].next[i]].fail=0; 60 } 61 q.push(trie[u].next[i]); 62 } 63 else trie[u][i]=trie[fail[u]][i]; //这句按定义可以忽略,但是还是加上,有的题目必须要这句话 64 } 65 } 66 } 67 68 int get(int u){ 69 int res=0; 70 while(u){ 71 res+=trie[u].cnt; 72 trie[u].cnt=0; 73 u=trie[u].fail; 74 } 75 return res; 76 } 77 78 int match(char *s){ 79 int ans=0,now=0; 80 int n=strlen(s); 81 for(int i=0;i<n;i++){ 82 int ch=s[i]-‘a‘; 83 if(trie[now].next[ch]) 84 now=trie[now].next[ch]; 85 else{ 86 int p=trie[now].fail; 87 while(p!=-1&&trie[p].next[ch]==0) p=trie[p].fail; 88 if(p==-1) now=0; 89 else now=trie[p].next[ch]; 90 } 91 if(trie[now].cnt) 92 ans+=get(now); 93 } 94 return ans; 95 } 96 97 int main(){ 98 int t; 99 scanf("%d",&t); 100 while(t--){ 101 init(); 102 int n; 103 scanf("%d",&n); 104 for(int i=0;i<n;i++){ 105 scanf("%s",s); 106 Insert(s); 107 } 108 scanf("%s",s); 109 getfail(); 110 printf("%d\n",match(s)); 111 } 112 return 0; 113 }