多项式求逆

多项式求逆指的是给定一个多项式F(x)F(x),求出一个多项式G(x)G(x)满足

F(x)G(x)1(modxn)F(x)*G(x)\equiv1\pmod {x^n}

它是怎么做的?

我们称一个多项式的“度”为其最高次项系数+1+1

首先,我们知道当n=1n=1的时候,显然G(x)G(x)即为F(x)F(x)的常数项之逆元

我们将原式写成模xn2x^{\lceil\frac n 2\rceil}意义下的形式:

F(x)G(x)1(modxn2)F(x)*G(x)\equiv1\pmod {x^{\lceil\frac n 2\rceil}}

假设我们已经求出B(x)B(x)满足

F(x)B(x)1(modxn2)F(x)*B(x)\equiv1\pmod {x^{\lceil\frac n 2\rceil}}

将两个式子相减

G(x)B(x)0(modxn2)G(x)-B(x)\equiv0\pmod{x^{\lceil\frac n 2\rceil}}

平方一下

G2(x)2G(x)B(x)+B2(x)0(modxn)G^2(x)-2G(x)B(x)+B^2(x)\equiv0\pmod{x^n}

两边乘上F(x)F(x)

G(x)2B(x)+F(x)B2(x)0(modxn)G(x)-2B(x)+F(x)B^2(x)\equiv0\pmod{x^n}

(这里由于F(x)G(x)1(modxn)F(x)*G(x)\equiv1\pmod{x^n},消去了一些部分)

移项整理得

G(x)(2F(x)B(x))B(x)(modxn)G(x)\equiv(2-F(x)B(x))B(x)\pmod{x^n}

多项式乘法可以用FFT/NTT加速

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <iostream>
#define inv(x) (fastpow((x),mod-2))
using namespace std;
typedef long long ll;

template <typename T>void read(T &t)
{
t=0;int f=0;char c=getchar();
while(!isdigit(c)){f|=c=='-';c=getchar();}
while(isdigit(c)){t=t*10+c-'0';c=getchar();}
if(f)t=-t;
}

const ll mod=998244353,gg=3,ig=332748118;
const int maxn=100000+5;
int n;
ll a[maxn<<2],b[maxn<<2];

ll fastpow(ll a,ll b)
{
ll re=1,base=a;
while(b)
{
if(b&1)
re=re*base%mod;
base=base*base%mod;
b>>=1;
}
return re;
}

int len;
int r[maxn<<2];
void NTT(ll *f,int type)
{
for(register int i=0;i<len;++i)
if(i<r[i])
swap(f[i],f[r[i]]);
for(register int p=2;p<=len;p<<=1)
{
int length=p>>1;
ll unr=fastpow(type?gg:ig,(mod-1)/p);
for(register int l=0;l<len;l+=p)
{
ll w=1;
for(register int i=l;i<l+length;++i,w=w*unr%mod)
{
ll tt=f[i+length]*w%mod;
f[i+length]=(f[i]-tt+mod)%mod;
f[i]=(f[i]+tt)%mod;
}
}
}
if(!type)
{
ll ilen=inv(len);
for(register int i=0;i<len;++i)
f[i]=f[i]*ilen%mod;
}
}

ll c[maxn<<2];
void getinv(int deg,ll *a,ll *b)
{
if(deg==1)
{
b[0]=inv(a[0]);
return;
}
getinv((deg+1)>>1,a,b);
for(len=1;len<=(deg<<1);len<<=1);
for(register int i=0;i<len;++i)
{
r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
c[i]=(i<deg?a[i]:0);
}
NTT(c,1),NTT(b,1);
for(register int i=0;i<len;++i)
b[i]=(2ll-c[i]*b[i]%mod+mod)%mod*b[i]%mod;
NTT(b,0);
fill(b+deg,b+len,0);//重要,因为是在模 x^deg 意义下
}

int main()
{
read(n);
for(register int i=0;i<n;++i)read(a[i]);
getinv(n,a,b);
for(register int i=0;i<n;++i)
printf("%lld ",b[i]);
return 0;
}