Alien 的数列

茅秦斩
2023-12-01

DP好题:Alien 的数列题解

题目背景

Alien 们很迷信,所以对于一个数列,它们如果觉得它不吉利,就要将这个数列进行处理,但处理方式很诡异。

题目描述

对于一个数列 A1, A2, A3 … AN,如果它不是不下降的,那么 Alien们就认为这个是不吉利的。Alien 们要尽力把不吉利的数列修改成为吉利的,可以把这个数列的中的某个数修改为 New,代价是|Ai-New|。 现在它们委托你帮忙修改一下, 你的目的是将它们给出的一个数列改成不下降的,而且代价最小。

输入格式

第一行一个整数 N。
接下来 N 行,每行一个数,表示 Alien 们给出的原数列。

输出格式

输出一行一个数 Ret,表示最小代价。

题目窗口

点此进入,报名即可看题

思路解析

这道题的 O ( n 3 ) O(n^3) O(n3) O ( n 2 ) O(n^2) O(n2) O ( n 2 l o g ( n ) ) O(n^2 log(n)) O(n2log(n)) O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n))做法应该都有。但 O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n))做法对我来说比较困难,大体是在 O ( n 2 ) O(n^2) O(n2)基础上根据图像与数据结构优化的,在这里就不做重点了。
相信大家一看到这题,就会想到DP,而且感觉有些套路。我们可以设 d p [ i ] [ j ] dp[i][j] dp[i][j]为到第 i i i位,且第 i i i位修改为 j j j的满足条件的最小代价。在这里, j j j肯定不是任意一个整数,不然空间时间都会爆。我们可以瞎胡一波结论,就是数值 j j j肯定在原数列中。这是因为在实际操作过程中,一个数要么不变,要么改为与它左边的数或右边的数相同(数列不下降即可)。如果有疑问,你可以假设它不是这么修改的,那么你可以一波微操,以相同的代价按上述方法修改。好了,现在我们的 d p [ i ] [ j ] dp[i][j] dp[i][j]为到第 i i i位,且第 i i i位修改为 a [ j ] a[j] a[j]的满足条件的最小代价。我们可以推出转移方程式: d p [ i ] [ j ] = ∣ a [ i ] − a [ j ] ∣ + m i n ( d p [ i − 1 ] [ k ] ) dp[i][j]= |a[i]-a[j]|+min(dp[i-1][k]) dp[i][j]=a[i]a[j]+min(dp[i1][k]) a [ k ] ≤ a [ j ] a[k] \leq a[j] a[k]a[j]。于是 O ( n 3 ) O(n^3) O(n3)的做法就浮出水面,暴力枚举 i , j , k i,j,k ijk即可。
然而,我们知道这样有一些时间被浪费了,因为有一些 d p [ i − 1 ] [ k ] dp[i-1][k] dp[i1][k]是显然不能转移到 d p [ i ] [ j ] dp[i][j] dp[i][j]的。所以我们可以在每个 i i i循环下,对 d p [ i − 1 ] [ k ] dp[i-1][k] dp[i1][k]按照 a [ k ] a[k] a[k]从小到大进行排序,若 a [ k ] a[k] a[k]相等,则 d p [ i − 1 ] [ k ] dp[i-1][k] dp[i1][k]的值小的排在前面。这样,我们按照 排序后的顺序 枚举 j j j,每次不断取 m i n min min值,当更新到当前 j j j的时候,就把 m i n min min赋值给 d p [ i ] [ j ] dp[i][j] dp[i][j],因为再接下来枚举的都是 a [ j ] a[j] a[j]大于它的,没必要了,就算是等于,我们已经把值小的排在前面了。在代码中,我用了一个临时结构体数组 p p p来实现转移。这样的时间复杂度是 O ( n 2 l o g ( n ) ) O(n^2 log(n)) O(n2log(n))的,跑 n = 5000 n=5000 n=5000时,常规要5 ~ 6s左右,卡常后也要跑3 ~ 4s左右。不过 n = 2000 n=2000 n=2000可以跑1s内。
那么,就继续优化吧。我们可以发现每一次排序都是按 a a a数组大小排的。而 a a a数组从头到尾都没变过。每次枚举 i i i都排一遍有些浪费。我们是否可以事先排好呢?答案是肯定的。突破口还是在 d p [ i ] [ j ] dp[i][j] dp[i][j]的更新顺序上。详细就见代码吧。

代码

O ( n 3 ) O(n^3) O(n3)

//ZJ_MRZ's Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#define N 5010
#define INF 1000000000000
using namespace std;
long long dp[N][N];
long long n,a[N];
int main() {
	scanf("%lld",&n);
	for(int i=1;i<=n;i++)
		scanf("%lld",&a[i]);
	for(int i=1;i<=n;i++) {
		dp[1][i]=abs(a[i]-a[1]);
	}
	for(int i=2;i<=n;i++) {
		for(int j=1;j<=n;j++) {
			dp[i][j]=INF;
			for(int k=1;k<=n;k++) {
				if(a[j]>=a[k]) {
					dp[i][j]=min(dp[i][j],dp[i-1][k]);
				}
			}
			dp[i][j]+=abs(a[i]-a[j]);
		}
	}
	long long ans=INF;
	for(int i=1;i<=n;i++) {
		ans=min(ans,dp[n][i]);
	}
	printf("%lld\n",ans);
	return 0;
}

O ( n 2 l o g ( n ) ) O(n^2log(n)) O(n2log(n))

//ZJ_MRZ's Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#define N 5010
#define INF 1000000000000
using namespace std;
long long dp[N][N];
long long n,a[N];
struct mrz {
	long long aa,dd;
	int id;
} p[N];
bool cmp(mrz k1,mrz k2) {
	if(k1.aa!=k2.aa)
		return k1.aa<k2.aa;
	else
		return k1.dd<k2.dd;
}
int main() {
	scanf("%lld",&n);
	for(int i=1;i<=n;i++)
		scanf("%lld",&a[i]);
	for(int i=1;i<=n;i++) {
		dp[1][i]=abs(a[i]-a[1]);
	}
	for(int i=2;i<=n;i++) {
		for(int j=1;j<=n;j++) {
			p[j].dd=dp[i-1][j];
			p[j].aa=a[j];
			p[j].id=j;
			dp[i][j]=INF;
		}
		sort(p+1,p+n+1,cmp);
		long long minn=INF;
		for(int j=1;j<=n;j++) {
			minn=min(minn,p[j].dd);
			dp[i][p[j].id]=minn;
		}
		for(int j=1;j<=n;j++) {
			dp[i][j]+=abs(a[i]-a[j]);
		}
	}	
	long long ans=INF;
	for(int i=1;i<=n;i++) {
		ans=min(ans,dp[n][i]);
	}
	printf("%lld\n",ans);
	return 0;
}

O ( n 2 ) O(n^2) O(n2)

//ZJ_MRZ's Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#define N 5010
#define INF 1000000000000000
using namespace std;
long long dp[N][N];
long long n,a[N];
struct mrz {
	long long val;
	int id;
} p[N];
inline bool cmp(mrz k1,mrz k2) {
	return k1.val<k2.val;
}
inline long long readl() {
    long long x=0,f=1;char ch=getchar();
    while(ch>'9' || ch<'0') { if(ch=='-') f=-1; ch=getchar(); }
    while(ch>='0' && ch<='9') { x=x*10+ch-'0'; ch=getchar(); }
    return x*f;
}
int main() {
	n=readl();
	for(int i=1;i<=n;i++) {
		a[i]=readl();
		p[i].val=a[i];
		p[i].id=i;
	}
	sort(p+1,p+n+1,cmp);
	for(int i=1;i<=n;i++) {
		dp[1][i]=abs(a[i]-a[1]);
	}
	for(int i=2;i<=n;i++) {
		for(int j=1;j<=n;j++) {
			dp[i][j]=INF;
		}
		long long minn=INF,ret;
		int ls=1;
		for(int j=1;j<=n;j++) {
			ret=minn;
			minn=min(minn,dp[i-1][p[j].id]);
			while(p[j].val>p[ls].val) {
				dp[i][p[ls].id]=ret;
				ls++;
			}
		}
		dp[i][p[n].id]=minn;
		for(int j=1;j<=n;j++) {
			dp[i][j]+=abs(a[i]-a[j]);
		}
	}	
	register long long ans=INF;
	for(int i=1;i<=n;i++) {
		ans=min(ans,dp[n][i]);
	}
	printf("%lld\n",ans);
	return 0;
}
 类似资料: