原题: http://acm.hdu.edu.cn/showproblem.php?pid=5977
题意:
给出一棵树,每个节点有值 v i ∈ [ 1 , k ] v_i\in[1,k] vi∈[1,k],现在要求有多少条路径使得路径上 [ 1 , k ] [1,k] [1,k]所有值都出现至少一次。(如果路径长度不为0贡献为0)
解析:
考虑树分治的思想去做。
对于当前重心,也就是根 G G G,dfs往下得到所有以 G G G为起点的路径。使用数组记录下每条路径的状态(状压,1表示这一位已经存在)。
那么枚举每条路径,对于这个路径的状态 S S S,计入答案的应该是:所有其他路径状态 S ′ S' S′满足 S ∣ S ′ = A L L S|S'=ALL S∣S′=ALL。
这个怎么做呢?我们可以枚举 S S S的所有子状态 S i S_i Si,加入对应的状态 ( S i ′ = S i x o r A L L ) (S_i'=S_i\;xor\;ALL) (Si′=SixorALL)的数量即可。
//枚举s的子集合
for(int j=s;;j=(s&(j-1))){
->j
if(j==0)break;// 空集
}
当然还要和其他题目一样,减掉从一个儿子往下延伸的非法情况。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N =5e4+5;
int read(){ int ans=0; char last=' ',ch=getchar();
while(ch<'0' || ch>'9')last=ch,ch=getchar();
while(ch>='0' && ch<='9')ans=ans*10+ch-'0',ch=getchar();
if(last=='-')ans=-ans; return ans;
}
int n,k,End;
int head[N],nex[2*N],to[2*N],now;
int bit[N];
void add(int a,int b){
to[++now]=b;nex[now]=head[a];head[a]=now;
to[++now]=a;nex[now]=head[b];head[b]=now;
}
int G;
int vis[N];//是否删除
int siz[N],maxn[N];//这棵子树大小,最大子树大小
void getG(int p,int fa,int sum){
siz[p]=1;
maxn[p]=0;
for(int i=head[p];~i;i=nex[i]){
if(!vis[to[i]]&&to[i]!=fa){
getG(to[i],p,sum);//这个时候既可以得出G,又可以得出下面的siz
siz[p]+=siz[to[i]];
maxn[p]=max(maxn[p],siz[to[i]]);//找出最大子树
}
}
maxn[p]=max(maxn[p],sum-siz[p]);//当然还要这棵树的上半部分比较
if(maxn[G]>maxn[p])G=p;
}
int sta[N],num,ct[1059];
void dfs(int p,int fa,int preSta){
preSta|=bit[p];
sta[++num]=preSta;
for(int i=head[p];~i;i=nex[i]){
if(vis[to[i]]||fa==to[i])continue;
dfs(to[i],p,preSta);
}
}
#define LL long long
#define mmm(a,b) memset(a,b,sizeof a)
LL ans;
LL cal(int p,int preSta){
sta[p]=preSta;
num=0;
mmm(ct,0);
dfs(p,0,sta[p]);
for(int i=1;i<=num;i++){
ct[sta[i]]++;
}
LL A=0;
for(int i=1;i<=num;i++){
for(int j=sta[i];j;j=(sta[i]&(j-1))){
A+=(LL)ct[j^End];
}
A+=(LL)ct[End];
}
return A;
}
void divide(int p){
vis[p]=1;
ans+=cal(p,0);
for(int i=head[p];~i;i=nex[i]){
if(vis[to[i]])continue;
ans-=cal(to[i],bit[p]);
G=0;//为了初始化
maxn[0]=siz[to[i]];
getG(to[i],0,siz[to[i]]);//一套做全都是0
divide(G);
}
}
int main(){
while(cin>>n>>k){
ans=0;mmm(head,-1);now=0;mmm(vis,0);End=(1<<k)-1;
for(int i=1;i<=n;i++)bit[i]=(1<<read()-1);
for(int i=1,a,b;i<n;i++)a=read(),b=read(),add(a,b);
maxn[0]=n;
G=0;getG(1,0,n);
divide(G);
printf("%lld\n",ans);
}
}