Skip to content

题解-P2292

Posted on:2023年7月16日 at 11:10

本文章遵守知识共享协议 CC-BY-NC-SA,同步发表于洛谷题解区,转载时须在文章的任一位置附上原文链接和作者署名(rickyxrc)。推荐在我的个人博客阅读。

题面大意

给你 nn 个单词和 mm 个字符串,求对于每个字符串,对于每个字符串 SS,找出一种拆分方法,使得 SS 的某个前缀可以恰好被拆分成任意个这些单词拼接起来的结果,你需要输出这个最长前缀的长度。

解题思路

一看题,不难想到一个 naive 的思路:建立 AC 自动机,在 AC 自动机上对于所有 fail 指针的子串转移,最后取最大值得到答案。

主要代码如下(若不熟悉代码中的类型定义可以跳到末尾的完整代码):

void query(char *s)
{
    int u = 1, len = strlen(s), l = 0;
    for (int i = 0; i < len; i++)
    {
        int v = s[i] - 'a';
        int k = trie[u].son[v];
        while (k > 1)
        {
            if (trie[k].flag && (dp[i - trie[k].len] || i - trie[k].len == -1))
                dp[i] = dp[i - trie[k].len] + trie[k].len;
            k = trie[k].fail;
        }
        u = trie[u].son[v];
    }
}

主函数里取 max 即可。

for (int i = 0, e = strlen(T); i < e; i++)
    mx = std::max(mx, dp[i]);

但是这样的思路复杂度不是线性(因为要跳每个节点的 fail),会被 subtask#2 卡到 T,所以我们需要一个优化的思路。

我们再看看题目的特殊性质,我们发现所有单词的长度只有 2020,所以可以想到状态压缩优化。

具体怎么优化呢?我们发现,目前的时间瓶颈主要在跳 fail 这一步,如果我们可以将这一步优化到 O(1)O(1),就可以保证整个问题在严格线性的时间内被解出。

那我们就将前 2020 位字母中,可能的子串长度存下来,并压缩到状态中,存在每个子节点中。

那么我们在 buildfail 的时候就可以这么写:

void getfail(void)
{
    for (int i = 0; i < 26; i++)
        trie[0].son[i] = 1;
    q.push(1);
    trie[1].fail = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        int Fail = trie[u].fail;
        // 对状态的更新在这里
        trie[u].stat = trie[Fail].stat;
        if (trie[u].flag)
            trie[u].stat |= 1 << trie[u].depth;
        for (int i = 0; i < 26; i++)
        {
            int v = trie[u].son[i];
            if (!v)
                trie[u].son[i] = trie[Fail].son[i];
            else
            {
                trie[v].depth = trie[u].depth + 1;
                trie[v].fail = trie[Fail].son[i];
                q.push(v);
            }
        }
    }
}

然后查询时就可以去掉跳 fail 的循环,将代码简化如下:

int query(char *s)
{
    int u = 1, len = strlen(s), mx = 0;
    unsigned st = 1;
    for (int i = 0; i < len; i++)
    {
        int v = s[i] - 'a';
        u = trie[u].son[v];
        // 因为往下跳了一位每一位的长度都+1
        st <<= 1;
        // 这里的 & 十分妙,下文会讲到
        if (trie[u].stat & st)
            st |= 1,
            mx = i + 1;
    }
    return mx;
}

下面来解答一下 & 在上文代码中的作用。

我们的 trie[u].stat 维护的是从 u 节点开始,整条 fail 链上的长度集(因为长度集小于 3232 所以不影响),而 st 则维护的是查询字符串走到现在,前 3232 位(因为状态压缩自然溢出)的长度集。

& 值不为 00,则代表两个长度集的交集非空,我们此时就找到了一个匹配。

代码如下:

#include <stdio.h>
#include <string.h>
#include <queue>

#define maxn 3000001
char T[maxn];
int n, cnt, vis[maxn], ans, m, dp[maxn];
struct trie_node
{
    int son[26];
    int fail, flag, depth;
    unsigned stat;
    void init();
} trie[maxn];

std::queue<int> q;

void init();
void insert(char *s, int num);

void getfail(void)
{
    for (int i = 0; i < 26; i++)
        trie[0].son[i] = 1;
    q.push(1);
    trie[1].fail = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        int Fail = trie[u].fail;
        trie[u].stat = trie[Fail].stat;
        if (trie[u].flag)
            trie[u].stat |= 1 << trie[u].depth;
        for (int i = 0; i < 26; i++)
        {
            int v = trie[u].son[i];
            if (!v)
                trie[u].son[i] = trie[Fail].son[i];
            else
            {
                trie[v].depth = trie[u].depth + 1;
                trie[v].fail = trie[Fail].son[i];
                q.push(v);
            }
        }
    }
}

int query(char *s)
{
    int u = 1, len = strlen(s), mx = 0;
    unsigned st = 1;
    for (int i = 0; i < len; i++)
    {
        int v = s[i] - 'a';
        u = trie[u].son[v];
        st <<= 1;
        if (trie[u].stat & st)
            st |= 1,
                mx = i + 1;
    }
    return mx;
}

int main()
{
    scanf("%d%d", &n, &m);
    init();
    for (int i = 1; i <= n; i++)
    {
        scanf("%s", T);
        insert(T, i);
    }
    getfail();
    for (int i = 1; i <= m; i++)
    {
        scanf("%s", T);
        printf("%d\n", query(T));
    }
}


在 Rickyxrc's blog 出现的文章,若无特殊注明,均采用 CC BY-NC-SA 4.0 协议共享,也就是转载时需要注明本文章的地址,并且引用本文章的文章也要使用相同的方式共享。