现在的位置: 首页 > 综合 > 正文

hdu 1754 I Hate It (树状数组)

2017年11月23日 ⁄ 综合 ⁄ 共 2097字 ⁄ 字号 评论关闭

小记:对于求区间的最值问题,如何利用树状数组来解决它:

idx[i]表示1-i区间中最大的值,a[]是存放元素数组

利用树状数组来求解区间最值问题的原理就是树状数组对二进制的利用,对于树状数组的第k个数组元素值它的意义代表着区间[k - lowbit(k) +1, k ]的最大值,在它的下面有lowbit(k)个类似的区间,我们统计出其中的每一个区间的最值,然后求出最大的就是它自己的值了。

void init(){
    for(int i = 1; i <= n; ++i){
        for(int j = i; j <= n; j+= lowbit(j)){
	    idx[j] = max(idx[j],a[j]);
	}
    }
}

这里是初始化idx数组元素的值,但是在执行init之前我们必须对idx初始化一个不影响结果的小值。我们可以改用另一种:

void init(){
    for(int i = 1; i <= n; ++i){
        idx[i] = a[i];
        for(int j = 1; j < lowbit(i); j<<=1){
            idx[i] = max(idx[i],idx[i-j]);
        }
    }
}

这样就可以直接执行init函数了,原理就是前面说的,通过它自己包含的子区间来更新自己的idx的值。

然后我们要查询某个区间的最值,例如[l,r],那么我们就从右边开始算起,只要它所包含的区间在[l,r]之间,那么我就取出它的区间最值,对于某个r,只要其区间[r-lowbit(r)+1,r]是在[l,r]之间的,那么我们就将其idx[r]值来更新结果。如果不在,或者区间超过了r-l+1,那么我们就令r在减去个1,这样我们可以看r的下一个区间,只要一直走就可以更新完所有在[l,r]之间的所有区间,然后我们的结果就出来了。

int query(int l,int r){
	int ans = a[r];
	while(true){
		ans = max(ans,a[r]);
		if(l == r)break;
		for(r--; r - l >= lowbit(r); r -= lowbit(r)){
			ans = max(ans,idx[r]);
		}
	}
	return ans;
}

然后为了能够实现边查边改我们,通过初始化的那样的方法,对某点更新了a[]值之后,那么我就要根据其子节点的值更新其idx的值,然后再往上更新其父亲节点的值,当碰到idx的值大于我更新的那个值时,就可以退出更新了,因为后面的都是一样的。

void modify(int p,int v){
	a[p] = v;
	for(int i = p; i <= n; i+=lowbit(i)){
		if(idx[i] < v)
			idx[i] = v;
		else break;
		for(int j = 1; j < lowbit(i); j<<=1){
			idx[i] = max(idx[i],idx[i-j]);
		}
	}
}

遍历其子节点的依据是:

for(int j = 1; j < lowbit(i); j<<=1){
	idx[i] = max(idx[i],idx[i-j]);
}

根据树状数组的二进制原理。

参考:树状数组求区间最值 树状数组详解

代码:

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>

using namespace std;

#define max(a,b) ((a)>(b))?(a):(b)
#define MAX_ 200010
#define N 1000010

int  a[MAX_], idx[MAX_];

int n;

int lowbit(int x){return x&(-x);}

void init(){
	for(int i = 1; i <= n; ++i){
		idx[i] = a[i];
		for(int j = 1; j < lowbit(i); j<<=1){
			idx[i] = max(idx[i],idx[i-j]);
		}
	}
}

void modify(int p,int v){
	a[p] = v;
	for(int i = p; i <= n; i+=lowbit(i)){
		if(idx[i] < v)
			idx[i] = v;
		else break;
		for(int j = 1; j < lowbit(i); j<<=1){
			idx[i] = max(idx[i],idx[i-j]);
		}
	}
}

int query(int l,int r){
	int ans = a[r];

	while(true){
		ans = max(a[r],ans);
		if(l == r)break;
		for(r--; r-l >= lowbit(r); r -= lowbit(r)){
			ans = max(ans,idx[r]);
		}
	}
	return ans;
}


int main(){
	int T,m,i, s, t;
	char str[10];

	while(~scanf("%d%d",&n,&m)){
		
		for( i = 1; i <= n; ++i){
			scanf("%d",&a[i]);
		}
		init();
		for( i = 0; i < m; ++i){
			scanf("%s%d%d",str,&s,&t);
			switch(str[0]){
			case 'Q':printf("%d\n",query(s,t));
				break;
			case 'U':modify(s,t);
				break;
			}
		}
	}
	return 0;
}

抱歉!评论已关闭.