trie树
trie树又称前缀树,是一种有序树,常用于检索字符串、AC自动机、维护异或极值、维护异或和、01-trie树、可持久化字典树等等
对acmer来说是个比较常见的东西,特别是涉及到前缀之类的字符串题
他的主要思想就是共享前缀,达到快速检索的目的
代码实现
trie的思想很简单,但如果你没有见过他的代码你可能会觉得无从下手,所以这里我主要讲一下trie树的实现思想
其实和链式前向星建图差不多,靠一个tot来标记+跳转,id用来记录
int tot, id, root;
int tr[MAX][30];//字典树
bool vis[MAX];//标记数组
//建树
void insert(string s){
root = 0;//0节点不放东西的
for(int i = 0; i < s.size(); ++i){
id = s[i] - 'a';
if(!tr[root][id])tr[root][id] = ++tot;//如果当前位置的id位置没有值,说明没有字符串能到目前这个位置,那我们就给他赋个值就行
root = tr[root][id];//root进行跳转,跳到当前的位置来
}
vis[root] = 1;//就是在结尾标记一下,表示root的位置有字符串
}
//查询,这个函数是可以根据需要来改的,比较灵活,这里写的是判读串t是否出现过
bool find(string t){
root = 0;
for(int i = 0; i < t.size(); ++i){
id = t[i] - 'a';
if(!tr[root][id])return false;//如果当前位置没来过,说明之前并不存在t,返回就可以
root = tr[root][id];//跳转
}
if(vis[root])return true;//如果root出有字符串就返回出现过
else return false;//否则就没有出现过
}
如果不会就多做几个题,写多了就懂了
板子
放个判是否是前缀的板子
string s;
int root, tot, id, num, cnt;
int tr[MAX][5];
bool vis[MAX];
string t[MAX];
void insert(string s){
root = 0;
int len = s.size();
for(int i = 0; i < len; ++i){
id = s[i] - '0';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
}
vis[root] = 1;
}
bool judge(string s){//true为存在s的前缀
root = 0;
int len = s.size();
for(int i = 0; i < len - 1; ++i){
id = s[i] - '0';
root = tr[root][id];
if(vis[root])return true;
}
return false;
}
入门例题
单词数
题目描述:
看题目名字就知道,这个题就是统计单词数量
思路:
trie板子题
多组输入,记得初始化
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<sstream>
#include<cstring>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
using namespace std;
#define eps 1.0E-8
#define endl '\n'
#define inf 0x3f3f3f3f
#define MAX 1000000 + 50
//#define mod 1000000007
#define lowbit(x) (x & (-x))
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d %d",&n,&m)
#define pd(n) printf("%d\n", (n))
#define pdd(n,m) printf("%d %d\n",n, m)
#define sddd(n,m,z) scanf("%lld %lld %lld",&n,&m,&z)
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define mem(a,b) memset((a),(b),sizeof(a))
#define max(a,b) (((a)>(b)) ? (a):(b))
#define min(a,b) (((a)>(b)) ? (b):(a))
typedef long long ll ;
typedef unsigned long long ull;
//不开longlong见祖宗!不看范围见祖宗!
inline int IntRead(){char ch = getchar();int s = 0, w = 1;while(ch < '0' || ch > '9'){if(ch == '-') w = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar();}return s * w;}
int id, tot, root, ans;
string s, t;
stringstream ss;
int tr[MAX][30];
bool vis[MAX];
void insert(string t){
root = 0;
for(int i = 0; i < t.size(); ++i){
id = t[i] - 'a';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
}
vis[root] = 1;
}
bool judge(string t){
root = 0;
for(int i = 0; i < t.size(); ++i){
id = t[i] - 'a';
if(!tr[root][id])return false;
root = tr[root][id];
}
if(vis[root])return true;
else return false;
}
void init(){
for(int i = 0; i <= tot; ++i){
vis[i] = 0;
for(int j = 0; j <= 26; ++j){
tr[i][j] = 0;
}
}
tot = ans = 0;
}
int main(){
while (getline(cin, s)) {
if(s[0] == '#')break;
ss.clear();
ss << s;
while (ss >> t) {
if(!judge(t)){
++ans;
insert(t);
}
}
cout<<ans<<endl;
init();
}
return 0;
}
统计难题
题目描述:
给一批单词,再给另一批单词,对第二批的每个单词,问第一批中有多少个是以他为前缀的
思路:
板子题
输入第一批的时候建树,维护一个sum数组
输入第二批的时候就挨个查询即可
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<sstream>
#include<cstring>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
using namespace std;
#define eps 1.0E-8
#define endl '\n'
#define inf 0x3f3f3f3f
#define MAX 2000000 + 50
#define mod 1000000007
#define lowbit(x) (x & (-x))
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d %d",&n,&m)
#define pd(n) printf("%d\n", (n))
#define pdd(n,m) printf("%d %d\n",n, m)
#define sddd(n,m,z) scanf("%d %d %d",&n,&m,&z)
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define mem(a,b) memset((a),(b),sizeof(a))
#define max(a,b) (((a)>(b)) ? (a):(b))
//#define min(a,b) (((a)>(b)) ? (b):(a))
typedef long long ll ;
typedef unsigned long long ull;
//不开longlong见祖宗!不看范围见祖宗!
inline int IntRead(){char ch = getchar();int s = 0, w = 1;while(ch < '0' || ch > '9'){if(ch == '-') w = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar();}return s * w;}
int cnt, tot, root, id;
int tr[MAX][30];
char s[20];
int sum[MAX];
void insert(string s){
root = 0;
for(int i = 0; i < s.size(); ++i){
id = s[i] - 'a';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
++sum[root];//更新前缀的数量
}
}
int getans(string s){
root = 0;
for(int i = 0; i < s.size(); ++i){
id = s[i] - 'a';
if(!tr[root][id])return 0;
root = tr[root][id];
}
return sum[root];
}
int main(){
while (gets(s)) {
if(s[0] == '\0')break;
insert(s);
}
while (gets(s)) {
cout<<getans(s)<<endl;
}
return 0;
}
Phone List
题目描述:
给一堆电话号码,问有没有某个电话号码是别的电话号码的前缀
思路
板子题,输入的时候建树,并存下来,然后等输入完了再从头开始扫,对第i个串判断有没有他的前缀存在
记得初始化
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<sstream>
#include<cstring>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
using namespace std;
#define eps 1.0E-8
#define endl '\n'
#define inf 0x3f3f3f3f
#define MAX 100000 + 50
#define mod 1000000007
#define lowbit(x) (x & (-x))
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d %d",&n,&m)
#define pd(n) printf("%d\n", (n))
#define pdd(n,m) printf("%d %d\n",n, m)
#define sddd(n,m,z) scanf("%d %d %d",&n,&m,&z)
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define mem(a,b) memset((a),(b),sizeof(a))
#define max(a,b) (((a)>(b)) ? (a):(b))
//#define min(a,b) (((a)>(b)) ? (b):(a))
typedef long long ll ;
typedef unsigned long long ull;
//不开longlong见祖宗!不看范围见祖宗!
inline int IntRead(){char ch = getchar();int s = 0, w = 1;while(ch < '0' || ch > '9'){if(ch == '-') w = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar();}return s * w;}
int t, n;
string s[MAX];
bool ans;
int tot, root, id;
int tr[MAX][15];
bool vis[MAX];
void insert(string s){
root = 0;
for(int i = 0; i < s.size(); ++i){
id = s[i] - '0';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
}
vis[root] = 1;
}
bool judge(string s){
root = 0;
for(int i = 0; i < s.size() - 1; ++i){
id = s[i] - '0';
root = tr[root][id];
if(vis[root])return false;
}
return true;
}
int main(){
io;
cin>>t;
while (t--) {
mem(tr, 0);
mem(vis, 0);
ans = tot = 0;
cin>>n;
for(int i = 1; i <= n; ++i){
cin>>s[i];
insert(s[i]);
}
for(int i = 1; i <= n; ++i){
if(!judge(s[i])){
ans = 1;
break;
}
}
if(ans)cout<<"NO\n";
else cout<<"YES\n";
}
return 0;
}
Immediate Decodability
题目描述&思路:
和上一个题一样,双倍经验get
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<sstream>
#include<cstring>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
using namespace std;
#define eps 1.0E-8
#define endl '\n'
#define inf 0x3f3f3f3f
#define MAX 100000 + 50
#define mod 1000000007
#define lowbit(x) (x & (-x))
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d %d",&n,&m)
#define pd(n) printf("%d\n", (n))
#define pdd(n,m) printf("%d %d\n",n, m)
#define sddd(n,m,z) scanf("%d %d %d",&n,&m,&z)
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define mem(a,b) memset((a),(b),sizeof(a))
#define max(a,b) (((a)>(b)) ? (a):(b))
//#define min(a,b) (((a)>(b)) ? (b):(a))
typedef long long ll ;
typedef unsigned long long ull;
//不开longlong见祖宗!不看范围见祖宗!
inline int IntRead(){char ch = getchar();int s = 0, w = 1;while(ch < '0' || ch > '9'){if(ch == '-') w = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar();}return s * w;}
string s;
int root, tot, id, num, cnt;
int tr[MAX][5];
bool vis[MAX];
string t[MAX];
void insert(string s){
root = 0;
for(int i = 0; i < s.size(); ++i){
id = s[i] - '0';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
}
vis[root] = 1;
}
bool judge(string s){//true为存在前缀子串
root = 0;
for(int i = 0; i < s.size() - 1; ++i){
id = s[i] - '0';
root = tr[root][id];
if(vis[root])return true;
}
return false;
}
int main(){
io;
while (cin>>s) {
if(s[0] == '9'){
++cnt;
bool ans = 0;
for(int i = 1; i <= num; ++i){
if(judge(t[i])){
ans = 1;
break;
}
}
if(ans)cout<<"Set "<<cnt<<" is not immediately decodable\n";
else cout<<"Set "<<cnt<<" is immediately decodable\n";
mem(vis, 0);
mem(tr, 0);
tot = num = 0;
}
else{
t[++num] = s;
insert(s);
}
}
return 0;
}
Hat’s Words
题目描述:
给你一堆串,问你这些串中有多少个串可以由别的两个串拼接而成即C = A + B,按字典序输出这些串
思路:
这个题有点意思,最开始我想的是用trie树暴力,枚举每个串的分割位置然后去查询,但是有50000个单词,单词的数量还不知道,感觉必T
正解应该是维护两个trie树,一个插正的字符串,另一个将字符串反着插,用两个vis数组来标记前缀和后缀,查询的时候就维护一个sum数组,对两个树都查,如果当前位置存在字符串就给当前长度位置++,最后扫一遍sum数组,如果有哪个位置前后缀都标记过了,也就是sum=2,就说明这个单词可以由别的单词拼接而成
这个题我调了快三个小时,重写了两遍,用尽一切办法,终于发现是函数传参出了问题,不要传字符串数组进去!!!
#include<map>
#include<set>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<sstream>
#include<cstring>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
using namespace std;
#define eps 1.0E-8
#define endl '\n'
#define inf 0x3f3f3f3f
#define MAX 200000 + 50
#define mod 1000000007
#define lowbit(x) (x & (-x))
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d %d",&n,&m)
#define pd(n) printf("%d\n", (n))
#define pdd(n,m) printf("%d %d\n",n, m)
#define sddd(n,m,z) scanf("%d %d %d",&n,&m,&z)
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define mem(a,b) memset((a),(b),sizeof(a))
#define max(a,b) (((a)>(b)) ? (a):(b))
//#define min(a,b) (((a)>(b)) ? (b):(a))
typedef long long ll ;
typedef unsigned long long ull;
//不开longlong见祖宗!不看范围见祖宗!
inline int IntRead(){char ch = getchar();int s = 0, w = 1;while(ch < '0' || ch > '9'){if(ch == '-') w = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar();}return s * w;}
int tot, cnt, num, root, id;
string str[MAX];
bool vis1[MAX], vis2[MAX];
int sum[MAX];
int tr[MAX][30], tree[MAX][30];
set<string>se;
set<string>::iterator it;
inline void insert1(string s){
root = 0;
int len = (int)s.size();
for(int i = 0; i < len; ++i){
id = s[i] - 'a';
if(!tr[root][id])tr[root][id] = ++tot;
root = tr[root][id];
}
vis1[root] = 1;
}
inline void update1(string s){
root = 0;
int len = (int)s.size();
for(int i = 0; i < len - 1; ++i){
id = s[i] - 'a';
root = tr[root][id];
if(vis1[root]){
++sum[i];
}
}
}
inline void insert2(string s){
root = 0;
int len = (int)s.size();
for(int i = 0; i < len; ++i){
id = s[i] - 'a';
if(!tree[root][id])tree[root][id] = ++cnt;
root = tree[root][id];
}
vis2[root] = 1;
}
inline void update2(string s){
root = 0;
int len = (int)s.size();
for(int i = 0; i < len - 1; ++i){
id = s[i] - 'a';
root = tree[root][id];
if(vis2[root]){
++sum[len - i - 2];//这里的下标要注意⚠️
}
}
}
bool judge(string s){
int len = (int)s.size();
for(int i = 0; i < len; ++i)sum[i] = 0;
update1(s);
reverse(s.begin(), s.end());
update2(s);
for(int i = 0; i < len; ++i){
if(sum[i] == 2)return true;
}
return false;
}
int main(){
while (cin>>str[++num]) {
string s = str[num];
if(str[num] == "p")break;
insert1(s);
reverse(s.begin(), s.end());
insert2(s);
reverse(s.begin(), s.end());
}
for(int i = 1; i <= num; ++i){
if(judge(str[i]))se.insert(str[i]);
}
for(it = se.begin(); it != se.end(); ++it)cout<<*it<<endl;
return 0;
}