现在的位置: 首页 > 综合 > 正文

poj 2778 ac自动机+矩阵相乘

2013年12月09日 ⁄ 综合 ⁄ 共 2953字 ⁄ 字号 评论关闭

这两天一直在看ac自动机,昨天上午写了一个最最基础的,结果tle,还一直没改出来,看了大牛的代码发现自己写的好纠结啊,果断的写了一个简单的,昨天下午看了一题,其中有个人写了两百多行的代码,当时好佩服啊,觉得这人有好的的毅力啊,结果自己的代码写出来,发现完全不短于那个人的。。。结果是一直RE,没改出来,然后看了 别人的代码感觉就要跪了,才几十行,啊啊啊啊!

几天又看了一道裸题,代码这次是自己写的,但是好多bug,根据别人的测试数据改 了好久,其实就是写之前没有将所有情况考虑清楚,还有加上马虎,不过最后确实AC了,没有去看别人的代码。。。

关于解题思路,推荐两篇博客:http://blog.henix.info/blog/poj-2778-aho-corasick-dp.html

                                                       http://hi.baidu.com/lilymona/item/fd18390b1885df883d42e25f

本人比较愚钝,两篇加起来才明白了用ac自动机到底应该怎么写,惭愧啊,还看了好多测试数据。。。

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#include<stack>
#include<vector>
#include<climits>
#include<map>
using namespace std;

#define rep(i,n) for(int i=0; i<(n); ++i)
#define repf(i,n,m) for(int i=(n); i<=(m); ++i)
#define repd(i,n,m) for(int i=(n); i>=(m); --i)
#define ll long long
#define inf 1000000000
#define exp 0.000000001
#define pb(i) push_back(i)
#define mod 100000
#define N 101 

ll mu[101][101];
ll c[N][N];
struct node{
	bool vis;
	int fal;
    int sign;	
	int next[4];
};
node a[400];
char s[11];
int n,m,tot;
int len;
int trans(char c)
{
   if(c=='A') return 0;
   if(c=='T') return 1;
   if(c=='C') return 2;
   return  3;
}
struct no
{
	ll mp[N][N];
	no()
	{
		memset(mp,0,sizeof(mp));
		rep(i,len) mp[i][i]=1;
	}
	no(ll mp[][N])
	{
        rep(i,len) rep(j,len) this->mp[i][j]=mp[i][j];
	}
	no operator * (const no &b){
		no re;
		rep(i,len) 
			rep(j,len)
		    {
			  re.mp[i][j]=0;
			  rep(k,len)
				  re.mp[i][j]+=(mp[i][k]*b.mp[k][j])%mod;
			  re.mp[i][j]%=mod;
		    }
		return re;
	}
};

void init()
{
	tot=0;a[0].sign=-1;
	a[0].vis=false;
	rep(i,4) a[0].next[i]=-1; 
}
void tree()
{
	int len=strlen(s); int u=0;
    rep(i,len)
	{
		int x=trans(s[i]);
		if(a[u].next[x]==-1)
		{
			++tot;
			rep(j,4) a[tot].next[j]=-1;
			a[tot].fal=0;a[tot].vis=false;a[tot].sign=-1;
			a[u].next[x]=tot;
		}
		u=a[u].next[x];
		if(i==len-1) a[u].vis=true;
	}
}
int fail(int u,int t)
{	
	if(a[u].next[t]!=-1) 
	{
		if(a[a[u].next[t]].vis==true) return -1;
		return a[u].next[t];
	}
	if(u==0) return 0;
	return fail(a[u].fal,t);
}
void bulid()
{
	queue<int>q; q.push(0);
	while(!q.empty())
	{
		int x=q.front(); q.pop();
		if(a[x].vis==false)
		{
			rep(i,4)
				if(a[x].next[i]!=-1)
				{
					q.push(a[x].next[i]);
					if(x==0) a[a[x].next[i]].fal=0;
					else{
					   	int y=fail(a[x].fal,i);
						if(y==-1) a[a[x].next[i]].vis=true;
						else a[a[x].next[i]].fal=y;
					}
				}
		}
		else
		{
			rep(i,4)
				if(a[x].next[i]!=-1)
	        	{
					q.push(a[x].next[i]);
					a[a[x].next[i]].vis=true;
				}

		}
	}
}

int find(int u,int k)
{
	if(a[u].next[k]!=-1)
	{
	  if(a[a[u].next[k]].vis==false)	return a[u].next[k];
	  else return -1;
	}
	if(u==0) return 0;
	return find(a[u].fal,k);
}

no mul(no b,int n)
{
	no c;
	while(n!=1)
	{
		if(n%2!=0) c=c*b;
		b=b*b;
		n=n/2;
	//	rep(i,len){ rep(j,len) cout<<b.mp[i][j]<<" "; cout<<endl;}
	//	rep(i,len){ rep(j,len) cout<<c.mp[i][j]<<" "; cout<<endl;}
	}
	return b*c;
}
void solve()
{
	init();
    rep(i,n)
		scanf("%s",s),tree();
	a[0].fal=0;
	bulid();
//	repf(i,0,tot) cout<<a[i].vis<<" "<<a[i].fal<<endl;
    memset(mu,0,sizeof(mu));
    len=0;
	repf(i,0,tot)
	{
		if(a[i].vis==true) continue; 
		if(a[i].sign==-1) a[i].sign=len++;
		rep(j,4)
		{
			int y=find(i,j);
			if(y!=-1)
			{
				if(a[y].sign==-1) a[y].sign=len++;
				mu[a[i].sign][a[y].sign]++;
			}
		}
	}//矩阵的长度为len的
	no b(mu);
	if(m>1)
	   b=mul(b,m);
	ll sum=0;
	rep(i,len) sum+=b.mp[0][i];
	cout<<sum%mod<<endl;
}
int main()
{
	while(~scanf("%d%d",&n,&m))
		solve();
    return 0;
}

抱歉!评论已关闭.