直径的求法:两遍搜索 (BFS or DFS)
任选一点w为起点,对树进行搜索,找出离w最远的点u。
以u为起点,再进行搜索,找出离u最远的点v。则u到v的路径长度即为树的直径。
简单证明:
如果w在直径上,那么u一定是直径的一个端点。反证:若u不是端点,则从直径另一端点到w再到u的距离比直径更长,与假设矛盾。
如果w不在直径上,且w到其距最远点u的路径与直径一定有一交点c,那么由上一个证明可知,u是直径的一个端点。
如果w到最远点u的路径与直径没有交点,设直径的两端为S与T,那么(w->u)>(w->c)+(c->T),推出(w->u)+(S->c)+(w->c)>(S->c)+(c->T)=(S->T)与假设矛盾。
因此w到最远点u的路径与直径必有交点。
附上bfs的模板代码:
#include <queue>//树的直径bfs
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 1e5+7;
const int maxm = 1e5+7;
struct node{
int id;
int d;
int Next;
}side[maxm];
int head[maxn];
int cnt = 0;
void Init()
{
memset(head,-1,sizeof(head));
cnt = 0;
}
void add(int x,int y,int d)
{
side[cnt].id = y;
side[cnt].d = d;
side[cnt].Next = head[x];
head[x] = cnt++;
}
struct Node{
int id;
int d;
};
int mk[maxn];//标记已经遍历的点
int bfs(int x,int &d)//每个点遍历一次
{
memset(mk,0,sizeof(mk));
queue <Node> q;
Node tmp;
tmp.id = x;
tmp.d = 0;
q.push(tmp);
mk[x] = 1;
int mx = 0,ans = x;
while(q.size())
{
tmp = q.front();
q.pop();
for(int i = head[tmp.id]; i!=-1; i=side[i].Next)
{
int y = side[i].id;
if(mk[y]) continue;
mk[y] = 1;
Node nd;
nd.id = y;
nd.d = tmp.d+side[i].d;
if(nd.d > mx)
{
mx = nd.d;
ans = y;
}
q.push(nd);
}
}
d = mx;//记录最远距离
return ans;//返回距离最远处的那个点
}
int main()
{
Init();
int x,y,w;
while(scanf("%d%d%d",&x,&y,&w)!=EOF)
{
add(x,y,w);
add(y,x,w);
}
int d;
int u = bfs(1,d);//返回的最远点
int v = bfs(u,d);//两部即可完成
printf("%d\n",d);
return 0;
}
附上dfs的模板代码:
#include<iostream>//树的直径dfs 方法和bfs类似
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAX 100000
using namespace std;
struct node{
int id;
int d;
int next;
}side[400000];
int head[100000];
int vis[100000];
int n,m,f,ans;
int cnt=0;
void add(int x,int y,int w)
{
side[cnt].id=y;
side[cnt].d=w;
side[cnt].next=head[x];
head[x]=cnt++;
}//存图
void Init()
{
memset(head,-1,sizeof(head));
cnt=0;
}
void dfs(int x,int num)
{
vis[x]=1;
if(num>ans)
{
ans=num;
f=x;
}
for(int i=head[x];i!=-1;i=side[i].next)
{
if(!vis[side[i].id])
dfs(side[i].id,num+side[i].d);
}
}
int main()
{
Init();
scanf("%d",&n);
m=n-1;
while(m--)
{
int x,y,w;
scanf("%d%d",&x,&y);
add(x,y,1);
add(y,x,1);
}
ans=0;
memset(vis,0,sizeof(vis));
dfs(1,0);
memset(vis,0,sizeof(vis));
dfs(f,0);
printf("%d\n",ans);
return 0;
}