(以下讲解的下标均从1开始)
给出一个数组 a[1~n],a[i] = 1或2, 求每个区间的众数之和。
我们假设 two[i] 为 a[1~i] 的“2”的数量,one[i] 为 a[1~i] 的“1”的数量,假设m为众数为2的区间数量,那么有下面的公式:
其中[...]表示如果括号内的条件满足,则=1,否则=0.
这个式子可以转换为:
设一个数组 diff[i] = two[i] - one[i]。那么式子再写成:
可以看到这个式子非常的熟悉,可能有人会想到一维偏序什么的。
首先,-n <= diff[i] <= n。我们不喜欢负数的存在,由于比较大小时两边都加同一个数没啥毛病,考虑每个diff[i]都 +(n+1),这样就都是正数了。
回到正题,如果暴力求解这个式子的话,显然要用O(n^2)的查询。我们希望遍历r时,可以马上获取数组中在r左边,且 < diff[r]的数的数量,这里,我当时就用的线段树。
(不知道线段树的朋友可以先去了解一下)
建立一个线段树tree,tree[i] 表示树的第i号节点,我们先假设其表示的区间为[left,right]。则tree[i]的值就是diff[1~r-1]的所有值处于[left,right]的元素数量。比如diff[] = {3,5,4,7,1, ... },[left,right] = [1,4],假设r此时为5.则tree[i]就是3,5,4,7(下标<r)中位于[1,4]的数:3,4,它们数量的和:2。则tree[i] = 2.
假设我们处理了diff[1~r-1],把它们写到了线段树里,现在轮到diff[r]了。
我们要判断diff[1~r-1]中有多少<diff[r]的,怎么办?我们就去线段树中去找在区间[left,right] = [0,diff[r]-1]内的tree值。这个[0,diff[r]-1]是可以用线段树的性质拼接出来的,对每个组成其一部分的子区间的tree值求和即可得到答案,对应的增加众数为2的区间数量的贡献。
查找完之后我们还要将diff[r]的值加到线段树里,这样在继续找diff[r+1,n]的答案时可以把diff[r]的贡献算进去。
最后,由于总区间个数为sum = n*(n+1)/2,众数为2的区间个数为m,则众数为1的区间个数为p = sum - m,答案为m*2 + p。
时间复杂度O(nlogn),其中线段树查询和单点修改为O(logn),做n次
空间复杂度O(n)
#include <iostream> #include <set> #include <algorithm> #include <vector> #include <cstring> using namespace std; using ll = long long; const int N = (int)2e5+5; ll n; ll a[200005]; ll tree[N << 3]; // 开大点没关系 ll diff[200005]; // 将数值x放入线段树,此时所在的区间为[l,r],线段树下标为id void push(ll x,int l,int r,int id){ if(l >= r){ tree[id]++; return; } int mid = (l+r)>>1; if(x <= mid){ push(x,l,mid,id<<1); } else{ push(x,mid+1,r,(id<<1)+1); } tree[id] = tree[id<<1] + tree[(id<<1)+1]; } // 求<x的diff数量 ll getsum(ll x,int l,int r,int id){ if(l >= r){ if(x <= l) return 0; return tree[id]; } ll res = 0; int mid = (l+r)>>1; if(x > mid){ res += tree[id<<1]; res += getsum(x,mid+1,r,(id<<1)+1); } else{ res += getsum(x,l,mid,(id<<1)); } return res; } void solve(){ ll m = 0; for(int i = 1;i <= n;++i){ cin >> a[i]; // 1 <= a[i] <= 2 diff[i] = diff[i-1] + (a[i] == 2 ? 1 : -1); } for(int i = 0;i <= n;++i){ diff[i] += n+1; // 保证所有数为正数 } push(diff[0],0,2*n+1,1); // 由于l-1可以为0,因此先把diff[0]放进去 for(int i = 1;i <= n;++i){ m += getsum(diff[i],0,2*n+1,1); // 从最大的区间[0,2n+1]开始递归查询 push(diff[i],0,2*n+1,1); // 单点修改,其所在的区间的值都要+1 } ll p = n*(n+1)/2 - m; ll ans = m*2 + p; cout << ans << "\n"; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); while(cin >> n){ solve(); } return 0; }