当前位置: 首页 > 工具软件 > WMQ > 使用案例 >

xijtuoj wmq的A×B Problem FFT+原根

云伯寅
2023-12-01

题目连接点这里

神套路题

因为m为素数所以必定有原根,设为x,

根据原根那套理论,x^(0)mod m,x^(1)mod m,,,,x^(m-2)mod m,的值互不相同,取遍 1到m-1.所以我们可以把原数组的每个数根据mod m等于多少,可以唯一的用x^(t)代替。

然后将t看出数组下标。。就可以FFT啦。

然后,,mod m等于0的情况貌似无法处理,,我是单独算的

#include<algorithm>
#include<iostream>
#include<stdio.h>
#include<cstring>
#include<cmath>
using namespace std;
#define mem(x,y) memset(x,y,sizeof(x))
#define FIN freopen("input.txt","r",stdin)
#define fuck(x) cout<<x<<endl
const double  eps=1e-7;
const int MX=333333;
#define INF 0x3f3f3f3f
#define INFLL 0x3f3f3f3f3f3f3f3f
typedef long long LL;
typedef pair<LL,LL> PLL;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int n,m,root;
bool isprime[MX];
int prime[MX],prime_cnt;
void init()
{
    mem(isprime,1);
    for(int i=2; i<MX; i++)
    {
        if(isprime[i]) prime[prime_cnt++]=i;
        for(int j=0; j<prime_cnt&&i*prime[j]<MX; j++)
        {
            isprime[i*prime[j]]=0;
            if(i%prime[j]==0)break;
        }
    }
}
int fac[MX],fac_cnt;
void divide(int n)
{
    for(int i=0; i<prime_cnt&&prime[i]<=sqrt(n+0.5); i++)
    {
        if(n%prime[i]==0) fac[fac_cnt++]=prime[i];
        while(n%prime[i]==0)n/=prime[i];
    }
    if(n!=1) fac[fac_cnt++]=n;
}
int quick_pow(int a,int x,int mod)
{
    int e=1;
    while(x)
    {
        if(x&1)e=((LL)e*a)%mod;
        a=(LL)a*a%mod;
        x>>=1;
    }
    return e;
}
int getroot(int n)
{
    fac_cnt=0;
    divide(n-1);
    for(int i=2;; i++)
    {
        int flag=1;
        for(int j=0; j<fac_cnt; j++)if(quick_pow(i,(n-1)/fac[j],n)==1)flag=0;
        if(flag) return i;
    }
}
int w[MX];
void root_init()
{
    int x=1;
    w[1]=0;
    for(int i=1; i<m-1; i++)
    {
        x=(LL)x*root%m;
        w[x]=i;
    }
}
const double pi = acos(-1.0);
int len,mx;//开大4倍
LL res[MX];
struct Complex
{
    long double r,i;
    Complex(double r=0,double i=0):r(r),i(i) {};
    Complex operator+(const Complex &rhs)
    {
        return Complex(r + rhs.r,i + rhs.i);
    }
    Complex operator-(const Complex &rhs)
    {
        return Complex(r - rhs.r,i - rhs.i);
    }
    Complex operator*(const Complex &rhs)
    {
        return Complex(r*rhs.r - i*rhs.i,i*rhs.r + r*rhs.i);
    }
} va[MX],vb[MX];
int arr[MX];
void rader(Complex F[],int len)   //len = 2^M,reverse F[i] with  F[j] j为i二进制反转
{
    int j = len >> 1;
    for(int i = 1; i < len - 1; ++i)
    {
        if(i < j) swap(F[i],F[j]);  // reverse
        int k = len>>1;
        while(j>=k)
        {
            j -= k;
            k >>= 1;
        }
        if(j < k) j += k;
    }
}
void FFT(Complex F[],int len,int t)
{
    rader(F,len);
    for(int h=2; h<=len; h<<=1)
    {
        Complex wn(cos(-t*2*pi/h),sin(-t*2*pi/h));
        for(int j=0; j<len; j+=h)
        {
            Complex E(1,0); //旋转因子
            for(int k=j; k<j+h/2; ++k)
            {
                Complex u = F[k];
                Complex v = E*F[k+h/2];
                F[k] = u+v;
                F[k+h/2] = u-v;
                E=E*wn;
            }
        }
    }
    if(t==-1)   //IDFT
        for(int i=0; i<len; ++i)
            F[i].r/=len;
}
void Conv(Complex a[],Complex b[],int len)   //求卷积
{
    FFT(a,len,1);
    FFT(b,len,1);
    for(int i=0; i<len; ++i) a[i] = a[i]*b[i];
    FFT(a,len,-1);
}
void gao()
{
    int len=1;
    while(len<2*m-3)len<<=1;
    Conv(va,vb,len);
    for(int i=0; i<len; ++i)
    {
        res[i]=va[i].r + 0.5;
        if(i%2==0) res[i]-=arr[i/2];
        res[i]/=2;
    }
    for(int i=m-1; i<len; i++)res[i%(m-1)]+=res[i];
}
int cnt0;
int main()
{
    //FIN;
    init();
    int T;
    cin>>T;
    while(T--)
    {
        mem(va,0);
        mem(vb,0);
        mem(arr,0);
        cnt0=0;
        scanf("%d%d",&n,&m);
        root=getroot(m);
        root_init();
        for(int i=1; i<=n; i++)
        {
            int x;
            scanf("%d",&x);
            x%=m;
            if(x==0)cnt0++;
            else
            {
                arr[w[x]]++;
                va[w[x]].r++;
                vb[w[x]].r++;
            }
        }
        gao();
        printf("%lld\n",(LL)cnt0*(n-cnt0)+(LL)cnt0*(cnt0-1)/2);
        for(int i=1; i<=m-1; i++)printf("%lld\n",res[w[i]]);
    }
    return 0;
}






 类似资料: