现在的位置: 首页 > 综合 > 正文

POJ 1330 Nearest Common Ancestors(tarjan , 倍增法求LCA) – from lanshui_Yang

2018年02月21日 ⁄ 综合 ⁄ 共 3018字 ⁄ 字号 评论关闭

          题目大意:给你一棵树,让你求结点 u 和 v 的最近公共祖先(即LCA)。

        解题思路:这道题我学习了两种方法。一种是 tarjan 算法(dfs + 并查集) ,另一种是倍增法。tarjan算法是一种离线算法,较易理解,不再详述。着重谈一下在线算法 : 倍增法求LCA 。  

         tarjan 算法程序如下:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#define mem(a , b) memset(a , b , sizeof(a))
using namespace std ;
const int MAXN = 1e4 + 5 ;
int n ;
int root ;
int u , v ;
vector<int> G[MAXN] ;
bool vis[MAXN] ;
int set[MAXN] ;
void chu()
{
    int i ;
    for(i = 0 ; i <= n ; i ++)
    {
        G[i].clear() ;
        set[i] = i ;
    }
    mem(vis , 0) ;
}
int find(int x)
{
    int r = x ;
    while (r != set[r])
    {
        r = set[r] ;
    }
    int t ;
    while (x != r)
    {
        t = set[x] ;
        set[x] = r ;
        x = t ;
    }
    return r ;
}
void init()
{
    scanf("%d" , &n) ;
    chu() ;
    int i ;
    for(i = 0 ; i < n - 1 ; i ++)
    {
        int a , b ;
        scanf("%d%d" , &a , &b) ;
        G[a].push_back(b) ;

        vis[b] = true ;
    }
    scanf("%d%d" , &u , &v) ;
    for(i = 1 ; i <= n ; i ++)
    {
        if(!vis[i])
        {
            root = i ;
            break ;
        }
    }
    mem(vis , 0) ;
}
void LCA(int x)
{
    vis[x] = true ;
    int i ;
    for(i = 0 ; i < G[x].size() ; i ++)
    {
        int y = G[x][i] ;
        if(!vis[y])
        {
            if(y == u)
            {
                if(vis[v])
                {
                    printf("%d\n" , find(v)) ;
                    return  ;
                }
            }
            else if(y == v)
            {
                if(vis[u])
                {
                    printf("%d\n" , find(u)) ;
                    return  ;
                }
            }
            LCA(y) ;
            set[y] = x ;
        }
    }
}
void solve()
{
    LCA(root) ;
}
int main()
{
    int T ;
    scanf("%d" , &T) ;
    while (T --)
    {
        init() ;
        solve() ;
    }
    return 0 ;
}

         倍增法:

        基本思想是:

        deep[i] 表示 i节点的深度, fa[i,j]表示 i 的 2^j (即2的j次方) 倍祖先,那么fa[i , 0]即为节点i 的父亲,然后就有一个递推式子:

                                                      fa[i,j]= fa [ fa [i,j-1] , j-1 ] ,可以这样理解:

设tmp = fa [i, j - 1] ,tmp2 = fa [tmp, j - 1 ] ,即tmp 是i 的第2 ^ (j - 1) 倍祖先,tmp2 是tmp 的第2 ^ (j - 1) 倍祖先 , 所以tmp2 是i 的第 2 ^ (j - 1) + 2 ^ (j - 1) =  2^ j 倍祖先,注意:这里的“倍”可不能理解为倍数的意思,而是距离节点i有多远的意思,节点i的第2
^ j
倍祖先表示的节点u满足deep[ u ] - deep[ i ] = 2 ^ j
        这样子一个O(NlogN)的预处理求出每个节点的 2^k 的祖先  
        然后对于每一个询问的点对a, b的最近公共祖先就是: 

 先判断是否 d[x]< d[y] ,如果是的话就交换一下(保证 x 的深度大于 y 的深度), 然后把 x 调到与 y 同深度, 同深度以后再把a, b 同时往上调,调到有一个最小的 j 满足fa [x,j] != fa [y,j] (x,y是在不断更新的), 最后再把(x,y)往上调(x=p[x,0], y=p[y,0])  ,一个一个向上调直到x = y, 这时 x或y 就是他们的最近公共祖先。

         Ps:如果还是不明白,就手动模拟一棵节点数为9的树(如下图所示),很快就会理解的。还有我不得不感叹一句 :二进制真的很神奇!!                  

         请看代码:

#include<iostream>
#include<cstring>
#include<algorithm>
#include<string>
#include<cmath>
#include<vector>
#include<cstdio>
#define mem(a , b) memset(a , b , sizeof(a))
using namespace std ;
inline void RD(int &a)
{
    a = 0 ;
    char t ;
    do
    {
        t = getchar() ;
    }
    while (t < '0' || t > '9') ;
    a = t - '0' ;
    while ((t = getchar()) >= '0' && t <= '9')
    {
        a = a * 10 + t - '0' ;
    }
}
inline void OT(int a)
{
    if(a >= 10)
    {
        OT(a / 10) ;
    }
    putchar(a % 10 + '0') ;
}
const int MAXN = 10005 ;
const int M = 30 ;
vector<int> G[MAXN] ;
bool vis[MAXN] ;
int deep[MAXN] ;
int fa[MAXN][M] ;
int n ;
int root ;
void chu()
{
    mem(vis , 0) ;
    mem(deep , 0) ;
    mem(fa , 0) ;
    int i ;
    for(i = 0 ; i <= n ; i ++)
        G[i].clear() ;
}
void dfs(int u)
{
    vis[u] = true ;
    int i ;
    for(i = 0 ; i < G[u].size() ; i ++)
    {
        int v = G[u][i] ;
        if(!vis[v])
        {
            deep[v] = deep[u] + 1 ;
            dfs(v) ;
        }
    }
}
void bz()  // 倍增祖先
{
    int i , j ;
    for(j = 1 ; j < M ; j ++)
    {
        for(i = 1 ; i <= n ; i ++)
        {
            fa[i][j] = fa[ fa[i][j - 1] ][j - 1] ;
        }
    }
}
void swap(int &x , int &y)
{
    int tmp = x ;
    x = y ;
    y = tmp ;
}
int LCA(int u , int v)
{
    if(deep[u] < deep[v]) swap(u , v) ;
    int d = deep[u] - deep[v] ;
    int i ;
    for(i = 0 ; i < M ; i ++)
    {
        if( (1 << i) & d )  // 注意此处,动手模拟一下,就会明白的
        {
            u = fa[u][i] ;
        }
    }
    if(u == v) return u ;
    for(i = M - 1 ; i >= 0 ; i --)
    {
        if(fa[u][i] != fa[v][i])
        {
            u = fa[u][i] ;
            v = fa[v][i] ;
        }
    }
    u = fa[u][0] ;
    return u ;
}
void init()
{
    scanf("%d" , &n) ;
    chu() ;
    int i ;
    for(i = 0 ; i < n - 1 ; i ++)
    {
        int a , b ;
        scanf("%d%d" , &a , &b) ;
        G[a].push_back(b) ;
        fa[b][0] = a ;
        if(fa[a][0] == 0)
        {
            root = a ;
        }
    }
    deep[root] = 1 ;
    dfs(root) ;
    bz() ;
    int u , v ;
    scanf("%d%d" , &u , &v) ;
    printf("%d\n", LCA(u , v)) ;
}
int main()
{
    int T ;
    scanf("%d" , &T) ;
    while (T --)
    {
        init() ;
    }
    return 0 ;
}

抱歉!评论已关闭.