给定一个长度为 n n n的数组,里面元素只有 1 1 1跟 − 1 -1 −1,问选出一个长度为 l e n len len的区间使得,这个区间的前缀和时刻大于零,后缀和时刻大于零,输出最大长度 l e n len len,
考虑枚举 l l l端点,我们可以二分出最大的 r r r,满足 p r e _ s u m pre\_sum pre_sum时刻大于等于零,设为 [ l , r ] [l, r] [l,r],
考虑枚举 R R R端点,我们可以二分出最小的 L L L,满足 s u c _ s u m suc\_sum suc_sum时刻大于等于零,设为 [ L , R ] [L, R] [L,R],
则答案一定是在所有上述点对中的 l , R l, R l,R中的一个,且有 l ≤ R ≤ r l \leq R \leq r l≤R≤r, L ≤ l ≤ R L \leq l \leq R L≤l≤R,
假设我们已经把上述满足要求的两种点对都算出来了,考虑新开一个线段树,
我们把第二种点对 [ L , R ] [L, R] [L,R],放进线段树上维护,在 R R R点记录符合要求的最小的 L L L,
考虑枚举 [ l , r ] [l, r] [l,r]点对,在区间 [ l , r ] [l, r] [l,r]中寻找一个最大的 R R R,使得于 R R R相对应的 L L L,满足 L ≤ l L \leq l L≤l,这个时候我们的答案就是 R − l + 1 R - l + 1 R−l+1的最大值了。
上述操作都利用 S T ST ST表,然后二分一下即可,整体复杂度 O ( n log n ) O (n \log n) O(nlogn)。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10, logn = 20;
int a[N], sum[N], b[N], Log[N], f[N][logn + 1], n;
char str[N];
vector<pair<int, int>> vt, v;
void init() {
Log[1] = 0, Log[2] = 1;
for (int i = 3; i < N; i++) {
Log[i] = Log[i / 2] + 1;
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
init();
scanf("%d %s", &n, str + 1);
for (int i = 1; i <= n; i++) {
a[i] = str[i] == 'p' ? 1 : -1;
}
for (int i = 1; i <= n; i++) {
sum[i] = a[i] + sum[i - 1], f[i][0] = sum[i];
}
for (int j = 1; j <= logn; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
}
}
for (int i = 1; i <= n; i++) {
if (a[i] == -1) {
continue;
}
int l = i, r = n;
while (l < r) {
int mid = l + r + 1 >> 1, s = Log[mid - i + 1];
if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {
l = mid;
}
else {
r = mid - 1;
}
}
// printf("%d %d\n", i, l);
vt.push_back({i, l});
}
for (int i = 1; i <= n; i++) {
sum[i] = a[n - i + 1] + sum[i - 1], f[i][0] = sum[i];
}
for (int j = 1; j <= logn; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
}
}
memset(b, 0x3f, sizeof b);
for (int i = 1; i <= n; i++) {
if (a[n - i + 1] == -1) {
continue;
}
int l = i, r = n;
while (l < r) {
int mid = l + r + 1 >> 1, s = Log[mid - i + 1];
if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {
l = mid;
}
else {
r = mid - 1;
}
}
b[n - i + 1] = n - l + 1;
// printf("%d %d\n", n - l + 1, n - i + 1);
}
for (int i = 1; i <= n; i++) {
f[i][0] = b[i];
}
for (int j = 1; j <= logn; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
}
}
int ans = 0;
for (auto it : vt) {
int L = it.first, R = it.second;
int l = it.first, r = it.second;
while (L < R) {
// [mid + 1, r]
int mid = L + R >> 1, s = Log[r - mid];
if (min(f[mid + 1][s], f[r - (1 << s) + 1][s]) <= l) {
L = mid + 1;
}
else {
R = mid;
}
}
ans = max(ans, L - l + 1);
}
printf("%d\n", ans);
return 0;
}