当前位置: 首页 > 工具软件 > Salad > 使用案例 >

P3564 [POI2014]BAR-Salad Bar(ST表 + 二分)

程胤运
2023-12-01

P3564 [POI2014]BAR-Salad Bar

给定一个长度为 n n n的数组,里面元素只有 1 1 1 − 1 -1 1,问选出一个长度为 l e n len len的区间使得,这个区间的前缀和时刻大于零,后缀和时刻大于零,输出最大长度 l e n len len

考虑枚举 l l l端点,我们可以二分出最大的 r r r,满足 p r e _ s u m pre\_sum pre_sum时刻大于等于零,设为 [ l , r ] [l, r] [l,r]

考虑枚举 R R R端点,我们可以二分出最小的 L L L,满足 s u c _ s u m suc\_sum suc_sum时刻大于等于零,设为 [ L , R ] [L, R] [L,R]

则答案一定是在所有上述点对中的 l , R l, R l,R中的一个,且有 l ≤ R ≤ r l \leq R \leq r lRr L ≤ l ≤ R L \leq l \leq R LlR

假设我们已经把上述满足要求的两种点对都算出来了,考虑新开一个线段树,

我们把第二种点对 [ L , R ] [L, R] [L,R],放进线段树上维护,在 R R R点记录符合要求的最小的 L L L

考虑枚举 [ l , r ] [l, r] [l,r]点对,在区间 [ l , r ] [l, r] [l,r]中寻找一个最大的 R R R,使得于 R R R相对应的 L L L,满足 L ≤ l L \leq l Ll,这个时候我们的答案就是 R − l + 1 R - l + 1 Rl+1的最大值了。

上述操作都利用 S T ST ST表,然后二分一下即可,整体复杂度 O ( n log ⁡ n ) O (n \log n) O(nlogn)

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 10, logn = 20;

int a[N], sum[N], b[N], Log[N], f[N][logn + 1], n;

char str[N];

vector<pair<int, int>> vt, v;

void init() {
  Log[1] = 0, Log[2] = 1;
  for (int i = 3; i < N; i++) {
    Log[i] = Log[i / 2] + 1;
  }
}

int main() {
  // freopen("in.txt", "r", stdin);
  // freopen("out.txt", "w", stdout);
  init();
  scanf("%d %s", &n, str + 1);
  for (int i = 1; i <= n; i++) {
    a[i] = str[i] == 'p' ? 1 : -1;
  }
  for (int i = 1; i <= n; i++) {
    sum[i] = a[i] + sum[i - 1], f[i][0] = sum[i];
  }
  for (int j = 1; j <= logn; j++) {
    for (int i = 1; i + (1 << j) - 1 <= n; i++) {
      f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
    }
  }
  for (int i = 1; i <= n; i++) {
    if (a[i] == -1) {
      continue;
    }
    int l = i, r = n;
    while (l < r) {
      int mid = l + r + 1 >> 1, s = Log[mid - i + 1];
      if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {
        l = mid;
      }
      else {
        r = mid - 1;
      }
    }
    // printf("%d %d\n", i, l);
    vt.push_back({i, l});
  }
  for (int i = 1; i <= n; i++) {
    sum[i] = a[n - i + 1] + sum[i - 1], f[i][0] = sum[i];
  }
  for (int j = 1; j <= logn; j++) {
    for (int i = 1; i + (1 << j) - 1 <= n; i++) {
      f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
    }
  }
  memset(b, 0x3f, sizeof b);
  for (int i = 1; i <= n; i++) {
    if (a[n - i + 1] == -1) {
      continue;
    }
    int l = i, r = n;
    while (l < r) {
      int mid = l + r + 1 >> 1, s = Log[mid - i + 1];
      if (min(f[i][s], f[mid - (1 << s) + 1][s]) >= sum[i - 1]) {
        l = mid;
      }
      else {
        r = mid - 1;
      }
    }
    b[n - i + 1] = n - l + 1;
    // printf("%d %d\n", n - l + 1, n - i + 1);
  }
  for (int i = 1; i <= n; i++) {
    f[i][0] = b[i];
  }
  for (int j = 1; j <= logn; j++) {
    for (int i = 1; i + (1 << j) - 1 <= n; i++) {
      f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
    }
  }
  int ans = 0;
  for (auto it : vt) {
    int L = it.first, R = it.second;
    int l = it.first, r = it.second;
    while (L < R) {
      // [mid + 1, r]
      int mid = L + R >> 1, s = Log[r - mid];
      if (min(f[mid + 1][s], f[r - (1 << s) + 1][s]) <= l) {
        L = mid + 1;
      }
      else {
        R = mid;
      }
    }
    ans = max(ans, L - l + 1);
  }
  printf("%d\n", ans);
  return 0;
}
 类似资料: