CF627E Orchestra [矩阵计数]

鄂曦之
2023-12-01

也许更好的阅读体验

D e s c r i p t i o n \mathcal{Description} Description

题目大意 有一个 r ∗ c r * c rc的矩阵上有 n n n个点,问有多少个子矩阵里包含至少 k k k个点

输入格式 第一行四个数 r , c , n , k r,c,n,k r,c,n,k,接下来 n n n行,每行两个数 x i , y i x_i,y_i xi,yi,表示第 i i i个点的坐标

输出格式 一行一个数表示合法子矩阵个数

数据范围 1 ≤ r , c , n ≤ 3000 , 1 ≤ k ≤ min ⁡ ( n , 10 ) 1\leq r,c,n\leq 3000 , 1\leq k\leq \min(n,10) 1r,c,n3000,1kmin(n,10)

S o l u t i o n \mathcal{Solution} Solution
这题好难

题目要求至少包含了 k k k个点的子矩阵, k k k很小
朴素的想法是枚举上边界,下边界,然后在中间尺取,这样是 O ( n 3 ) O\left(n^3\right) O(n3)

反过来考虑,从下往上枚举每一行
先算出以这一行的每个顶点为左上角,最下面一行为第 c c c的所有合法矩阵个数
计算方法考虑尺取即可
再考虑最下面一行逐渐往上走的合法矩阵个数

这个过程中,点是在不断变少的
考虑算出将最下面一行变为其上面那一行时合法矩阵的减少量,再用原来的合法矩阵量减去,得到最下面一行往上移一行后的合法矩阵数量

如何计算合法矩阵的减少量
枚举这一行中的每个点,计算出包含这个点且点数刚好为 k k k的矩阵个数,然后删去这个点
计算方法也是尺取,另外,尺取时用一个链表表示该行中下一个点和上一个点的位置,这样尺取复杂度就是该行点数,而不是 m m m
这样就能算出删去这一行中所有点后不合法的矩阵个数
那么往上移一行合法矩阵就会减去这些矩阵

如此计算,总复杂度为
O ( n 2 k + n p k ) O\left(n^2k+npk\right) O(n2k+npk)

如有不懂之处,请参考代码
如仍有不懂之处可在评论提出,博主看到会给予回复
C o d e \mathcal{Code} Code

/*******************************
Author:Morning_Glory
LANG:C++
Created Time:2019年10月15日 星期二 15时37分21秒
*******************************/
#include <cstdio>
#include <fstream>
#include <vector>
#define ll long long
using namespace std;
const int maxn = 3005;
int n,m,p,k;
int s[maxn],t[maxn],last[maxn],nxt[maxn];
ll ans;
vector <int> dot[maxn];
int main()
{
	scanf("%d%d%d%d",&n,&m,&p,&k);//求n行m列p个点至少包含k个点的矩阵个数
	for (int i=1;i<=p;++i){
		int x,y;
		scanf("%d%d",&x,&y);
		dot[x].push_back(y);//每行的点
	}

	for (int u=n;u>=1;--u){
		for (int i=0;i<dot[u].size();++i)	++s[dot[u][i]];//s[i] -> 第n行到第u行上第i列的点数
		int l=1,r=0,num=0;
		ll inc=0;
		while (true){
			while (r+1<=m&&num<k)	num+=s[++r];
			if (num<k)	break;
			while (l<=r&&num>=k)	num-=s[l++],inc+=m-r+1;//inc 合法矩阵个数
		}

		for (int i=1;i<=m;++i)	t[i]=s[i];
		for (int i=0;i<=m+1;++i)	last[i]=i-1,nxt[i]=i+1;
		for (int i=1;i<=m;++i)
			if (!s[i])	last[nxt[i]]=last[i],nxt[last[i]]=nxt[i];

		for (int d=n;d>=u;--d){
			ans+=inc;

			for (int i=0;i<dot[d].size();++i){

				int cur=dot[d][i],sl=t[cur],sr=0;
				l=r=cur;
				while (nxt[r]<=m&&sl+sr+t[nxt[r]]<=k)	sr+=t[r=nxt[r]];

				while (true){
					if (sl+sr==k)	inc-=(l-last[l])*(nxt[r]-r);
					sl+=t[l=last[l]];
					if (!l||sl>k)	break;
					while (sl+sr>k)	sr-=t[r],r=last[r];
				}

				--t[cur];
				if (!t[cur])	last[nxt[cur]]=last[cur],nxt[last[cur]]=nxt[cur];
			}
		}
	}

	printf("%lld\n",ans);
	return 0;
}

如有哪里讲得不是很明白或是有错误,欢迎指正
如您喜欢的话不妨点个赞收藏一下吧

 类似资料: