题目类型 DP
题目意思
给出一个最多150字符长的只有a或b或c组成的字符串
对于每个操作可以把前面一个字符变成后面一个字符或者把后面的一个字符变成前面一个字符
即可以执行赋值语句 str[i+1] = str[i]; 或者 str[i] = str[i+1];
如果原字符串在执行若干次操作后变成一个a,b,c的字符数量相互不超过1的 字符串, 那么称得到的串为一个合法串
问合法串有多少个
例如输入字符串 aaacb 其中a有3个,b有1个,c有1个 3-1>1 不合法
但是可以把 第3个a变成 c -> aaccb a有2个,b有1个,c有2个, 相互之间的数量都不超过1 所以 aaccb 就是一个合法的串
解题方法
假设输入的字符串是 A, 字符串 A'是 A串的相同字符压缩成一个后的结果 (例如 aacbbb 压缩成 acb)
假设一个合法串B, B'是B压缩后的字符串
这时会发现 如果 B 是由 A 经过若干次操作得到的结果, 那么 B' 是 A'的一个子序列
如果 B' 是 A' 的一个子序列, 那么说明 B 能从 A 通过若干次操作得到
这样很自然可以想到 如果用一个数组 dp[i] 表示以 A'中的第 i 个字符为结尾的串的数量
由于题目要求字符a, b, c的数量相互不超过1 所以要同时记录字符a, b, c使用了的数量即
状态 dp[i][na][nb][nc] 表示 以 A'中第i个字符为结尾的字符a数量为na, 字符b数量为nb, 字符c数量为nc的字符串的数量
由于A'中只有排在前面的字符才能更新后面的字符所以不断用当前状态去更新后面的状态 当枚举到后面的状态时该状态已经是最优的了
更新时用 完全背包 的思想, 即可以在当前状态后面加任意个的a 或 b 或 c (详情看代码)
if(构造的下一个字符是a) dp[next[i]][na+1][nb][nc] += dp[i][na][nb][nc]; next[i]表示 满足条件(A'[j] == 'a' && j >= i)中最小的 j
if(构造的下一个字符是b) dp[next[i]][na][nb+1][nc] += dp[i][na][nb][nc];
next[i]表示 满足条件(A'[j] == 'b' && j >= i)中最小的 j
next[i]表示 满足条件(A'[j] == 'b' && j >= i)中最小的 j
if(构造的下一个字符是c)
dp[next[i]][na][nb][nc+1] += dp[i][na][nb][nc]; next[i]表示 满足条件(A'[j] == 'c' && j >= i)中最小的 j
dp[next[i]][na][nb][nc+1] += dp[i][na][nb][nc]; next[i]表示 满足条件(A'[j] == 'c' && j >= i)中最小的 j
为什么是最小的 j ? 如果不是最小的 j 是不是会使一部分字符浪费了?
由于要求合法串的a,b,c字符数量相互不超过1 如果字符串长度 n % 3 == 0 则a,b,c字符数量均为 n/3
如果字符串长度 n % 3 == 1 则a,b,c字符数量可能为 n/3+1,n/3,n/3 或 n/3,n/3+1,n/3 或 n/3,n/3,n/3+1
如果字符串长度 n % 3 == 2 则a,b,c字符数量可能为 n/3,n/3+1,n/3+1 或 n/3+1,n/3,n/3+1 或 n/3+1,n/3+1,n/3
即 na, nb, nc的最大值只是50
所以时间复杂度大约为 150*50*50*50 = 18750000 不会超时
参考代码 - 有疑问的地方在下方留言 看到会尽快回复的
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <set> #include <map> #include <string> #include <algorithm> using namespace std; typedef long long LL; const int MAXN = 1e3 + 10; const int MOD = 51123987; char str[MAXN]; char s[MAXN]; int dp[200][53][53][53]; int nex[200][3]; int main() { int n; while(scanf("%d", &n) != EOF) { memset(dp, 0, sizeof(dp)); scanf("%s", str); int k = 0; s[k++] = str[0]; for( int i=1; i<n; i++ ) { if(str[i] != str[i-1]) s[k++] = str[i]; } memset(nex, -1, sizeof(nex)); for( int i=0; i<k; i++ ) { for( int j=i; j<k; j++ ) { if(s[j] == 'a' && nex[i][0] == -1) nex[i][0] = j; if(s[j] == 'b' && nex[i][1] == -1) nex[i][1] = j; if(s[j] == 'c' && nex[i][2] == -1) nex[i][2] = j; } } int tn = n / 3 + 1; dp[0][0][0][0] = 1; for( int i=1; i<=k; i++ ) { //枚举A'串的各个位置 for( int A=0; A<=tn; A++ ) { // A, B, C的枚举顺序是不是很熟悉?这里可以用完全背包的思想理解 for( int B=0; B<=tn; B++ ) { for( int C=0; C<=tn; C++ ) { if(dp[i-1][A][B][C] == 0) continue; int dex = nex[i-1][0]; if(dex != -1) { dp[dex][A+1][B][C] += dp[i-1][A][B][C]; dp[dex][A+1][B][C] %= MOD; } dex = nex[i-1][1]; if(dex != -1) { dp[dex][A][B+1][C] += dp[i-1][A][B][C]; dp[dex][A][B+1][C] %= MOD; } dex = nex[i-1][2]; if(dex != -1) { dp[dex][A][B][C+1] += dp[i-1][A][B][C]; dp[dex][A][B][C+1] %= MOD; } } } } } int res = 0; if(n % 3 == 0) { for( int i=0; i<k; i++ ) { res += dp[i][n/3][n/3][n/3]; res %= MOD; } } else if(n % 3 == 1) { for( int i=0; i<k; i++ ) { res += dp[i][n/3+1][n/3][n/3]; res %= MOD; res += dp[i][n/3][n/3+1][n/3]; res %= MOD; res += dp[i][n/3][n/3][n/3+1]; res %= MOD; } } else { for( int i=0; i<k; i++ ) { res += dp[i][n/3+1][n/3+1][n/3]; res %= MOD; res += dp[i][n/3+1][n/3][n/3+1]; res %= MOD; res += dp[i][n/3][n/3+1][n/3+1]; res %= MOD; } } printf("%d\n", res); } return 0; }