当前位置: 首页 > 工具软件 > mtt > 使用案例 >

洛谷 - P4245 【模板】任意模数多项式乘法(三模NTT+中国剩余定理/五次FFT的MTT)

楚灿
2023-12-01

题目链接:点击查看

题目大意:给出一个长度为 n n n 的多项式 F ( x ) F(x) F(x) 和一个长度为 m m m 的多项式 G ( x ) G(x) G(x),求 F ( x ) ∗ G ( x ) F(x)*G(x) F(x)G(x),模数任意,值域 1 e 9 1e9 1e9

题目分析:如果不取模的话极限情况下会达到 1 0 9 ∗ 1 0 9 ∗ 1 0 5 = 1 0 23 10^9*10^9*10^5=10^{23} 109109105=1023

众所周知,NTT 支持值域更大的多项式相乘,但是不支持任意模数。

FFT 支持任意模数,但是不支持值域很大的多项式相乘。

所以本题就针对两种做法各自进行了变形,今天先补一下针对 NTT 的做法,FFT 的后续补


NTT做法:

既然一个模数不够用,那我们直接参考哈希,将模数扩展成三个,只需要取三个模数,使其乘积大于极限情况,最后再用中国剩余定理合并就好了

注意在中国剩余定理时会爆 l o n g   l o n g long\ long long long ,所以可以全部定义成 i n t 128 int128 int128


FTT做法:

考虑设置一个阈值 w = 2 15 w=2^{15} w=215,那么每个数都可以拆成 x = a ∗ w + b x=a*w+b x=aw+b 的形式,类似的,将多项式也拆成这种形式。设初始的多项式为 A ( x ) A(x) A(x) B ( x ) B(x) B(x),我们拆成 A ( x ) = A 1 ( x ) ∗ w + A 2 ( x ) A(x)=A_1(x)*w+A_2(x) A(x)=A1(x)w+A2(x) B ( x ) = B 1 ( x ) ∗ w + B 2 ( x ) B(x)=B_1(x)*w+B_2(x) B(x)=B1(x)w+B2(x)

那么 A ( x ) B ( x ) = A 1 ( x ) B 1 ( x ) w 2 + A 1 ( x ) B 2 ( x ) w + A 2 ( x ) B 1 ( x ) w + A 2 ( x ) B 2 ( x ) A(x)B(x)=A_1(x)B_1(x)w^2+A_1(x)B_2(x)w+A_2(x)B_1(x)w+A_2(x)B_2(x) A(x)B(x)=A1(x)B1(x)w2+A1(x)B2(x)w+A2(x)B1(x)w+A2(x)B2(x)

这样一来每一项的乘积的上限是 2 15 ∗ 2 15 ∗ n = 1 0 14 2^{15}*2^{15}*n=10^{14} 215215n=1014,满足了 F F T FFT FFT 的值域范围,也就可以进行求解了,这种设置阈值的方法称为 M T T MTT MTT,不过这种做法需要进行四次插值,点乘后还需要四次插值进行还原,共八次FFT,常数过于太大,下面考虑优化

在复数域中,设 P ( x ) = A 1 ( x ) + A 2 ( x ) i P(x)=A_1(x)+A_2(x)i P(x)=A1(x)+A2(x)i P ′ ( x ) = A 1 ( x ) − A 2 ( x ) i P'(x)=A_1(x)-A_2(x)i P(x)=A1(x)A2(x)i Q ( x ) = B 1 ( x ) + B 2 ( x ) i Q(x)=B_1(x)+B_2(x)i Q(x)=B1(x)+B2(x)i

那么有 T 1 ( x ) = P ( x ) Q ( x ) = A 1 ( x ) B 1 ( x ) − A 2 ( x ) B 2 ( x ) + ( A 1 ( x ) B 2 ( x ) + A 2 ( x ) B 1 ( x ) ) i T_1(x)=P(x)Q(x)=A_1(x)B_1(x)-A_2(x)B_2(x)+(A_1(x)B_2(x)+A_2(x)B_1(x))i T1(x)=P(x)Q(x)=A1(x)B1(x)A2(x)B2(x)+(A1(x)B2(x)+A2(x)B1(x))i T 2 ( x ) = P ′ ( x ) Q ( x ) = A 1 ( x ) B 1 ( x ) + A 2 ( x ) B 2 ( x ) + ( A 1 ( x ) B 2 ( x ) − A 2 ( x ) B 1 ( x ) ) i T_2(x)=P'(x)Q(x)=A_1(x)B_1(x)+A_2(x)B_2(x)+(A_1(x)B_2(x)-A_2(x)B_1(x))i T2(x)=P(x)Q(x)=A1(x)B1(x)+A2(x)B2(x)+(A1(x)B2(x)A2(x)B1(x))i

然后将 T 1 ( x ) T_1(x) T1(x) T 2 ( x ) T_2(x) T2(x) 通过求和和作差就可以求出我们需要的任意一项了,总共需要三次插值和两次还原,共五次FFT

代码:

三模NTT:

// Problem: P4245 【模板】任意模数多项式乘法
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4245
// Memory Limit: 500 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast","inline","-ffast-math")
// #pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
#include<list>
#include<unordered_map>
#define lowbit(x) (x&-x)
using namespace std;
typedef long long LL;
typedef __int128 ll;
typedef unsigned long long ull;
template<typename T>
inline void read(T &x)
{
	T f=1;x=0;
	char ch=getchar();
	while(0==isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(0!=isdigit(ch)) x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
	x*=f;
}
template<typename T>
inline void write(T x)
{
	if(x<0){x=~(x-1);putchar('-');}
    if(x>9)write(x/10);
    putchar(x%10+'0');
}
const int inf=0x3f3f3f3f;
const int N=1e6+100;
int mod[]={998244353,1004535809,469762049};
int n,m,limit=1,L,r[N];
LL a[3][N],b[3][N],res[N][3];
inline LL fastpow(LL a, LL k,int mod) {
	LL base = 1;
	while(k) {
		if(k & 1) base = (base * a ) % mod;
		a = (a * a) % mod;
		k >>= 1;
	}
	return base % mod;
}
inline void NTT(LL *A,int mod,int type,int id) {
	for(int i = 0; i < limit; i++) 
		if(i < r[i]) swap(A[i], A[r[i]]);
	for(int mid = 1; mid < limit; mid <<= 1) {	
		LL Wn = fastpow( type == 1 ? 3 : (mod+1)/3 , (mod - 1) / (mid << 1),mod);
		for(int j = 0; j < limit; j += (mid << 1)) {
			LL w = 1;
			for(int k = 0; k < mid; k++, w = (w * Wn) % mod) {
				 int x = A[j + k], y = w * A[j + k + mid] % mod;
				 A[j + k] = (x + y) % mod,
				 A[j + k + mid] = (x - y + mod) % mod;
			}
		}
	}
	if(type==-1) {
		LL inv=fastpow(limit,mod-2,mod);
		for(int i=0;i<limit;i++) {
			res[i][id]=A[i]*inv%mod;
		}
	}
}
void init() {
	while(limit<=n+m) limit<<=1,L++;
	for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
}
void ex_gcd(ll a,ll b,ll &d,ll &x,ll &y)
{
    if(b==0){
        d=a;
        x=1,y=0;
    }
    else{
        ex_gcd(b,a%b,d,y,x);
        y-=(a/b)*x;
    }
}
ll China(int n,int *m,LL *a)
{
    ll M=1,d,y,x=0;
    for(int i=0;i<n;i++) M=M*m[i];
    for(int i=0;i<n;i++){
        ll w=M/m[i];
        ex_gcd(m[i],w,d,d,y);
        x=(x+y*w*a[i])%M;
    }
    return (x+M)%M;
}
int main()
{
#ifndef ONLINE_JUDGE
//	freopen("data.in.txt","r",stdin);
//	freopen("data.out.txt","w",stdout);
#endif
//	ios::sync_with_stdio(false);
	int p;
	read(n),read(m),read(p);
	n++,m++;
	init();
	for(int i=0,x;i<n;i++) {
		read(x);
		for(int j=0;j<3;j++) {
			a[j][i]=x%mod[j];
		}
	}
	for(int i=0,x;i<m;i++) {
		read(x);
		for(int j=0;j<3;j++) {
			b[j][i]=x%mod[j];
		}
	}
	for(int j=0;j<3;j++) {
		NTT(a[j],mod[j],1,j),NTT(b[j],mod[j],1,j);
		for(int i=0;i<limit;i++) {
			a[j][i]=a[j][i]*b[j][i]%mod[j];
		}
		NTT(a[j],mod[j],-1,j);
	}
	for(int i=0;i<n+m-1;i++) {
		printf("%lld ",(LL)(China(3,mod,res[i])%p));
	}
	return 0;
}

MTT:

// Problem: P4245 【模板】任意模数多项式乘法
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4245
// Memory Limit: 500 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast","inline","-ffast-math")
// #pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
#include<list>
#include<unordered_map>
#define lowbit(x) (x&-x)
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
template<typename T>
inline void read(T &x)
{
	T f=1;x=0;
	char ch=getchar();
	while(0==isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(0!=isdigit(ch)) x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
	x*=f;
}
template<typename T>
inline void write(T x)
{
	if(x<0){x=~(x-1);putchar('-');}
    if(x>9)write(x/10);
    putchar(x%10+'0');
}
const int inf=0x3f3f3f3f;
const int N=1e5+100;
const int BASE=1<<15;
const long double Pi=acos(-1.0);
struct complex
{
    long double x,y;
    complex (long double xx=0,long double yy=0){x=xx,y=yy;}
}P1[N<<2],P2[N<<2],Q[N<<2];
complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);}
complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);}
complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);}
int limit,r[N<<2],a[N],b[N],res[N<<2];
void FFT(complex *A,int type)
{
	for(int i=0;i<limit;i++) if(i<r[i]) swap(A[i],A[r[i]]); 
		for(int mid=1;mid<limit;mid<<=1) {
			complex Wn( cos(Pi/mid) , type*sin(Pi/mid) );
			for(int R=mid<<1,j=0;j<limit;j+=R) {
				complex w(1,0);
				for(int k=0;k<mid;k++,w=w*Wn) {
					complex x=A[j+k],y=w*A[j+mid+k];
					A[j+k]=x+y;
					A[j+mid+k]=x-y;
				}
			}
		}
}
void init(int n) {
	limit=1;
    while(limit<=n) limit<<=1;
	for(int i=1;i<limit;i++) r[i]=r[i>>1]>>1|((i&1)?limit>>1:0);
}
void MTT(int *a,int *b,int n,int m,int *res,int mod) {
	init(n+m);
	for(int i=0;i<n;i++) {
		P1[i]={a[i]/BASE,a[i]%BASE};
		P2[i]={a[i]/BASE,-a[i]%BASE};
	}
	for(int i=n;i<limit;i++) {
		P1[i]={0,0},P2[i]={0,0};
	}
	for(int i=0;i<m;i++) {
		Q[i]={b[i]/BASE,b[i]%BASE};
	}
	for(int i=m;i<limit;i++) {
		Q[i]={0,0};
	}
	FFT(P1,1),FFT(P2,1),FFT(Q,1);
	for(int i=0;i<limit;i++) {
		Q[i].x/=limit,Q[i].y/=limit;
		P1[i]=P1[i]*Q[i],P2[i]=P2[i]*Q[i];
	}
	FFT(P1,-1),FFT(P2,-1);
	for(int i=0;i<limit;i++) {
		long long a1b1,a1b2,a2b1,a2b2;
		a1b1=(long long)floor((P1[i].x+P2[i].x)/2+0.5)%mod;
		a1b2=(long long)floor((P1[i].y+P2[i].y)/2+0.5)%mod;
		a2b1=(long long)floor((P1[i].y-P2[i].y)/2+0.5)%mod;
		a2b2=(long long)floor((P2[i].x-P1[i].x)/2+0.5)%mod;
		res[i]=((a1b1*BASE+(a1b2+a2b1))*BASE+a2b2)%mod;
		res[i]=(res[i]+mod)%mod;
	}
}
int main()
{
#ifndef ONLINE_JUDGE
//	freopen("data.in.txt","r",stdin);
//	freopen("data.out.txt","w",stdout);
#endif
//	ios::sync_with_stdio(false);
	int n,m,p;
	read(n),read(m),read(p);
	n++,m++;
	for(int i=0;i<n;i++) {
		read(a[i]);
	}
	for(int i=0;i<m;i++) {
		read(b[i]);
	}
	MTT(a,b,n,m,res,p);
	for(int i=0;i<n+m-1;i++) {
		printf("%d ",res[i]);
	}
	return 0;
}
 类似资料: