http://acm.hdu.edu.cn/showproblem.php?pid=5977
点分治,对于每个重心u计算出经过u且能构成当前需要的满状态ss的方案数。
对于每条到u的路径,我们得到这条路径上有的种类的状态压缩的数字a[i]
然后枚举a[i]的每个子集(可以在calc中枚举,也可以在外面枚举存在vector)) s0,sum[s0]++,说明a[i]是包含s0的,s0是a[i]的子集,sum[s0]就是包含s0的路径有多少条,这样就可以计算出所需方案数了。
注意我们要减去同一棵子树中的非法组合,而对于u的所有子节点v来说,v的路径能组成(mi[k+1]-1)^(mi[mp[u]])也就是默认带上u节点上的苹果,所能凑出所有苹果的路径组合的方案数,就是子树v对于u来说的非法组合数。
O(nlogn*2^k)的复杂度,虽然很大,不过其实子集数量其实是比较少的,且递归越深子集数量越少,基本卡不掉。
#include<bits/stdc++.h>
using namespace std;
const int maxl=5e4+10;
int n,k,m,cnt,len;
long long ans;
int a[maxl],ehead[maxl],typ[maxl],mi[12],dy[1024];
int st[maxl],sum[maxl];
struct ed
{
int to,nxt;
}e[maxl<<1];
vector <int> ziji[1024];
bool vis[maxl];
struct centertree
{
int n,ans,mini;
int son[maxl];
inline void dfs(int u,int fa)
{
son[u]=1;int mx=0,v;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa || vis[v]) continue;
dfs(v,u);
son[u]+=son[v];
mx=max(son[v],mx);
}
mx=max(mx,n-son[u]);
if(mx<mini)
{
mini=mx;
ans=u;
}
}
inline int getcenter(int u)
{
ans=0;mini=2e9;
dfs(u,0);
return ans;
}
}tree;
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
for(int i=1;i<=n;i++)
ehead[i]=0,vis[i]=false;
for(int i=1;i<=n;i++)
scanf("%d",&typ[i]);
int u,v;cnt=0;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
}
inline void getst(int u,int fa)
{
a[++len]=st[u];
int v;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa || vis[v]) continue;
st[v]=st[u] | mi[typ[v]];
getst(v,u);
}
}
inline long long calc(int u,int ss)
{
len=0;
st[u]=mi[typ[u]];
getst(u,0);
for(int i=0;i<mi[k+1];i++)
sum[i]=0;
long long ret=0;
for(int i=1;i<=len;i++)
for(auto s0 : ziji[a[i]])
sum[s0]++;
sum[0]=len;
int s;
for(int i=1;i<=len;i++)
{
s=ss&a[i];
ret+=sum[ss^s];
}
return ret;
}
inline void solv(int u)
{
vis[u]=true;
ans+=calc(u,mi[k+1]-1);
int v,rt;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(vis[v]) continue;
ans-=calc(v,(mi[k+1]-1)^mi[typ[u]]);
tree.n=tree.son[v];
rt=tree.getcenter(v);
solv(rt);
}
}
inline void mainwork()
{
ans=0;tree.n=n;
int rt=tree.getcenter(1);
solv(rt);
}
inline void print()
{
printf("%lld\n",ans);
}
int main()
{
for(int i=1;i<=11;i++)
mi[i]=1<<(i-1),dy[mi[i]]=i;
for(int i=1;i<1024;i++)
for(int s0=i;s0;s0=(s0-1)&i)
ziji[i].push_back(s0);
while(~scanf("%d%d",&n,&k))
{
prework();
mainwork();
print();
}
return 0;
}