昨天看到一个非常屌的求最长回文子串的O(n)算法。
算法的过程是这样的:先对字符串进行预处理,相邻的字符串之间加上一个特殊的标记符号(这个符号在原串中是没有出现的),为了方便可以在开头加另外一个特殊的标记符号这样经过扩展的字符串就变成了从一开始的了。
用一个数组P标记以当前字符为中心的最长回文子串可以向右延伸多少位,当不能延伸时P [ i ] = 1。然后,可以发现,不管回文串长度是奇数还是偶数,以第 i 个字符为中心的回文串的长度是 P [ i ] - 1 。这样算出了P数组之后扫描一遍即可得到最长回文子序列的长度。
关键是如何求P数组。每次求得 id 位置的P[id]之后,更新一下当前能扫到的最远位置 max ;对于后面的字符,如果max超过了当前的下标,也就是之前这个位置已经被访问过,也就是当前位置的字符是在以id为中心的最长回文子串里面的。用和i关于id对称的一个点j(j=2*id-i)的P[j]值来更新。因为i和j是关于id对称的,又i包含在id的最长回文子串里面,所以可用p[2*id-i]来更新p[i]的值。又比如这种情况,0 0 0 0
0 j 0 0 0 id 0 0 0 i 0 max。用P[j]来更新P[i]的时候,明显有一个问题,j向右延伸的距离可能超过了max-i,而超出来的这部分在id那个回文子串里面是没有被访问过的,所以得取P [ 2 * id - i ] 和 max - i 的最小值,即为P [ i ];
程序实现如下:
char s[maxn]; char str[maxn]; int p[maxn]; int sol() { int mx=0; int id; int n=strlen(s); str[0]='&'; str[1]='#'; for(int i=0; i<n; i++) str[i*2+2]=s[i],str[i*2+3]='#'; n=n*2+2; for(int i=1; i<n; i++) { if(mx>i)p[i]=min(p[2*id-i],mx-i); else p[i]=1; while(str[i+p[i]]==str[i-p[i]])p[i]++; if(p[i]+i>mx)mx=p[i]+i,id=i; } int ans=0; for(int i=1; i<n; i++) ans=max(ans,p[i]-1); return ans; } int main() { while(scanf("%s",s)==1) { cout<<sol()<<endl; } return 0; }
另外,上交的那本《算法与实现》当中也有这个算法,只不过没搞明白。