给你一个序列,你要把它划分成 m 个连续的段,以最小化这个东西:
把每一段的数和表示为 w[i],则要最小化每个 (w[i]+p)^2 的和。
首先你发现这个
p
p
p 是没有关系的,你完全可以把它拆开来。
变成
p
2
∗
m
+
2
∗
s
n
∗
p
p^2*m+2*s_{n}*p
p2∗m+2∗sn∗p。
(
s
i
s_i
si 是前缀和,下同)
首先我们考虑 DP:
f
i
,
j
f_{i,j}
fi,j 为当前第
i
i
i 个数,划分成了
m
m
m 段落。
f
i
,
j
=
min
k
=
0
i
−
1
f
k
,
j
−
1
+
(
s
i
−
s
k
+
p
)
2
f_{i,j}=\min\limits_{k=0}^{i-1} f_{k,j-1}+(s_i-s_{k}+p)^2
fi,j=k=0mini−1fk,j−1+(si−sk+p)2
然后我们考虑能不能斜率优化(看着就很斜率优化)。
显然这里我们可以做
m
m
m 次斜率优化来达到
O
(
n
m
)
O(nm)
O(nm) 的复杂度。
(所以下面就不写第二维了,反正都是从
j
−
1
j-1
j−1 转移到
j
j
j,比较的都是
j
−
1
j-1
j−1 的)
k
1
k_1
k1 比
k
2
k_2
k2 优:(
k
1
>
k
2
k_1>k_2
k1>k2)
f
k
1
+
(
s
i
−
s
k
1
)
2
<
f
k
2
+
(
s
i
−
s
k
2
)
2
f_{k_1}+(s_i-s_{k_1})^2<f_{k_2}+(s_i-s_{k_2})^2
fk1+(si−sk1)2<fk2+(si−sk2)2
f
k
1
+
s
k
1
2
−
2
s
i
s
k
1
<
f
k
2
+
s
k
2
2
−
2
s
i
s
k
2
f_{k_1}+s_{k_1}^2-2s_is_{k_1}<f_{k_2}+s_{k_2}^2-2s_is_{k_2}
fk1+sk12−2sisk1<fk2+sk22−2sisk2
s
i
(
2
s
k
2
−
2
s
k
1
)
<
(
f
k
2
+
s
k
2
2
)
−
(
f
k
1
+
s
k
1
2
)
s_i(2s_{k_2}-2s_{k_1})<(f_{k_2}+s_{k_2}^2)-(f_{k_1}+s_{k_1}^2)
si(2sk2−2sk1)<(fk2+sk22)−(fk1+sk12)
s
i
>
(
f
k
2
+
s
k
2
2
)
−
(
f
k
1
+
s
k
1
2
)
2
s
k
2
−
2
s
k
1
s_i>\dfrac{(f_{k_2}+s_{k_2}^2)-(f_{k_1}+s_{k_1}^2)}{2s_{k_2}-2s_{k_1}}
si>2sk2−2sk1(fk2+sk22)−(fk1+sk12)
然后我们考虑如何优化掉
j
j
j 的那一维。
那这个时候就有一个神奇的东西叫做 wqs 二分。
就是我们考虑转移的时候加上一个费用
C
C
C,这个费用我们二分。
那这个费用是为了干什么呢,可以说是为了“平衡”,使得恰好选到
m
m
m 个段的时候最优。
然后最后我们再减去这个贡献即可,减去
C
∗
m
C*m
C∗m。
然后你会发现斜率优化 DP 的转移没有变。
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
ll n, m, p, a[100501], sta[100501];
ll f[100501], sz[100501];
ll Y(ll x) {
return f[x] + a[x] * a[x];
}
ll X(ll x) {
return 2 * a[x];
}
bool check(ll C) {
f[0] = 0; sz[0] = 0;
f[1] = a[1] * a[1] + C; sz[1] = 1;
ll l = 1, r = 2;
sta[1] = 0; sta[2] = 1;
for (ll i = 2; i <= n; i++) {
while (l < r && Y(sta[l]) - Y(sta[l + 1]) > a[i] * (X(sta[l]) - X(sta[l + 1])))
l++;
f[i] = f[sta[l]] + (a[i] - a[sta[l]]) * (a[i] - a[sta[l]]) + C;
sz[i] = sz[sta[l]] + 1;
while (l < r && (Y(i) - Y(sta[r])) * (X(sta[r]) - X(sta[r - 1])) < (Y(sta[r]) - Y(sta[r - 1])) * (X(i) - X(sta[r])))
r--;
sta[++r] = i;
}
return sz[n] <= m;
}
int main() {
// freopen("divide.in", "r", stdin);
// freopen("divide.out", "w", stdout);
scanf("%lld %lld %lld", &n, &m, &p);
for (ll i = 1; i <= n; i++) scanf("%lld", &a[i]), a[i] = a[i - 1] + a[i];
ll l = 0, r = 1e18, re = 0;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid)) re = mid, r = mid - 1;
else l = mid + 1;
}
check(re);
printf("%lld", f[n] - re * m + m * p * p + 2 * a[n] * p);
return 0;
}