我们定义 f ( S , i ) f(S,i) f(S,i) 是将字符串 S S S 的前 i i i 个字符删掉拼接到末尾的字符串。如 f ( ′ a b d ′ , 2 ) = ′ d a b ′ f('abd~',2)='dab~' f(′abd ′,2)=′dab ′。
给定两个长度为 n n n 的字符串 S S S 和 T T T,求对于所有 0 ≤ i , j ≤ n − 1 0 \le i,j \le n-1 0≤i,j≤n−1 的 ( i , j ) (i,j) (i,j),字典序 f ( S , i ) f(S,i) f(S,i) 小于等于 f ( T , j ) f(T,j) f(T,j) 的个数。
1 ≤ n ≤ 2 ∗ 1 0 5 1 \le n \le 2*10^5 1≤n≤2∗105
不难想到在字符串后面拼接一个相同的串,这样 f ( S , i ) f(S,i) f(S,i) 就等价于从位置 i + 1 i+1 i+1 开始往后数 n n n 个的字符串。这样能把全部的 n n n 个字符串在一个串中表示出来。
然后将两字符串拼接起来求一下 S A [ ] SA[] SA[],按排名扫描一遍,答案就是 s a [ i ] < s a [ j ] ( 1 ≤ i ≤ n , 2 n + 2 ≤ j ≤ 3 n + 1 ) sa[i]<sa[j]~(1 \le i \le n, 2n+2 \le j\le 3n+1) sa[i]<sa[j] (1≤i≤n,2n+2≤j≤3n+1) 的对数。注意拼接时间往两串间添加一个小于 ′ a ′ 'a' ′a′ 的字符,最后要补一个大于 ’ z ′ ’z' ’z′ 的字符,这样才能保证 S S S 串和 T T T 串相等时 S S S 串排前面。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6+5;
int s1[N], s2[N];
char s[N];
int num[N], cnt[N], sa[N], rk[N], tmp[N];
void get_sa(int n){
int m=500;
for(int i=1;i<=n;i++) num[i]=s[i];
for(int i=1;i<=n;i++) cnt[num[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[num[i]]--]=i;
for(int k=1;k<=n;k<<=1){
int tot=0;
for(int i=n-k+1;i<=n;i++) rk[++tot]=i;
for(int i=1;i<=n;i++){
if(sa[i]>k) rk[++tot]=sa[i]-k;
}
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[num[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--){
sa[cnt[num[rk[i]]]--] = rk[i];
tmp[i]=num[i];
}
num[sa[1]]=1, tot=1;
for(int i=2;i<=n;i++){
if(tmp[sa[i]]==tmp[sa[i-1]] && tmp[sa[i]+k]==tmp[sa[i-1]+k]) num[sa[i]]=tot;
else num[sa[i]]=++tot;
}
m=tot;
if(m==n) break;
}
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
int n;cin>>n;
for(int i=1;i<=n;i++){
char c;cin>>c;
s1[i]=c;
}
for(int i=1;i<=n;i++){
char c;cin>>c;
s2[i]=c;
}
int tot=0;
for(int i=1;i<=n;i++) s[++tot]=s1[i];
for(int i=1;i<=n;i++) s[++tot]=s1[i];
s[++tot]='a'-2;
for(int i=1;i<=n;i++) s[++tot]=s2[i];
for(int i=1;i<=n;i++) s[++tot]=s2[i];
s[++tot]='z'+2;
get_sa(tot); // 对s跑一遍后缀数组
// for(int i=1;i<=tot;i++) cout<<sa[i]<<" ";cout<<endl<<endl;
ll ans=0;
ll cnt=0;
for(int i=1;i<=tot;i++){
if(1<=sa[i] && sa[i]<=n) cnt++;
else if(2*n+2<=sa[i] && sa[i]<=3*n+1) ans+=cnt;
}
cout<<ans<<endl;
return 0;
}