当前位置: 首页 > 工具软件 > Nightfall > 使用案例 >

codeforces 989 E-A Trance of Nightfall(矩阵乘法优化DP)

贡斌
2023-12-01

感觉是一道非常优秀的题(对于我这种蒟蒻)。简略的题解。


思路

首先假设初始点已经钦定了,那么只要DP就可以算出他到每个点的概率。这里DP显然可以用矩阵乘法优化,转移矩阵就是走一步的到某个点概率。

这里关于如何找某个点有多少条直线经过,然后每条线上有哪些点的方法,我的可能不是很优秀。直接 O ( n 2 ) O(n^2) O(n2)求出每条直线的解析式(用了 y = k x + b y=kx+b y=kx+b的表达方式,这样判断是否是同一条直线比较方便,但是需要特判垂直于x轴的直线),排序去重。然后再 O ( n ∗ l ) O(n*l) O(nl)求每个点有多少条直线经过,每条线上有哪些点,存下来(后面找答案要用到)。

然后哪些初始点有可能得到最大的概率呢?第一是选一个不在S中的点,让这个点第一步只走到一条直线上(可以知道走到多条直线上不可能更优),这种相当于选某一条线,走m-1步。第二是选S中的一个点,他可能第一步就有多种走法。

所以现在复杂度是 O ( n 3 + q ∗ n ∗ n 2 ∗ l o g m ) O(n^3+q*n*n^2*logm) O(n3+qnn2logm)(q后面第一个n是枚举起始点, n 2 n^2 n2矩阵乘法)。

然后后面的我并不能自己想出来

然后因为这么求其实非常浪费,并不需要知道到t之外的点的概率。于是考虑倒过来求从t出发倒着跑m步的概率,于是少了一个复杂度n,成功变为 O ( n 3 + q ∗ n 2 ∗ l o g m ) O(n^3+q*n^2*logm) O(n3+qn2logm)

代码:

#include<bits/stdc++.h>
using namespace std;
const int N = 210;
const int E = 16;
const double eps = 1e-7;
int n, ln, q;
struct POINT{
	int x, y;
}p[N];
struct LINE{
	bool fl;
	double k, b;
	void get(POINT u, POINT v){
		if (u.x == v.x){
			fl = 1;
			k = u.x;
			b = 0;
			return;
		}
		fl = 0;
		k = 1.0*(u.y-v.y)/(u.x-v.x);
		b = 1.0*u.y-k*u.x;
	}
	bool operator < (const LINE &l){
		if (fl != l.fl) return fl > l.fl;
		if (abs(l.k-k) > eps) return k < l.k;
		return b < l.b;
	}
	bool operator != (const LINE &l){
		if (fl != l.fl) return true;
		if (fl) return abs(k-l.k) > eps;
		return abs(l.k-k) > eps || abs(l.b-b) > eps;
	}
}l[N*N];
vector<int> vl[N*N], vp[N];
double f[E][N][N], a[N];

void chkMax(double &x, double y){if (x < y) x = y;}

bool onLine(POINT u, LINE l)
{
	if (l.fl) return abs(u.x-l.k) < eps;
	return abs(l.k*u.x+l.b-u.y) < eps;
}

void Mul1(int o)
{
	double b[N];
	memset(b, 0, sizeof(b));
	for (int i = 0; i < n; ++ i)
		for (int j = 0; j < n; ++ j)
			b[i] += a[j]*f[o][i][j];
	for (int i = 0; i < n; ++ i)
		a[i] = b[i];
}

int main()
{
	scanf("%d", &n);
	for (int i = 0; i < n; ++ i)
		scanf("%d%d", &p[i].x, &p[i].y);
	ln = 0;
	for (int i = 0; i < n; ++ i)
		for (int j = i+1; j < n; ++ j)
			l[ln++].get(p[i], p[j]);
	sort(l, l+ln);
	{
		int lln = ln;
		ln = 1;
		for (int i = 1; i < lln; ++ i)
			if (l[i] != l[i-1])
				l[ln++] = l[i];
	}
	// cout << ln << endl;
	// for (int i = 0; i < ln; ++ i) cout << l[i].k << " " << l[i].b << endl;
	// cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << endl;

	for (int i = 0; i < n; ++ i)
		for (int j = 0; j < ln; ++ j)
			if (onLine(p[i], l[j])){
				// cout << i << " " << j << " " << endl;
				vp[i].push_back(j);
				vl[j].push_back(i);
			}
	// for (int i = 0; i < n; ++ i) cout << vp[i].size() << " "; cout << endl;

	memset(f, 0, sizeof(f));
	for (int i = 0; i < n; ++ i){
		int lln = vp[i].size();
		for (int j = 0; j < lln; ++ j){
			int nowl = vp[i][j];
			for (int k = 0, ppn = vl[nowl].size(); k < ppn; ++ k){
				int nxtp = vl[nowl][k];
				f[0][i][nxtp] += 1.0/lln/ppn;
				// cout << i << " " << nowl << " " << nxtp << " " << ppn << endl;
			}
		}
	}
	// for (int i = 0; i < n; ++ i){
	// 	for (int j = 0; j < n; ++ j)
	// 		cout << f[0][i][j] << " " << endl;
	// 	cout << endl;
	// }

	for (int o = 1; o < E; ++ o){
		for (int k = 0; k < n; ++ k)
			for (int i = 0; i < n; ++ i)
				for (int j = 0; j < n; ++ j){
					f[o][i][j] += f[o-1][i][k]*f[o-1][k][j];
				}
	}

	for (scanf("%d", &q); q; -- q){
		int t, m;
		double ans = 0;
		scanf("%d%d", &t, &m);
		-- t; -- m;
		memset(a, 0, sizeof(a));
		a[t] = 1.0;
		for (int i = 0; i < E; ++ i)
			if ((m>>i)&1){
				Mul1(i);
				// cout << i << " !!!!!!!!!!!" << endl;
			}
		for (int i = 0; i < ln; ++ i){
			double tmp = 0;
			int lln = vl[i].size();
			for (int j = 0; j < lln; ++ j)
				tmp += a[vl[i][j]];
			chkMax(ans, tmp/lln);
		}
		Mul1(0);
		for (int i = 0; i < n; ++ i)
			chkMax(ans, a[i]);
		printf("%.20f\n", ans);
	}
	return 0;
}
 类似资料: