当前位置: 首页 > 工具软件 > ka-map > 使用案例 >

ACM-ICPC 2018 沈阳赛区网络预赛 J Ka Chang 分块

宋正真
2023-12-01

https://nanti.jisuanke.com/t/31451

对每层的个数分块
当这个深度的节点个数>block时
暴力维护每个点的子树有多少个这个深度的节点
这样的层数最多有n/block个
预处理复杂度O(n*n/block)
修改直接修改这个层数总共加了多少

当深度的节点个数<=block时
暴力修改这个深度的每个节点
用一个bit或者线段树维护树的dfs序区间和

查询分别对大块和小块查询
贡献就是bit上的子树dfs序区间和 + 每个大块的个数*大块上的权

#include <iostream>
#include <algorithm>
#include <sstream>
#include <string>
#include <queue>
#include <cstdio>
#include <map>
#include <set>
#include <utility>
#include <stack>
#include <cstring>
#include <cmath>
#include <vector>
#include <ctime>
#include <bitset>
#include <assert.h>
using namespace std;
#define pb push_back
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define ansn() printf("%d\n",ans)
#define lansn() printf("%lld\n",ans)
#define r0(i,n) for(int i=0;i<(n);++i)
#define r1(i,e) for(int i=1;i<=e;++i)
#define rn(i,e) for(int i=e;i>=1;--i)
#define mst(abc,bca) memset(abc,bca,sizeof abc)
#define lowbit(a) (a&(-a))
#define all(a) a.begin(),a.end()
#define pii pair<int,int>
#define pll pair<long long,long long>
#define mp(aa,bb) make_pair(aa,bb)
#define lrt rt<<1
#define rrt rt<<1|1
#define X first
#define Y second
#define PI (acos(-1.0))
double pi = acos(-1.0);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
const ll mod = 1000000007;
const double eps=1e-12;
const int inf=0x3f3f3f3f;
//const ll infl = 100000000000000000;//1e17
const int maxn=  2e5+20;
const int maxm = 5e3+20;
//muv[i]=(p-(p/i))*muv[p%i]%p;
inline int in(int &ret) {
    char c;
    int sgn ;
    if(c=getchar(),c==EOF)return -1;
    while(c!='-'&&(c<'0'||c>'9'))c=getchar();
    sgn = (c=='-')?-1:1;
    ret = (c=='-')?0:(c-'0');
    while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0');
    ret *=sgn;
    return 1;
}
int dep[maxn];
int cnt[maxn];
vector<int>g[maxn];
vector<int>had[maxn];
vector<int>big;
bool bi[maxn];
int p1[maxn],p2[maxn];
int tot;
int block;
int mxdep;
void dfs(int u,int f) {
    int sz = g[u].size();
    p1[u] = ++tot;
    ++cnt[dep[u]];
    mxdep = max(mxdep,dep[u]);
    r0(i,sz) {
        int v = g[u][i];
        dep[v] = dep[u] + 1;
        dfs(v,u);
    }
    p2[u] = tot;
}
vector<int>sml[maxn];
void dfs2(int u,int f) {
    int sz = g[u].size();
    int udep = dep[u];
    int bs = big.size();
    had[u].resize(bs,0);
    for(int i=0; i<sz; ++i) {
        int v = g[u][i];
        dfs2(v,u);
        for(int j=0; j<bs; ++j)had[u][j] += had[v][j];
    }
    if(bi[udep]) {
        int pos = lower_bound(all(big),udep) - big.begin();
        ++had[u][pos];
    } else sml[udep].pb(u);
}
ll add[maxn];
ll bit[maxn];
int n;
void upd(int x,int v) {
    for(;x<=n;) {
        bit[x] += v;
        x+=lowbit(x);
    }
}
ll query(int x) {
    ll r = 0 ;
    for(; x; x-=lowbit(x)) {
        r += bit[x];
    }
    return r;
}
int main() {
#ifdef LOCAL
    freopen("input.txt","r",stdin);
//    freopen("output.txt","w",stdout);
#endif // LOCAL

    int q;
    sdd(n,q);
    block = sqrt(n/log2(n));
    r0(i,n-1) {
        int a,b;
        sdd(a,b);
        g[a].pb(b);
    }
    dfs(1,0);
    for(int i=0; i<=mxdep; ++i)
        if(cnt[i]>block)bi[i] = 1,big.pb(i);
    dfs2(1,0);
    for(; q--;) {
        int op;
        sd(op);
        if(op&1) {
            int x,v;
            sdd(x,v);
            if(bi[x])add[x] += v;
            else {
                int sz = sml[x].size();
                for(int i=0; i<sz; ++i) {
                    int id = sml[x][i];
                    int p = p1[id];
                    upd(p,v);
                }
            }
        } else {
            int x;
            sd(x);
            ll ans = query(p2[x]) - query(p1[x]-1);
            int bs = big.size();
            r0(i,bs) {
                int cnt = had[x][i];
                ll v = add[big[i]];
                ans += v*cnt;
            }
            lansn();
        }
    }
    return 0;
}
 类似资料: