现在的位置: 首页 > 综合 > 正文

树状数组

2019年04月20日 ⁄ 综合 ⁄ 共 1993字 ⁄ 字号 评论关闭

poj2481,一直是超时。。。

#include<stdio.h>
#include <vector>
#include <algorithm>

using namespace std;

const int N = 100001;
class term
{
public:
	term(int s,int e,int id):s(s),e(e),id(id){}
	int s;
	int e;
	int id;
};

class Cmp
{
public:
	bool operator()(const term & t1,const term &t2)
	{
		if(t1.e>t2.e)
			return true;
		if(t2.e == t1.e)
			return t1.s <= t2.s;

		return false;
	}
};

int c[N];
int rst[N];

int lowbit(int i)
{
	return i&(-i);
}

int sum(int i)
{
	int s=0;
	while (i>0)
	{
		s +=c[i];
		i -=lowbit(i);
	}
	return s;
}

void update(int pos,int val)
{
	while(pos<=N)
	{
		c[pos] +=val;
		pos +=lowbit(pos);
	}
}

int main()
{
	int num;
	vector<term> cows; 
	while (scanf("%d",&num) && num!=0)
	{
		for (int i=0;i<num;++i)
		{
			int s,e;
			scanf("%d%d",&s,&e);
			cows.push_back(term(s,e,i));
		}
		sort(cows.begin(),cows.end(),Cmp());
		
		term last = cows[0];//设当前为Scur,Ecur,求s在[0-Scur]之间的span的个数。
		for (int i=0;i<num;++i)
		{
			term t=cows[i];
			if(last.s == t.s && last.e == t.e)
			{
				rst[t.id]=rst[last.id];
			}
			else
				rst[t.id] = sum(t.s+1);
			last = t;
			update(t.s+1,1);
		}

		for (int i=0;i<num;++i)
		{
			if(i!=0)
				printf(" ");
			printf("%d",rst[i]);
		}
		printf("\n");

	}
	return 0;
}

这是网上找的代码

#include <queue>
#include <stack>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <limits.h>
#include <string.h>
#include <algorithm>
#define MAX 100010
using namespace std;
typedef struct SE{
	int s,e;
	int ind;
}SE;
SE se[MAX];
int c[MAX];
int ind[MAX];
int out[MAX];
bool cmp( SE a ,SE b )
{
	if( a.e == b.e )
		return a.s < b.s;
	return a.e > b.e;
}
int Lowbit(int x)
{
	return x & (-x);
}
void Updata(int x)
{
	while( x < MAX )
	{
		c[x]++;
		x += Lowbit(x);
	}
}
int Getsum(int x)
{
	int sum = 0;
	while( x > 0 )
	{
		sum += c[x];
		x -= Lowbit(x);
	}
	return sum;
}
int main()
{
	int i,n;
	while( ~scanf("%d",&n) && n )
	{
		memset(c,0,sizeof(c));
		memset(se,0,sizeof(se));
		memset(out,0,sizeof(out));
		for(i=0; i<n; i++)
		{
			scanf("%d%d",&se[i].s,&se[i].e);
			se[i].s++; se[i].e++;
			se[i].ind = i;
		}
		
		sort(se,se+n,cmp);
		
		int ans = Getsum(se[0].s);
		out[se[0].ind] = ans;
		Updata(se[0].s);
		int ts = se[0].s,te = se[0].e;
		for(i=1; i<n; i++)
		{
			if( se[i].s == ts && se[i].e == te )
			{
				out[se[i].ind] = out[se[i-1].ind];
				Updata(se[i].s);
				continue;
			}
			ts = se[i].s;
			te = se[i].e;
			int ans = Getsum(se[i].s);
			out[se[i].ind] = ans;
			Updata(se[i].s);
		}
		for(i=0; i<n; i++)
		{
			if( i != 0 )
				printf(" ");
			printf("%d",out[i]);
		}
		printf("/n");
	}
return 0;
}

抱歉!评论已关闭.