当前位置: 首页 > 工具软件 > Eden-PHP > 使用案例 >

HDU - 5977 Garden of Eden(树分治 子集合枚举)

贺高飞
2023-12-01

原题: http://acm.hdu.edu.cn/showproblem.php?pid=5977

题意:

给出一棵树,每个节点有值 v i ∈ [ 1 , k ] v_i\in[1,k] vi[1,k],现在要求有多少条路径使得路径上 [ 1 , k ] [1,k] [1,k]所有值都出现至少一次。(如果路径长度不为0贡献为0)

解析:

考虑树分治的思想去做。

对于当前重心,也就是根 G G G,dfs往下得到所有以 G G G为起点的路径。使用数组记录下每条路径的状态(状压,1表示这一位已经存在)。

那么枚举每条路径,对于这个路径的状态 S S S,计入答案的应该是:所有其他路径状态 S ′ S' S满足 S ∣ S ′ = A L L S|S'=ALL SS=ALL

这个怎么做呢?我们可以枚举 S S S的所有子状态 S i S_i Si,加入对应的状态 ( S i ′ = S i    x o r    A L L ) (S_i'=S_i\;xor\;ALL) (Si=SixorALL)的数量即可。

//枚举s的子集合
        for(int j=s;;j=(s&(j-1))){
            ->j
            if(j==0)break;// 空集
        }

当然还要和其他题目一样,减掉从一个儿子往下延伸的非法情况。

代码:

#include<bits/stdc++.h>
using namespace std;
const int N =5e4+5;
int read(){ int ans=0; char last=' ',ch=getchar();
while(ch<'0' || ch>'9')last=ch,ch=getchar();
while(ch>='0' && ch<='9')ans=ans*10+ch-'0',ch=getchar();
if(last=='-')ans=-ans; return ans;
}

int n,k,End;
int head[N],nex[2*N],to[2*N],now;
int bit[N];
void add(int a,int b){
    to[++now]=b;nex[now]=head[a];head[a]=now;
    to[++now]=a;nex[now]=head[b];head[b]=now;
}


int G;
int vis[N];//是否删除
int siz[N],maxn[N];//这棵子树大小,最大子树大小
void getG(int p,int fa,int sum){
    siz[p]=1;
    maxn[p]=0;
    for(int i=head[p];~i;i=nex[i]){
        if(!vis[to[i]]&&to[i]!=fa){
            getG(to[i],p,sum);//这个时候既可以得出G,又可以得出下面的siz
            siz[p]+=siz[to[i]];
            maxn[p]=max(maxn[p],siz[to[i]]);//找出最大子树
        }
    }
    maxn[p]=max(maxn[p],sum-siz[p]);//当然还要这棵树的上半部分比较
    if(maxn[G]>maxn[p])G=p;
}

int sta[N],num,ct[1059];
void dfs(int p,int fa,int preSta){
    preSta|=bit[p];
    sta[++num]=preSta;
    for(int i=head[p];~i;i=nex[i]){
        if(vis[to[i]]||fa==to[i])continue;
        dfs(to[i],p,preSta);
    }
}

#define LL long long
#define mmm(a,b) memset(a,b,sizeof a)

LL ans;
LL cal(int p,int preSta){
    sta[p]=preSta;
    num=0;
    mmm(ct,0);
    dfs(p,0,sta[p]);
    for(int i=1;i<=num;i++){
        ct[sta[i]]++;
    }
    LL A=0;
    for(int i=1;i<=num;i++){
        for(int j=sta[i];j;j=(sta[i]&(j-1))){
            A+=(LL)ct[j^End];
        }
        A+=(LL)ct[End];
    }
    return A;
}

void divide(int p){
    vis[p]=1;
    ans+=cal(p,0);
    for(int i=head[p];~i;i=nex[i]){
        if(vis[to[i]])continue;
        ans-=cal(to[i],bit[p]);
        G=0;//为了初始化
        maxn[0]=siz[to[i]];
        getG(to[i],0,siz[to[i]]);//一套做全都是0
        divide(G);
    }
}


int main(){
    while(cin>>n>>k){
        ans=0;mmm(head,-1);now=0;mmm(vis,0);End=(1<<k)-1;
        for(int i=1;i<=n;i++)bit[i]=(1<<read()-1);
        for(int i=1,a,b;i<n;i++)a=read(),b=read(),add(a,b);
        maxn[0]=n;
        G=0;getG(1,0,n);
        divide(G);
        printf("%lld\n",ans);
    }
}

 类似资料: