现在的位置: 首页 > 综合 > 正文

树状数组总结篇

2013年10月14日 ⁄ 综合 ⁄ 共 5302字 ⁄ 字号 评论关闭

最近几天系统的做了一下树状数组的各类题型,感觉树状数组真乃神器呀,编程复杂度低,同时功能强大,一维和二,三维的基本上共通。

首先复习一下:树状数组的几个基本操作。

在我的理解中树状数组可以分为两类,插段问点和插点问段,分别对应的是向上更新,向下统计和向下更新,向上统计(一定要真正理解这句话)。

1,lowbit,这就不多说了,精华之一。2,修改操作,对应于某个点的改变,同时被该店所管辖的点也要一起改变。3.统计sum,调用一次sum(a),就意味着对a之前的所有的点都进行了改变。

//注意。在插段问点和插点问段中的求和操作与修改操作是相反的。

 int lowbit(int i){
		int ans=i&(-i);
		return ans;
	}
	//插点问段
	 void modify(int k,int d){
		while(k<=N){
			c[k]+=d;
			k+=lowbit(k);
		}
	}
	 int sum(int n){
		int result=0;
		while(n>0){
			result+=c[n];
			n-=lowbit(n);
		}
		return result;
	}
	
	//插段问点
	
	 void modify1(int k,int num){
		while(k>0){
			c[k]+=num;
			k-=lowbit(k);
		}
	}
	
	 int sum1(int n){//用于统计某个点的出现的次数
		int s=0;
		while(n<=N){
			s+=c[n];
			n+=lowbit(n);
		}
		return s;
	}

2.二维树状数组模板

static void inc(int i, int j, int t) {

		int temj;
		for (; i <= Max; i += lowbit(i)) {
			for (temj = j; temj <= Max; temj += lowbit(temj)) {
				arr[i][temj] += t;
			}
		}

	}

	static int getSum(int i, int j) {
		int temj, sum = 0;
		for (; i > 0; i -= lowbit(i)) {
			for (temj = j; temj > 0; temj -= lowbit(temj))
				sum += arr[i][temj];
		}
		return sum;
	}

}


具体题目介绍:

在这里介绍两道中等难度的题目:

poj 3321 Apple Tree

这题需要把一个树形结构变换为线性结构,然后使用树状数组动态求和。变换方法为DFS一次这棵树,记录下每个节点第一次访问和最后一次访问时的顺序编号,这时,在某节点第一次访问和最后一次访问编号之间的访问编号,必定为该节点的子节点,按照访问编号建立树状数组,当改变某一节点苹果时,只需按照该节点第一次访问的编号在树状数组中修改值,查询时统计某节点在第一次访问和最后一次访问编号之间的和。

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.InputMismatchException;
public class E {
	class FastIO {
	    
	    int nextInt() throws IOException {
	        return Integer.valueOf(next());
	    }
	    
	    long nextLong() throws IOException {
	        return Long.valueOf(next());
	    }
	    
	    double nextDouble() throws IOException{
	        return Double.valueOf(next());
	    }

	    String next() throws IOException {
	        StringBuilder sb = new StringBuilder();
	        char c = skipChar();
	        while (Character.isWhitespace(c))
	            c = skipChar();
	        while (!Character.isWhitespace(c)) {
	            sb.append(c);
	            c = getChar();
	        }
	        return sb.toString();
	    }

	    char skipChar() throws IOException {
	        int tmp = bfr.read();
	        if (tmp == -1)
	            throw new InputMismatchException();
	        return (char) tmp;
	    }

	    char getChar() throws IOException {
	        return (char) bfr.read();
	    }

	    void end() {
	        out.flush();
	        out.close();
	    }

	    BufferedReader bfr = new BufferedReader(new InputStreamReader(System.in));
	    PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
	}

	

	static int maxn=100010;
	static int c[]=new int[maxn];
	static int start[]=new int[maxn];
	static int end[]=new int[maxn];
	static int flag[]=new int[maxn];
	static int data[]=new int[maxn];
	static int head[]=new int[maxn];
	static int edgenum;
	static int idx;
	class Node{
		int v,next;
	}
	Node eg[]=new Node[maxn*2];
	static void init()
	{
		edgenum=0;
		idx=0;
		for(int i=0;i<maxn;i++)
		{
			
			c[i]=0;
			flag[i]=0;
			start[i]=0;
			end[i]=0;
			head[i]=-1;
			data[i]=1;
		}
	}
	
	void addedge(int u,int v)
	{
		eg[edgenum]=new Node();
		eg[edgenum].v=v;
		eg[edgenum].next=head[u];
		head[u]=edgenum++;
	}
	
	void dfs(int x)
	{
		int i;
		flag[x]=1;
		idx++;
		start[x]=idx;
		for( i=head[x];i!=-1;i=eg[i].next)
		{
			if(flag[eg[i].v]==0)
				dfs(eg[i].v);
		}
		end[x]=idx;
	}
	
	static int lowbit(int x)
	{
		return x&(-x);
	}
	
	static void modify(int x,int add)
	{
		for(int i=x;i<maxn;i+=lowbit(i))
		{
			c[i]+=add;
		}
	}
	
	static int sum(int a)
	{
		int sum=0;
		for(int i=a;i>0;i-=lowbit(i))
			sum+=c[i];

		return sum;
	}
	
	void run() throws IOException
	{
		FastIO in=new FastIO();
		int num=in.nextInt();
		init();
		for(int i=1;i<=num;i++)
			modify(i,1);
		for(int i=1;i<num;i++)
		{
			int u=in.nextInt();
			int v=in.nextInt();
			addedge(u,v);
			addedge(v,u);
		}
		dfs(1);
		int qnum=in.nextInt();
		
		while(qnum--!=0)
		{
			String inc=in.next();
			int ds=in.nextInt();
			
			if(inc.equals("Q"))
			{
				int ans=sum(end[ds])-sum(start[ds]-1);
				System.out.println(ans);
			}else
			{
				if(data[ds]==1)
				{
					modify(start[ds],-1);
				}else if(data[ds]==0)
				{
					modify(start[ds],1);
				}
				data[ds]^=1;
			}
		}
		
	}
	public static void main(String[] args) throws IOException {
		new E().run();
	}
	
}

POJ 2155

插段问点的典型问题:

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.InputMismatchException;

public class B {
	class FastIO {
	    
	    int nextInt() throws IOException {
	        return Integer.valueOf(next());
	    }
	    
	    long nextLong() throws IOException {
	        return Long.valueOf(next());
	    }
	    
	    double nextDouble() throws IOException{
	        return Double.valueOf(next());
	    }

	    String next() throws IOException {
	        StringBuilder sb = new StringBuilder();
	        char c = skipChar();
	        while (Character.isWhitespace(c))
	            c = skipChar();
	        while (!Character.isWhitespace(c)) {
	            sb.append(c);
	            c = getChar();
	        }
	        return sb.toString();
	    }

	    char skipChar() throws IOException {
	        int tmp = bfr.read();
	        if (tmp == -1)
	            throw new InputMismatchException();
	        return (char) tmp;
	    }

	    char getChar() throws IOException {
	        return (char) bfr.read();
	    }

	    void end() {
	        out.flush();
	        out.close();
	    }

	    BufferedReader bfr = new BufferedReader(new InputStreamReader(System.in));
	    PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
	}
	static int maxn=1010;
	static int cc[][]=new int[maxn][maxn];
	
	//插段问点
	
	static void init()
	{
		for(int i=0;i<maxn;i++)
			for(int j=0;j<maxn;j++)
				cc[i][j]=0;
	}
	
	static int lowbit(int x)
	{
		return x&(-x);
	}
	
	static void modify(int a,int b,int add)
	{
		int i,temj;
		for(i=a;i>0;i-=lowbit(i))
		{
			for(temj=b;temj>0;temj-=lowbit(temj))
				cc[i][temj]+=add;
		}
		
	}
	static int sum(int a,int b)
	{
		
		int sum=0,temj;
		for(int i=a;i<maxn;i+=lowbit(i))
		{
			for(temj=b;temj<maxn;temj+=lowbit(temj))
				sum+=cc[i][temj];
		}
		return sum;
	}
	void run() throws IOException
	{
		FastIO in=new FastIO();
		int num=in.nextInt();
		while(num--!=0)
		{
			init();
			int N=in.nextInt();
			int qnum=in.nextInt();
			for(int i=1;i<=qnum;i++)
			{
				String inc=in.next();
				if(inc.equals("C"))
				{
					int x1=in.nextInt()+1;
					int y1=in.nextInt()+1;
					int x2=in.nextInt()+1;
					int y2=in.nextInt()+1;
					modify(x2,y2,1);
					modify(x1-1,y1-1,1);
					modify(x2,y1-1,-1);
					modify(x1-1,y2,-1);
				}else if(inc.equals("Q"))
				{
					int a=in.nextInt()+1;
					int b=in.nextInt()+1;
					System.out.println(sum(a,b)%2);
				}
			}
			System.out.println();
		}
	}
	public static void main(String[] args) throws IOException {
		new B().run();
	}
}







抱歉!评论已关闭.