现在的位置: 首页 > 综合 > 正文

URAL 1471(lca tarjan算法)

2018年01月14日 ⁄ 综合 ⁄ 共 2396字 ⁄ 字号 评论关闭

题意:给定一棵树,查询时给定两个点,求出两个点的距离。

暴力做肯定超时的。我的做法是采用lca(最近公共祖先)的离线算法,即tarjan算法(据说Tarjan提出了很多算法,可能还有很多tarjan算法),算法里用到了并查集。在输入完所有查询之后,在求出答案。tarjan算法的做法是:一开始vis数组初始化为0,从树根开始递归往下对点进行染色,刚到一个点的时候将vis取为-1,在继续递归;遍历完子节点返回之后vis变为1。在vis变为1之前,检索一下当前节点的所有查询,设查询中的另外一个节点为To,如果vis[To]==0,就continue,因为To还没有处理,不知道它的信息;如果vis[To]==-1,说明To被访问了一次,但是还没有返回到,这意味着To是当前节点的祖先,因此To就是当前节点的最近公共祖先;如果vis[To]==1,说明To已经处理完了,这时候并查集就派上用场了。在递归时,当一个节点处理完返回到父亲那里时,就把父亲变成其所在集合的代表元素。在刚才讨论到vis[To]==1的情况中,可以知道find(To)(即To所在集合的代表元素)就是To和当前节点的最近公共祖先了(这个可以画图演算一下)。在这道题中,我们一开始可以用一个简单的递归算出每个点到根节点的距离dis[i]。那么对于一个查询的两个点fir和sec,它们的距离就是dis[fir]-dis[lca]+dis[sec]-dis[lca],lca是fir和sec的最近公共祖先。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<set>
#include<climits>
#include<queue>
#include<vector>
#include<map>
using namespace std;

struct node
{
	int to,id;
	node(int t,int i)
	{
		to=t;
		id=i;
	}
	node(){}
};
const int maxn=50005;
vector<node>vec[maxn];
vector<pair<int,int>>query;
int father[maxn],fir[maxn<<1],nxt[maxn<<1],vv[maxn<<1],val[maxn<<1],dis[maxn],ans[75005],e;
int vis[maxn];//0 means it's white,-1 means it's grey, 1 means it's black

int findn(int n)
{
	if(n!=father[n]) father[n]=findn(father[n]);
	return father[n];
}

void add(int a,int b,int c,int i)
{
	vv[e]=b;
	val[e]=c;
	nxt[e]=fir[a];
	fir[a]=e++;
}

void get_height(int sroot,int dist)
{
	vis[sroot]=1;
	dis[sroot]=dist;
	for(int i=fir[sroot];i!=-1;i=nxt[i])
	{
		int v=vv[i];
		if(!vis[v])
		{
			get_height(v,dist+val[i]);
		}
	}
}

void dfs(int cur,int fa)
{
	vis[cur]=-1;
	for(int i=fir[cur];i!=-1;i=nxt[i])
	{
		int v=vv[i];
		if(!vis[v])
		{
			dfs(v,cur);
			father[v]=cur;
		}
	}
	int size=vec[cur].size();
	
	for(int i=0;i<size;i++)
	{
		node nxt=vec[cur][i];
		if(!vis[nxt.to]) continue;
		if(-1==vis[nxt.to])
		{
			ans[nxt.id]=nxt.to;
		}
		else if(1==vis[nxt.to])
		{
			ans[nxt.id]=findn(nxt.to);
		}
	}
	vis[cur]=1;
}
int main()
{
	#pragma comment(linker, "/STACK:102400000,102400000")//此代码需要扩栈,可能在递归时耗的内存有点大
	int n;
	while(scanf("%d",&n)!=EOF)
	{
		for(int i=0;i<=n;i++) 
		{
			father[i]=i;
			fir[i]=-1;
			vis[i]=0;
			vec[i].clear();
		}
		e=0;//important
		int a,b,c;
		for(int i=0;i<n-1;i++)
		{
			scanf("%d%d%d",&a,&b,&c);
			add(a,b,c,i);
			add(b,a,c,i);
		}
		get_height(0,0);
		int q;
		scanf("%d",&q);
		for(int i=0;i<q;i++)
		{
			scanf("%d%d",&a,&b);
			vec[a].push_back(node(b,i));
			vec[b].push_back(node(a,i));
			query.push_back(make_pair<int,int>(a,b));
		}
		for(int i=0;i<=n;i++) vis[i]=0;
		dfs(0,0);
		int size=query.size();
		for(int i=0;i<size;i++)
		{
			int fir=query[i].first;
			int sec=query[i].second;
			int lca=ans[i];
			int distance=abs(dis[lca]-dis[fir])+abs(dis[lca]-dis[sec]);
			printf("%d\n",distance);
		}
	}
}

抱歉!评论已关闭.