现在的位置: 首页 > 综合 > 正文

SPOJ 913 Query on a tree II ( 树链剖分 + 倍增 )

2019年02月22日 ⁄ 综合 ⁄ 共 2218字 ⁄ 字号 评论关闭

题目链接~~>

做题感悟:感觉又充实了一些。

解题思路:树链剖分 + 倍增

     开始看时,第一问还好,第二问就不知道怎么解了。其实这两问都可以用倍增法解决。

               先解释一下我理解的倍增 :记录 u 结点的 第 2 ^ i 个祖先,然后求u 的第 k 个祖先的时候,就相当于用 2 ^ i 去组合 k ,不断向上,一直到达第 k 个节点,其实每次更新的时k 的二进制中为 1 的位置。如下图,计算 u 的第 5 个祖先结点(这里不包括 u),先到达 u' 节点,然后再从 u' ,到 u'' (5 的二进制 101) 。会倍增算法后就好做了,计算第一问的时候
dis = dis[ u ] + dis[v] - 2 * dis[ LCA(u ,v)] 
,第二问先判断一下要求的点在 u 到交点的链上还是在 v 到交点的链上然后再结合倍增做就ok了。

代码:

#include<iostream>
#include<sstream>
#include<map>
#include<cmath>
#include<fstream>
#include<queue>
#include<vector>
#include<sstream>
#include<cstring>
#include<cstdio>
#include<stack>
#include<bitset>
#include<ctime>
#include<string>
#include<cctype>
#include<iomanip>
#include<algorithm>
using namespace std  ;
#define INT long long int
#define L(x)  (x * 2)
#define R(x)  (x * 2 + 1)
const int INF = 0x3f3f3f3f ;
const double esp = 0.0000000001 ;
const double PI = acos(-1.0) ;
const INT mod = 1000000007 ;
const int MY = 1400 + 5 ;
const int MX = 20000 + 5 ;
int num ,S = 20 ,n ;
int head[MX] ,dep[MX] ,dis[MX] ,p[MX][30] ;
struct NODE
{
    int v ,w ,next ;
}E[MX] ;
void addedge(int u ,int v ,int w)
{
    E[num].v = v ; E[num].w = w ; E[num].next = head[u] ; head[u] = num++ ;
    E[num].v = u ; E[num].w = w ; E[num].next = head[v] ; head[v] = num++ ;
}
void dfs_find(int u ,int fa ,int w)   // 处理深度、距离
{
   dep[u] = dep[fa] + 1 ;
   dis[u] = w ;
   p[u][0] = fa ;
   for(int i = 1 ;i <= S ; ++i) // 处理祖先
        p[u][i] = p[p[u][i-1]][i-1] ;
   for(int i = head[u] ;i != -1 ;i = E[i].next)
   {
       int v = E[i].v ;
       if(v == fa)  continue ;
       dfs_find(v ,u ,w + E[i].w) ;
   }
}
int LCA(int u ,int v) // 计算公共交点
{
    if(dep[u] > dep[v])  swap(u ,v) ;  // u 的深度小于等于 v
    if(dep[u] < dep[v]) // 处理成同一深度
    {
        int d = dep[v] - dep[u] ;  // 深度差
        for(int i = 0 ;i < S ; ++i)
          if(d&(1<<i))
             v = p[v][i] ;
    }
    if(u != v)  // 已经变成同一深度
    {
        for(int i = S ;i >= 0 ; --i)
          if(p[u][i] != p[v][i])
          {
              u = p[u][i] ;
              v = p[v][i] ;
          }
          u = p[u][0] ;
          v = p[v][0] ;
    }
    return u ;
}
int cunt(int u ,int k) // 计算 u 的第 k 个节点
{
    for(int i = 0 ;i < S ; ++i)
      if(k&(1<<i))
         u = p[u][i] ;
    return u ;
}
int Query(int u ,int v ,int k) // 从 u 到 v 的路径上的第 k 个节点
{
    int z = LCA(u ,v) ; // 公共交点
    if(dep[u] - dep[z] + 1 >= k) // 在 u 的这条链上
             return cunt(u ,k-1) ;
    else // 在 v 的这条线上
    {
        k -= dep[u] - dep[z] ;
        k = dep[v] - dep[z] - k + 1 ;
        return cunt(v ,k) ;
    }
}
int main()
{
    //freopen("input.txt" ,"r" ,stdin) ;
    char s[10] ;
    int Tx ,u ,v ,w ,k ;
    scanf("%d" ,&Tx) ;
    while(Tx--)
    {
        scanf("%d" ,&n) ;
        num = 0 ;
        memset(head ,-1 ,sizeof(head)) ;
        for(int i = 1 ;i < n ; ++i)
        {
            scanf("%d%d%d" ,&u ,&v ,&w) ;
            addedge(u ,v ,w) ;
        }
        dep[1] = 0 ;
        dfs_find(1 ,1 ,0) ;
        while(scanf("%s" ,s) && strcmp(s ,"DONE"))
        {
            if(s[0] == 'D')  // 求任意两点之间的距离
            {
                scanf("%d%d" ,&u ,&v) ;
                printf("%d\n" ,dis[u] + dis[v] - 2 *dis[LCA(u ,v)]) ;
            }
            else   // 询问第 k 个节点
            {
                scanf("%d%d%d" ,&u ,&v ,&k) ;
                printf("%d\n" ,Query(u ,v ,k)) ;
            }
        }
    }
    return 0 ;
}


抱歉!评论已关闭.