题目大意:
问有多少
{
L
n
}
,
{
R
n
}
,
∀
i
∈
[
1
,
n
]
L
i
,
R
i
∈
[
1
,
n
]
\{L_n\},\{R_n\},\forall i\in[1,n]L_i,R_i\in[1,n]
{Ln},{Rn},∀i∈[1,n]Li,Ri∈[1,n]满足从任意
x
∈
[
1
,
n
]
x\in[1,n]
x∈[1,n]出发,先做
A
A
A次
x
=
L
x
x=L_x
x=Lx,再做一次
x
=
R
x
x=R_x
x=Rx,再做
B
B
B次
x
=
L
x
x=L_x
x=Lx,再做一次
x
=
R
x
x=R_x
x=Rx,再做
C
C
C次
x
=
L
x
x=L_x
x=Lx后变成原来的值。
题解:
首先特判掉
A
=
B
=
C
=
0
A=B=C=0
A=B=C=0的情况,其余情况均满足
L
,
R
L,R
L,R分别是两个排列(否则至少存在两个不同的点某一步会走到同一个点然后之后就分不开了)。
通过打表,观察到如下性质:
A
n
s
(
n
,
A
,
B
,
C
)
=
A
n
s
(
n
,
C
,
B
,
A
)
=
A
n
s
(
n
,
A
,
B
−
1
,
C
−
1
)
=
A
n
s
(
n
,
A
−
1
,
B
−
1
,
C
)
\mathrm{Ans}(n,A,B,C)=\mathrm{Ans}(n,C,B,A)\\=\mathrm{Ans}(n,A,B-1,C-1)=\mathrm{Ans}(n,A-1,B-1,C)
Ans(n,A,B,C)=Ans(n,C,B,A)=Ans(n,A,B−1,C−1)=Ans(n,A−1,B−1,C)
所以问题转化为
A
n
s
(
n
,
A
,
0
,
C
)
\mathrm{Ans}(n,A,0,C)
Ans(n,A,0,C)或者
A
n
s
(
n
,
0
,
B
,
0
)
\mathrm{Ans}(n,0,B,0)
Ans(n,0,B,0)
然后还可以知道:
A
n
s
(
n
,
A
,
0
,
C
)
=
A
n
s
(
b
,
A
+
C
,
0
,
0
)
A
n
s
(
n
,
0
,
B
,
0
)
=
A
n
s
(
n
,
B
,
0
,
0
)
\mathrm{Ans}(n,A,0,C)=\mathrm{Ans}(b,A+C,0,0)\\ \mathrm{Ans}(n,0,B,0)=\mathrm{Ans}(n,B,0,0)
Ans(n,A,0,C)=Ans(b,A+C,0,0)Ans(n,0,B,0)=Ans(n,B,0,0)
所以问题可以转化为
A
n
s
(
n
,
∣
A
+
C
−
B
∣
,
0
,
0
)
\mathrm{Ans}(n,|A+C-B|,0,0)
Ans(n,∣A+C−B∣,0,0),记为H(n,A)。
然后考虑现在问题是从任意点出发,沿着L走A步,R走两步回到自己。
考虑每个点沿着L走A步后得到的新的排列(记作L’)长啥样,显然原来一个长度时x的环,会变成 d = gcd ( x , A ) d=\gcd(x,A) d=gcd(x,A)个环,每个环环长 x d \frac xd dx
这样相当于在L’上做A=1的问题,以及对每个L’求可以通过多少L弄出来(显然有些L’是无法通过L和原来的A得到的,而有些L’可以通过很多L弄出来)。
后面的问题比较简单先讨论后面的问题。
显然你只能把相同长度的环长拼起来,现在要计算把k个长度为x的环拼起来的方案数。
首先要保证
gcd
(
k
x
,
A
)
=
k
\gcd(kx,A)=k
gcd(kx,A)=k,然后答案是
(
k
−
1
)
!
x
k
−
1
(k-1)!x^{k-1}
(k−1)!xk−1(前面那个是换排列,后面是换排列的时候每个点是原来的一个环的某个点)。
然后考虑某个L’,那些R是合法的。
假设x是某个点,然后R(x)=z。
分类:
当x和z在L’的同一个环C上的话,那么若C的环长是偶数,则必定不存在某个R合法,否咋存在恰好一个R。
否则你会发现x所在环和z所在环的R就都确定了,而且必须有x所在环的环长等于z所在环的环长。
以上自己手画一下就会发现。
因此这个问题也是不同环长独立。
也就是整个问题可以分为很多不同环长的问题。
一下考虑环长是x用了k个的答案。
首先如何拼成L上面提到过了,拼成R还剩一步:
对于第k个环,要么和之前某个环确认一个R(这时候要乘以x,因为对于第k个环的某个点要确认其R对应你之前那个环的哪个点)。
若x为奇数,还可以直接自己和自己确认R。
最后还要乘以一些组合数之类的去个重啥的,细节还挺多的(感觉)。
这版代码加了点卡常代码不是很好理解:
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline lint inln()
{
lint x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int inn() { return int(inln()); }
inline lint gcd(lint x,lint y) { return x?gcd(y%x,x):y; }
const int N=1010;
int a[N],C[N][N],g[N],f[N],fac[N],fmi[N],dp[N],facinv[N],inv[N],mi[N],gd[N],gsav[N];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int H(int n,lint A)
{
int m=0;rep(i,1,n) if(A%i==0) a[++m]=i;
memset(g,0,sizeof(int)*(n+1)),g[0]=1;
rep(i,1,n)
{
int u=n/i;memset(f,0,sizeof(int)*(u+1)),f[0]=1;
mi[0]=1;rep(j,1,u) mi[j]=(lint)mi[j-1]*i%mod;
rep(k,1,m&&a[k]<=u) if((gd[k]=gcd(i,A/a[k]))==1)
gsav[k]=(lint)fac[a[k]-1]%mod*mi[a[k]-1]%mod;
rep(j,1,u) rep(k,1,m&&a[k]<=j) if(gd[k]==1)
f[j]=(f[j]+(lint)f[j-a[k]]*C[j-1][a[k]-1]%mod*gsav[k])%mod;
memset(dp,0,sizeof(int)*(u+1)),dp[0]=1,dp[1]=(i&1);
rep(j,2,u) dp[j]=((i&1)*dp[j-1]+(j-1ll)*i%mod*dp[j-2])%mod;
fmi[0]=1;rep(j,1,u) fmi[j]=(lint)fmi[j-1]*inv[i]%mod;
rep(j,0,u) fmi[j]=(lint)fmi[j]*fac[j*i]%mod*facinv[j]%mod;
rep(k,0,u) f[k]=(lint)f[k]*dp[k]%mod*fmi[k]%mod;
for(int j=n;j>=0;j--) rep(k,1,j/i)
g[j]=(g[j]+(lint)C[j][k*i]*g[j-k*i]%mod*f[k])%mod;
}
return g[n];
}
inline int prelude(int n)
{
fac[0]=1;rep(i,1,n) fac[i]=(lint)fac[i-1]*i%mod;rep(i,0,n) C[i][0]=1;
rep(i,0,n) facinv[i]=fast_pow(fac[i],mod-2),inv[i]=fast_pow(i,mod-2);
rep(i,1,n) rep(j,1,n) C[i][j]=C[i-1][j]+C[i-1][j-1],(C[i][j]>=mod?C[i][j]-=mod:0);
return 0;
}
int main()
{
prelude(N-1);
for(int T=inn();T;T--)
{
int n=inn();lint A=inln(),B=inln(),C=inln();
if(A+B+C==0)
{
dp[0]=dp[1]=1;rep(i,2,n) dp[i]=(dp[i-1]+(i-1ll)*dp[i-2])%mod;
int ans=dp[n];rep(i,1,n) ans=(lint)ans*n%mod;printf("%d\n",ans);
}
else if(B>=A+C) printf("%d\n",H(n,B-A-C));
else{
lint t=min(A,B);A-=t,B-=t;
t=min(B,C),B-=t,C-=t;
printf("%d\n",H(n,A+C));
}
}
return 0;
}
这是最初版本:
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline lint inln()
{
lint x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int inn() { return int(inln()); }
inline lint gcd(lint x,lint y) { return x?gcd(y%x,x):y; }
const int N=1010;
int a[N],C[N][N],g[N][N],f[N],fac[N],fmi[N],dp[N],facinv[N],inv[N],mi[N];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int H(int n,lint A)
{
int m=0,t=0;
rep(i,1,n) if(A%i==0) a[++m]=i;
rep(i,0,t) memset(g[i],0,sizeof(int)*(n+1));
g[0][0]=1;
rep(i,1,n)
{
int u=n/i;
memset(f,0,sizeof(int)*(u+1)),f[0]=1;
mi[0]=1;rep(j,1,u) mi[j]=(lint)mi[j-1]*i%mod;
rep(j,1,u) rep(k,1,m&&a[k]<=j) if(gcd(i,A/a[k])==1)
f[j]=(f[j]+(lint)f[j-a[k]]*C[j-1][a[k]-1]%mod*fac[a[k]-1]%mod*mi[a[k]-1])%mod;
memset(dp,0,sizeof(int)*(u+1)),dp[0]=1,dp[1]=(i&1);
rep(j,2,u) dp[j]=((i&1)*dp[j-1]+(j-1ll)*i%mod*dp[j-2])%mod;
fmi[0]=1;rep(j,1,u) fmi[j]=(lint)fmi[j-1]*inv[i]%mod;
rep(j,0,u) fmi[j]=(lint)fmi[j]*fac[j*i]%mod*facinv[j]%mod;
memset(g[i],0,sizeof(int)*(n+1));
rep(j,0,n) rep(k,0,j/i)
g[i][j]=(g[i][j]+(lint)C[j][k*i]*g[i-1][j-k*i]%mod*fmi[k]%mod*f[k]%mod*dp[k])%mod;
}
return g[n][n];
}
inline int prelude(int n)
{
fac[0]=1;rep(i,1,n) fac[i]=(lint)fac[i-1]*i%mod;rep(i,0,n) C[i][0]=1;
rep(i,0,n) facinv[i]=fast_pow(fac[i],mod-2),inv[i]=fast_pow(i,mod-2);
rep(i,1,n) rep(j,1,n) C[i][j]=C[i-1][j]+C[i-1][j-1],(C[i][j]>=mod?C[i][j]-=mod:0);
return 0;
}
int main()
{
prelude(N-1);
for(int T=inn();T;T--)
{
int n=inn();lint A=inln(),B=inln(),C=inln();
if(A+B+C==0)
{
dp[0]=dp[1]=1;rep(i,2,n) dp[i]=(dp[i-1]+(i-1ll)*dp[i-2])%mod;
int ans=dp[n];rep(i,1,n) ans=(lint)ans*n%mod;printf("%d\n",ans);
}
else if(B>=A+C) printf("%d\n",H(n,B-A-C));
else{
lint t=min(A,B);A-=t,B-=t;
t=min(B,C),B-=t,C-=t;
printf("%d\n",H(n,A+C));
}
}
return 0;
}