当前位置: 首页 > 工具软件 > Nightfall > 使用案例 >

2018.11.08【CodeForces989】E. A Trance of Nightfall(矩阵快速幂)(倍增)

戴靖
2023-12-01

传送门


解析:

考场上本来想写倍增来着结果发现这个东西可以矩阵快速幂转移,所以倍增数组就用来优化矩阵快速幂了。。。(省去每次求出转移矩阵的一个 O ( n ) O(n) O(n),过会看完优化的第二种就行了)

如果只有一次询问,我们只需要求出每个点在跑了 m m m次后以及每条线在跑了 m m m次后到达目标点的距离。这个显然只需要把每个点走 m − 1 m-1 m1次的概率 D P DP DP出来计算直线最大,因为直线需要走一步到点上面,然后把每个点走 m m m次的概率求出来,所有点取最大。然后两者取一下最大。

有一个地方需要读者思考一下,为什么不需要考虑线的交点?

很显然的结论,考场上只卡了我30s就想出来了,因为几个数的平均数肯定不大于它们当中的最大数,选择交点一定不会比直接选择概率最大的线优。

那么怎么计算每个点走 s t e p step step步达到 g o a l goal goal的概率。考虑现在已经求出矩阵 A i A_i Ai A i , u , v A_{i,u,v} Ai,u,v表示 u u u i i i步后停留在节点 v v v的概率,我们只要知道 A 1 A_1 A1。就可以知道 A i + 1 A_{i+1} Ai+1. A i + 1 , u , v = ∑ j = 1 n A i , u , j × A 1 , j , v A_{i+1,u,v}=\sum_{j=1}^{n}A_{i,u,j}\times A_{1,j,v} Ai+1,u,v=j=1nAi,u,j×A1,j,v

然后发现这个东西其实就是矩阵乘法,于是可以矩阵快速幂优化一下。

然后就是每次都来一下 O ( n 3 log ⁡ m ) O(n^3\log m) O(n3logm)的快速幂时间复杂度是会爆炸的。所以我们可以再优化一下。

第一种优化是基于 B S G S BSGS BSGS的思想,我们只需要预处理所有 A k m A_{k\sqrt{m}} Akm 的矩阵以及所有 ( 0 ≤ i ≤ m ) A i (0\leq i \leq \sqrt{m})A_i (0im )Ai就可以直接找到两个矩阵做一次 O ( n 3 ) O(n^3) O(n3)的乘法就可以构造任意矩阵,优化效果 O ( n 3 log ⁡ m ) − > O ( n 3 ) O(n^3\log m)->O(n^3) O(n3logm)>O(n3),没测试,不知道能不能过,不过应该不可以,刚好卡复杂度上界啊。。。

第二种优化直接优化一个 O ( n ) O(n) O(n)
由于我们只需要求出到达 g o a l goal goal的距离,所以构造一个只有一列的矩阵 g g g,其中 g g o a l = 1 g_{goal}=1 ggoal=1,表示一步不走的时候只有 g o a l goal goal到达 g o a l goal goal的概率是 1 1 1。用这个矩阵转移出来任何时候都只有一列矩阵,矩阵乘法的复杂度就降为 O ( n 2 ) O(n^2) O(n2),优化效果 O ( n 3 log ⁡ m ) − > O ( n 2 log ⁡ m ) O(n^3\log m)->O(n^2\log m) O(n3logm)>O(n2logm)

那么考虑怎么求出最初的矩阵 A 1 A_1 A1?

首先我们需要处理出哪些点在同一条直线上面,这个用点斜式判一下就好了。
我们同时需要处理出经过一个点的有多少直线。同直线上的点相互到达的概率就是 1 / s i z e ( l i n e ) 1/size(line) 1/size(line),其中 s i z e size size表示直线上的点数。

同时每个点向外的初始转移需要除以 c n t u cnt_u cntu,其中 c n t u cnt_u cntu表示经过点 u u u的有多少直线。

然后我们就用倍增预处理出矩阵就好了


代码(由于很多该封装的没有封装,所以很丑,但是也正因为这样才跑的那么快,所以不封装有不封装的好处):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const

inline int getint(){
	re int num;
	re char c;
	re bool f=0;
	while(!isdigit(c=gc()))if(c=='-')f=1;num=c^48;
	while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
	return f?-num:num;
}

cs int N=202,logM=15;

int x[N],y[N],cnt[N];
vector<int> G[N][N];
double f[logM+1][N][N];
double g[N];
double tmp[N];
bool vis[N];
vector<pair<int,int> > line;

inline bool judge(int u,int v,int i){
	return (x[i]-x[v])*(y[v]-y[u])==(x[v]-x[u])*(y[i]-y[v]);
}
int n,q;
signed main(){
	n=getint();
	for(int re i=1;i<=n;++i){
		x[i]=getint();
		y[i]=getint();
	}
	for(int re u=1;u<=n;++u){
		memset(vis,0,sizeof vis);
		for(int re v=1;v<=n;++v){
			if(u==v)continue;
			if(vis[v])continue;++cnt[u];
			for(int re i=1;i<=n;++i){
				if(judge(u,v,i))G[u][v].push_back(i),vis[i]=true;
			}
			line.push_back(make_pair(G[u][v][0],G[u][v][1]));
		}
	}
	
	sort(line.begin(),line.end());
	line.erase(unique(line.begin(),line.end()),line.end());
	
	for(int re kkk=0;kkk<line.size();++kkk){
		vector<int> &vec=G[line[kkk].first][line[kkk].second];
		for(int re i=0;i<vec.size();++i)
		for(int re j=0;j<vec.size();++j)f[0][vec[i]][vec[j]]+=1.0/(1.0*vec.size());
	}
	
	for(int re i=1;i<=n;++i){
		for(int re j=1;j<=n;++j)f[0][i][j]/=cnt[i];
	}
	
	for(int re i=1;i<=logM;++i){
		for(int re u=1;u<=n;++u)
		for(int re v=1;v<=n;++v)
		if(f[i-1][u][v]>1e-6)for(int re k=1;k<=n;++k)f[i][u][k]+=f[i-1][u][v]*f[i-1][v][k];
	}
	
	q=getint();
	while(q--){
		int goal=getint(),step=getint()-1;
		memset(g,0,sizeof g);
		g[goal]=1;
		for(int re i=0;i<=logM;++i){
			if((1<<i)>step)break;
			if((1<<i)&step){
				memset(tmp,0,sizeof tmp);
				for(int re u=1;u<=n;++u)
				if(g[u]>1e-6)for(int re k=1;k<=n;++k)
				tmp[k]+=f[i][k][u]*g[u];
				memcpy(g,tmp,sizeof tmp);
			}
		}
		
		double ans=0.0;
		for(int re i=0;i<line.size();++i){
			vector<int> &vec=G[line[i].first][line[i].second];
			double ttt=0;
			for(int re i=0;i<vec.size();++i){
				ttt+=g[vec[i]];
			}
			ttt/=1.0*vec.size();
			ans=max(ttt,ans);
		}
		
		memset(tmp,0,sizeof tmp);
		for(int re u=1;u<=n;++u)
		if(g[u]>1e-6)for(int re k=1;k<=n;++k)
		tmp[k]+=f[0][k][u]*g[u];
		memcpy(g,tmp,sizeof tmp);
		
		for(int re u=1;u<=n;++u)ans=max(ans,g[u]);
		
		printf("%.12lf\n",ans);
	}
	return 0;
}
 类似资料: