题意:在一棵n个节点的树上修改一个节点的值或求一条链上所以节点的mex
题解:树上带修改莫队
直接上。不解释。
代码:
#include<bits/stdc++.h>
using namespace std;
int n,m,num=0,fst[50010],a[50010],nn=0,blocksize,tn=0,qn=0,t=0,cnt[50010],blocksize2,size[250],last[50010],aa[50010];
struct edge
{
int x,y,n;
}e[100010];
struct pnt
{
int dep,f[17],id,block;
}p[50010];
struct query
{
int x,y,id,t,ans;
}q[50010];
struct mod
{
int x,y,last;
}d[50010];
bool in[50010];
void ins(int x,int y)
{
e[++num]={x,y,fst[x]};
fst[x]=num;
}
int cmp(query x,query y)
{
return p[x.x].block==p[y.x].block?p[x.y].id<p[y.y].id:p[x.x].id<p[y.x].id;
}
void dfs(int x,int f)
{
p[x].dep=p[f].dep+1;
p[x].f[0]=f;
for(int i=1;i<17;i++)
p[x].f[i]=p[p[x].f[i-1]].f[i-1];
p[x].id=++nn;
p[x].block=p[x].id/blocksize;
for(int i=fst[x];i;i=e[i].n)
{
int y=e[i].y;
if(y==f)
continue;
dfs(y,x);
}
}
int getlca(int x,int y)
{
if(p[x].dep<p[y].dep)
swap(x,y);
for(int i=16;i>=0;i--)
if((1<<i)<=p[x].dep-p[y].dep)
x=p[x].f[i];
if(x==y)
return x;
for(int i=16;i>=0;i--)
if(p[x].f[i]!=p[y].f[i])
{
x=p[x].f[i];
y=p[y].f[i];
}
return p[x].f[0];
}
void add(int x,int y)
{
if(x>n)
return;
if(cnt[x]==0&&y==1)
size[x/blocksize2]++;
if(cnt[x]==1&&y==-1)
size[x/blocksize2]--;
cnt[x]+=y;
}
void chg(int x,int y)
{
if(in[x])
{
add(a[x],-1);
add(y,1);
}
a[x]=y;
}
int getans()
{
for(int i=0;;i++)
if(size[i]<blocksize2)
{
for(int j=i*blocksize2;;j++)
if(!cnt[j])
return j;
}
}
void move(int x,int y)
{
int lca=getlca(x,y);
while(x!=lca)
{
if(in[x])
add(a[x],-1);
else
add(a[x],1);
in[x]^=1;
x=p[x].f[0];
}
while(y!=lca)
{
if(in[y])
add(a[y],-1);
else
add(a[y],1);
in[y]^=1;
y=p[y].f[0];
}
}
int cmp2(query x,query y)
{
return x.id<y.id;
}
int main()
{
scanf("%d%d",&n,&m);
blocksize=int(ceil(pow(n,2.0/3)));
blocksize2=int(sqrt(n));
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
aa[i]=a[i];
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
ins(x,y);
ins(y,x);
}
dfs(1,0);
for(int i=0;i<m;i++)
{
int op,x,y;
scanf("%d%d%d",&op,&x,&y);
if(op==0)
{
d[++tn]={x,y,last[x]};
last[x]=tn;
}
else
{
q[qn++]={x,y,i,tn,0};
}
}
sort(q,q+qn,cmp);
int x=1,y=1;
for(int i=0;i<qn;i++)
{
while(t<q[i].t)
{
t++;
chg(d[t].x,d[t].y);
}
while(t>q[i].t)
{
int last=d[t].last;
if(last==0)
chg(d[t].x,aa[d[t].x]);
else
chg(d[t].x,d[last].y);
t--;
}
move(x,q[i].x);
move(y,q[i].y);
int lca=getlca(q[i].x,q[i].y);
add(a[lca],1);
q[i].ans=getans();
add(a[lca],-1);
x=q[i].x;
y=q[i].y;
}
sort(q,q+qn,cmp2);
for(int i=0;i<qn;i++)
printf("%d\n",q[i].ans);
}