描述
给定N个字符串S1,S2…SN,接下来进行M次询问,每次询问给定一个字符串T,求S1~SN中有多少个字符串是T的前缀。输入字符串的总长度不超过10^6,仅包含小写字母。
输入格式
第一行两个整数N,M。接下来N行每行一个字符串Si。接下来M行每行一个字符串表示询问。
输出格式
对于每个询问,输出一个整数表示答案
样例输入
3 2 ab bc abc abc efg
样例输出
2 0
题解:把N个字符串直接插入Trie中,为了防止重复字符串导致错误,末尾的标记要改成计数,这样每次查询一个字符串T,就只需要正常查询,每查询一个字符就进行一次累加。
//#include<bits/stdc++.h> #include<algorithm> #include <iostream> #include <cstdlib> #include <cstring> #include <cassert> #include <cstdio> #include <vector> #include <string> #include <cmath> #include <queue> #include <stack> #include <set> #include <map> using namespace std; #define P(a,b,c) make_pair(a,make_pair(b,c)) #define rep(i,a,n) for (int i=a;i<=n;i++) #define per(i,a,n) for (int i=n;i>=a;i--) #define CLR(vis) memset(vis,0,sizeof(vis)) #define MST(vis,pos) memset(vis,pos,sizeof(vis)) #define pb push_back #define mp make_pair #define all(x) (x).begin(),(x).end() #define fi first #define se second #define SZ(x) ((int)(x).size()) typedef pair<int,pair<int,int> >pii; typedef long long ll; typedef unsigned long long ull; const ll mod = 1000000007; const int INF = 0x3f3f3f3f; ll gcd(ll a, ll b) { return b ? gcd(b, a%b) : a; } const int maxn = 1e6+10; int trie[maxn][26],tot=1; int over[maxn]; char s[maxn]; void insert(char *str){ int len = strlen(str), p=1; for(int k=0;k<len;k++){ int ch=str[k]-'a'; if(trie[p][ch]==0)trie[p][ch]=++tot; p=trie[p][ch]; } over[p]++; } int search(char *str){ int len=strlen(str),p=1,cnt=0; for(int i=0;i<len;i++){ int ch=str[i]-'a'; if(trie[p][ch]==0)return cnt; p=trie[p][ch]; cnt+=over[p]; } return cnt; } int main(){ CLR(over); int n,m; scanf("%d%d", &n,&m); while(n--){ scanf("%s", s); insert(s); } while(m--){ scanf("%s", s); cout<<search(s)<<endl; } return 0; }