本文章遵守知识共享协议 CC-BY-NC-SA,同步发表于洛谷题解区,转载时须在文章的任一位置附上原文链接和作者署名(rickyxrc)。推荐在我的个人博客阅读。
这是我的第一篇 NTT 题解,思路和其它题解几乎一样,不过推式子的过程和讲解更加详细,希望可以成为新人学习 NTT 的有力基础,当然会有讲的不周到的地方,欢迎各位补充与指正。
阅读本篇题解,您需要掌握拉格朗日插值的基本知识并 AC 拉格朗日插值。
题面大意
给定一个不超过 n 次的多项式的 n+1 个点值 f(0),f(1)…f(n),和一个正整数 m,求 f(m),f(m+1)…f(m+n)。
解题思路
首先我们写出拉格朗日插值的式子:
li(x)=yij=i∏xi−xjx−xj
f(x)=i=0∑nli(x)
这个式子我不会证明,不过简单地来想想,将 xi 带入该式子,就会将除了 li(x) 之外的值全部消掉(等于 0),得到 yi,所以可以认为它是对的(当然这显然不是严谨的证明)。
把所有的 li(x) 带入 f(x),我们就得到了拉格朗日插值的另一种表达形式:
f(x)=i=0∑nyij=i∏xi−xjx−xj
我们现在需要求出的是 f(m),f(m+1),...,f(m+n),我们做如下推导:
f(m+x)=i=0∑nyij=i∏xi−xj(m+x)−xj
因为给入的 x 是连续的且是 [0,n],所以式子可以简化如下:
f(m+x)=i=0∑nf(i)j=i∏i−j(m+x)−j
这一段连乘可以变成阶乘的形式,但是我没一步推出来所以我们写详细一点。
我们先下放连乘符号,于是式子可以写成这样:
f(m+x)=i=0∑nf(i)∏j=i(i−j)∏j=i(m+x−j)
然后我们先看分子:
j=i∏(m+x−j)
=m+x−i(m+x−1)×(m+x−2)×(m+x−3)×⋯×(m+x−n)
明显可以简化成阶乘的式子:
=(m+x−i)×(m+x−n−1)!(m−x−1)!
然后看分母:
j=i∏(i−j)
这个式子可以化成:
j<i∏(i−j)j=i+1∏n(i−j)
=i!×j=i+1∏n−(j−i)
=i!×(−1)n−i×j=i+1∏n(j−i)
=i!×(−1)n−i×(n−i)!
所以整段分式就可以化成:
i!×(−1)n−i×(n−i)!(m+x−i)×(m+x−n−1)!(m−x−1)!
=(m+x−i)×(m+x−n−1)!×i!×(−1)n−i×(n−i)!(m−x−1)!
原式就变成了:
f(m+x)=i=0∑n(m+x−i)×(m+x−n−1)!×i!×(−1)n−i×(n−i)!f(i)×(m−x−1)!
提出公因式:
=(m+x−n−1)!(m−x−1)!i=0∑n(m+x−i)×i!×(−1)n−i×(n−i)!f(i)
因为我们需要 NTT 求解,所以我们选择将其凑成卷积的形式:
=(m+x−n−1)!(m−x−1)!i=0∑n(m+x−i)1×(n−i)!×i!×(−1)n−if(i)
然后答案就比较显然了,为两个数列卷积之后再略微处理的形式。
形式化地说,设 Ai=(n−i)!×i!×(−1)n−if(i),Bi=m−n+i1,则令 F=A∗B,(m+x−n−1)!Fi×(m−x−1)! 即为答案。
然后需要求逆元和阶乘逆元,最好能线性求。
因为阶乘可能很大,所以我们不能预处理阶乘逆元,而是预处理上升幂,参考代码中 jinv 和 mjcsinv 的部分。
代码如下:
#include <stdio.h>
#define maxn 1000007
typedef long long i64;
i64 g = 3, gi;
const i64 mod = 998244353;
i64 t[maxn], a[maxn], b[maxn], n, m, len = 1, l, r[maxn], jcs[maxn], jinv[maxn], mjcs[maxn], mjcsinv[maxn], minv[maxn], ninv[maxn];
inline i64 pow(i64 x, i64 p);
inline i64 inv(i64 x);
void ntt(i64 *c, i64 op);
void getinv(i64 n, i64 m)
{
jcs[0] = mjcs[0] = 1;
for (i64 i = 1; i <= 2 * n + 1; i++)
jcs[i] = jcs[i - 1] * i % mod,
mjcs[i] = mjcs[i - 1] * (m - n + i - 1) % mod;
jinv[2 * n + 1] = inv(jcs[2 * n + 1]);
mjcsinv[2 * n + 1] = inv(mjcs[2 * n + 1]);
for (int i = 2 * n + 1; i; i--)
jinv[i - 1] = jinv[i] * i % mod,
mjcsinv[i - 1] = mjcsinv[i] * (m - n + i - 1) % mod,
ninv[i] = jinv[i] * jcs[i - 1] % mod,
minv[i] = mjcsinv[i] * mjcs[i - 1] % mod;
minv[0] = 1;
// for(int i=1;i<=2*n+1;i++)
// printf("[%lld %lld %lld %lld] ",
// i*ninv[i]%mod,(m-n+i-1)*minv[i]%mod,jcs[i]*jinv[i]%mod,mjcs[i]*mjcsinv[i]%mod);
}
int main()
{
gi = inv(g);
scanf("%lld%lld", &n, &m);
for (i64 i = 0; i <= n; i++)
scanf("%lld", t + i);
getinv(n, m);
for (i64 i = 0; i <= n; i++)
{
a[i] = t[i] * jinv[i] % mod * jinv[n - i] % mod;
if ((n - i) & 1)
a[i] = mod - a[i];
}
for (i64 i = 0; i <= 2 * n; i++)
b[i] = minv[i + 1];
while (len <= 2 * n)
len <<= 1, l++;
for (i64 i = 0; i < len; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
ntt(a, 1);
ntt(b, 1);
for (i64 i = 0; i < len; i++)
a[i] = a[i] * b[i] % mod;
ntt(a, -1);
i64 linv = pow(len, mod - 2);
for (i64 i = n; i <= 2 * n; i++)
printf("%lld ", mjcs[i + 1] * a[i] % mod * mjcsinv[i - n] % mod * linv % mod);
return 0;
}