本文章遵守知识共享协议 CC-BY-NC-SA,同步发表于洛谷题解区,转载时须在文章的任一位置附上原文链接和作者署名(rickyxrc)。推荐在我的个人博客阅读。
题面大意
给你 个单词和 个字符串,求对于每个字符串,对于每个字符串 ,找出一种拆分方法,使得 的某个前缀可以恰好被拆分成任意个这些单词拼接起来的结果,你需要输出这个最长前缀的长度。
解题思路
一看题,不难想到一个 naive 的思路:建立 AC 自动机,在 AC 自动机上对于所有 fail 指针的子串转移,最后取最大值得到答案。
主要代码如下(若不熟悉代码中的类型定义可以跳到末尾的完整代码):
void query(char *s)
{
int u = 1, len = strlen(s), l = 0;
for (int i = 0; i < len; i++)
{
int v = s[i] - 'a';
int k = trie[u].son[v];
while (k > 1)
{
if (trie[k].flag && (dp[i - trie[k].len] || i - trie[k].len == -1))
dp[i] = dp[i - trie[k].len] + trie[k].len;
k = trie[k].fail;
}
u = trie[u].son[v];
}
}
主函数里取 max 即可。
for (int i = 0, e = strlen(T); i < e; i++)
mx = std::max(mx, dp[i]);
但是这样的思路复杂度不是线性(因为要跳每个节点的 fail),会被 subtask#2 卡到 T,所以我们需要一个优化的思路。
我们再看看题目的特殊性质,我们发现所有单词的长度只有 ,所以可以想到状态压缩优化。
具体怎么优化呢?我们发现,目前的时间瓶颈主要在跳 fail 这一步,如果我们可以将这一步优化到 ,就可以保证整个问题在严格线性的时间内被解出。
那我们就将前 位字母中,可能的子串长度存下来,并压缩到状态中,存在每个子节点中。
那么我们在 buildfail 的时候就可以这么写:
void getfail(void)
{
for (int i = 0; i < 26; i++)
trie[0].son[i] = 1;
q.push(1);
trie[1].fail = 0;
while (!q.empty())
{
int u = q.front();
q.pop();
int Fail = trie[u].fail;
// 对状态的更新在这里
trie[u].stat = trie[Fail].stat;
if (trie[u].flag)
trie[u].stat |= 1 << trie[u].depth;
for (int i = 0; i < 26; i++)
{
int v = trie[u].son[i];
if (!v)
trie[u].son[i] = trie[Fail].son[i];
else
{
trie[v].depth = trie[u].depth + 1;
trie[v].fail = trie[Fail].son[i];
q.push(v);
}
}
}
}
然后查询时就可以去掉跳 fail 的循环,将代码简化如下:
int query(char *s)
{
int u = 1, len = strlen(s), mx = 0;
unsigned st = 1;
for (int i = 0; i < len; i++)
{
int v = s[i] - 'a';
u = trie[u].son[v];
// 因为往下跳了一位每一位的长度都+1
st <<= 1;
// 这里的 & 十分妙,下文会讲到
if (trie[u].stat & st)
st |= 1,
mx = i + 1;
}
return mx;
}
下面来解答一下 &
在上文代码中的作用。
我们的 trie[u].stat
维护的是从 u 节点开始,整条 fail 链上的长度集(因为长度集小于 所以不影响),而 st
则维护的是查询字符串走到现在,前 位(因为状态压缩自然溢出)的长度集。
&
值不为 ,则代表两个长度集的交集非空,我们此时就找到了一个匹配。
代码如下:
#include <stdio.h>
#include <string.h>
#include <queue>
#define maxn 3000001
char T[maxn];
int n, cnt, vis[maxn], ans, m, dp[maxn];
struct trie_node
{
int son[26];
int fail, flag, depth;
unsigned stat;
void init();
} trie[maxn];
std::queue<int> q;
void init();
void insert(char *s, int num);
void getfail(void)
{
for (int i = 0; i < 26; i++)
trie[0].son[i] = 1;
q.push(1);
trie[1].fail = 0;
while (!q.empty())
{
int u = q.front();
q.pop();
int Fail = trie[u].fail;
trie[u].stat = trie[Fail].stat;
if (trie[u].flag)
trie[u].stat |= 1 << trie[u].depth;
for (int i = 0; i < 26; i++)
{
int v = trie[u].son[i];
if (!v)
trie[u].son[i] = trie[Fail].son[i];
else
{
trie[v].depth = trie[u].depth + 1;
trie[v].fail = trie[Fail].son[i];
q.push(v);
}
}
}
}
int query(char *s)
{
int u = 1, len = strlen(s), mx = 0;
unsigned st = 1;
for (int i = 0; i < len; i++)
{
int v = s[i] - 'a';
u = trie[u].son[v];
st <<= 1;
if (trie[u].stat & st)
st |= 1,
mx = i + 1;
}
return mx;
}
int main()
{
scanf("%d%d", &n, &m);
init();
for (int i = 1; i <= n; i++)
{
scanf("%s", T);
insert(T, i);
}
getfail();
for (int i = 1; i <= m; i++)
{
scanf("%s", T);
printf("%d\n", query(T));
}
}