现在的位置: 首页 > 综合 > 正文

24点计算问题

2019年03月19日 ⁄ 综合 ⁄ 共 12274字 ⁄ 字号 评论关闭

问题描述:N个1到13之间的自然数,找出所有能通过加减乘除计算(每个数有且只能用一次)得到24的组合?


计算24点常用的算法有三种,第一种方法:任取两个数,计算后,将结果放回去,再从剩下的数中任取两个,如此反复直到只剩下一个数;第二种方法:先构建前缀/后缀表达式,再计算该表达式;第三种方法是用集合保存中间结果,集合间两两进行合并计算得到新集合(或者对给定的一个集合,对其所有的子集合进行合并计算)。本博文首先采用第一种方法。六种操作符:ADD(加)、SUB(减)、MUL()、DIV()、RSUB()、RDIV(反除),其中反减/除:先交换两个数,再减/除。显然,取两个数计算时,六种计算结果可能有重复,可以对这6个结果进行去重(实际上,只要分别对加减(ADD、SUB、RSUB)和乘除(MUL、DIV、RDIV)的3个计算结果进行去重判断就可以了,效率和对6个结果去重相差不大)。另外一种剪枝方法:保存每个数上次计算时所用的操作符(初始值为空)。所取的两个数:若某个数的上次操作符为减(SUB、RSUB),那么不进行加减(ADD、SUB、RSUB)计算。若某个数的上次操作符为除(DIV、RDIV),那么不进行乘除(MUL、DIV、RDIV)计算。比如:取的两个数为 a-b 和 c(c的上次操作符任意),如果进行加减计算的话.

a-b+c  和 c+a-b重复,
c-(a-b)和 c+b-a重复
a-b-c  和 c+b RSUB a重复

也就是说,上次操作符为减的,进行加减计算时,总可以转为某个上次操作符为加的表达式,因而可以不计算。同样,上次操作符为除的,不进行乘除计算。当然,还可以考虑记录位置进行剪枝,这样避免a+b+c和a+c+b都进行计算。但要注意的是:在给定的组合无解时,越多的剪枝方法,极有可能提高搜索效率,但在给定的组合有解时,很可能反而降低搜索效率。另外,对有解时输出的表达式的处理对程序的性能影响很大。如果每次计算都保存对应的表达式,会进行大量的字符串操作,严重影响性能。实际上,每次计算只要保存取出的两个数的位置和所进行计算的操作符就够了,最终需要输出表达式时,只要模拟一下递归函数调用过程,进行相应的字符串操作。

#include<iostream>
#include<fstream>
#include<sstream>
#include<vector>
#include<ctime>
#include<cmath>

using namespace std;

class Calc {
 public:
	 Calc(){};
	 void print_result() const;
	 bool run(const int src[], size_t sz, double n = 24.0, bool expr_calc = true, bool show = true); 
	 void calc_range(int first, int last,size_t N = 4,double M = 24, string filename = "24.out");

	 const string& get_expr() const { return expr;}
	 size_t get_count_expr() const { return count_expr;}
	 size_t get_count_func() const { return count_func;}

private:
	Calc(const Calc&);
	Calc& operator=(const Calc&);
	bool init(const int src[], size_t sz, double n);
	bool calc(size_t step);
	inline bool calc2(size_t step, size_t pos2,double na, double nb, int op);
	void calc_expr(); 
	
	void add_parentheses(string& str) {
		string tmp;  
		tmp.reserve(str.size() + 2);
		tmp += '(';  tmp += str;  tmp += ')';
		str.swap(tmp);
	} 
	
	char get_op_char(int op) { return char(op >> RSHIFT); }
	int get_opv(int op) { return op & OPV_MASK; }

  //0-2位表示操作符的优先级 加减: 1 乘除2  初始值4
  //+3位,即RFLAG标志,表示对减除法,交换两个操作数后再计算
  //4-7位表示操作符,8-15位表示该操作符的ascii值
	enum {
		OP_NULL = 4,
		RFLAG = 8, 
		RSHIFT = 8, 
		OPV_MASK = 7,
		FLAG_ADD = 0x10, 
		FLAG_SUB = 0x20, 
		FLAG_MUL = 0x40, 
		FLAG_DIV = 0x80,
		ADD = '+' << RSHIFT | FLAG_ADD | 1, 
		SUB = '-' << RSHIFT | FLAG_SUB | 1,
		MUL = '*' << RSHIFT | FLAG_MUL | 2, 
		DIV = '/' << RSHIFT | FLAG_DIV | 2,
		RSUB = SUB | RFLAG, RDIV = DIV | RFLAG,
	};
	
	struct Info_step {              //记录每一步取两个数所进行的计算
		size_t first;                 //第一个操作数位置 
		size_t second;                //第二个操作数位置
		int op;                       //操作符
	};
	
	size_t size;                    
	string expr;                    //得到的表达式 
	double result;                  //要得到的结果值
	size_t count_expr;              //处理的表达式总数     
	size_t count_func;              //函数被调用次数
	vector<int> old_number;         //要计算的数
	vector<double> number;          //中间计算结果           
	vector<int> ops;               //上一次计算所用的操作符,初始值要设为OP_NULL 
	vector<Info_step> info_step;  
};
 
bool Calc::init(const int src[], size_t sz, double n){
	if (sz == 0 || src == NULL) 
		return false;
	size = sz;
	expr.clear();
	result = n;
	count_expr = 0;
	count_func = 0;

	old_number.assign(src, src + sz); 
	number.assign(src, src+sz);
	ops.assign(sz, OP_NULL);

	info_step.clear();
	info_step.resize(sz);
	return true;
}

bool Calc::run(const int src[], size_t sz, double n, bool expr_calc, bool show){
	if (! init(src, sz, n)) 
		return false;
	bool ret = calc(sz);
	if (ret && expr_calc) 
		calc_expr();
	if (show) 
		print_result(); 
	return ret;
}

void Calc::calc_expr(){
	const size_t sz =  size;

	static vector<string> str;
	static vector<int> op_prev;
	static stringstream ss;

	str.resize(sz); 
	op_prev.assign(sz,OP_NULL);    //初始值
	for (size_t i = 0; i < sz; ++i) {
		ss << old_number[i];
		getline(ss, str[i]);
		ss.clear();
	}
	
	for (size_t k = sz; k-- > 1; ) {
		size_t i = info_step[k].first;
		size_t j = info_step[k].second;
		
		int op = info_step[k].op;
		int opv= get_opv(op);
		int op1v, op2v;
		
		if (op & RFLAG) { 
			str[i].swap(str[j]);
			op1v = get_opv(op_prev[j]);
			op2v = get_opv(op_prev[i]);
		} 
		else {
			op1v = get_opv(op_prev[i]);
			op2v = get_opv(op_prev[j]);
		}
		
		if (opv > op1v) 
			add_parentheses(str[i]);
		if (opv > op2v || (opv == op2v && (op & (FLAG_SUB | FLAG_DIV)))) 
			add_parentheses(str[j]);
		
		op_prev[i] = op;
		op_prev[j] = op_prev[k];
		
		str[i].reserve(str[i].size() + str[j].size() + 1);
		str[i] += get_op_char(op);
		str[i] += str[j]; 
		str[j].swap(str[k]);
	}
	expr.swap(str[0]);
}

bool Calc::calc(size_t step){
	++count_func;
	if (step <= 1) {
		++count_expr;
		const double zero = 1e-9;   
		if (fabs(result - number[0]) < zero) return true; 
		return false;
	}
	--step;
	for (size_t i = 0; i < step; i++){
		info_step[step].first = i;
		for(size_t j = i + 1; j <= step; j++) {
			info_step[step].second = j;
			
			double na = number[i];
			double nb = number[j];
			
			int op1 = ops[i];
			int op2 = ops[j];
			number[j] = number[step];
			ops[j] = ops[step]; 
			int tt = op1 | op2;
			bool ba = true, bb = true;
			if (tt & FLAG_SUB) 
				ba = false;
			if (tt & FLAG_DIV) 
				bb = false;
      
      if (ba) {
        if (calc2(step, i, na, nb, ADD)) return true;
        if (nb != 0 && calc2(step, i, na, nb, SUB)) return true;
        if (na != nb && na != 0 && calc2(step, i, na, nb, RSUB)) return true;
      }
      
      if (bb) {
        double nmul = na * nb;
        if (calc2(step, i, na, nb, MUL)) return true;
        if (na != 0 && nb !=0) {
          double ndiv = na / nb;
          if (nmul != ndiv && calc2(step,i,na, nb, DIV)) return true; 
          double nrdiv = nb / na;          
          if (nrdiv != ndiv && nrdiv != nmul && calc2(step,i,na, nb, RDIV))
            return true;
        }
      }
      number[i] = na;
      number[j] = nb;
      ops[i] = op1;
      ops[j] = op2;
    }
  }
  return false;
}

inline bool Calc::calc2(size_t step, size_t pos2,double na,double nb, int op){
  info_step[step].op = op;
  ops[pos2] = op;
  switch (op) {
    case ADD:   number[pos2] = na + nb; break;
    case SUB:   number[pos2] = na - nb; break;
    case MUL:   number[pos2] = na * nb; break;
    case DIV:   number[pos2] = na / nb; break;
    case RSUB:  number[pos2] = nb - na; break;
    case RDIV:  number[pos2] = nb / na; break;
    default : break;
  }
  return calc(step);
}

void Calc::print_result() const
{
  if (old_number.empty()) return;
  cout << "number: ";
  for (size_t i = 0; i < old_number.size(); ++i) 
    cout << old_number[i] << " ";
  cout << "\n";
  if (! expr.empty()) std::cout << "result: " << expr << "=" << result << "\n";
  cout << "expr/func: " << count_expr << "/" << count_func << "\n\n";
}

void Calc::calc_range(int first, int last,size_t N, double M, string filename){
  if (N ==0 || first >= last || filename.empty()) return;
  clock_t ta = clock();
  vector<int> vv(N, first);
  int *end = &vv[N-1], *p = end, *arr = &vv[0];
  ofstream ofs(filename.c_str());
  size_t count = 0;
  size_t count_b = 0;
  typedef unsigned long long size_tt;
  size_tt count_expr_a = 0;
  size_tt count_expr_b = 0;
  size_tt count_func_a = 0;
  size_tt count_func_b = 0;
  while(true){
    ++count_b;
    if (run(arr,N,M,true,false)) {
      ofs.width(4);
      ofs<<  ++count << "    ";
      for (size_t i = 0; i < N; ++i) { 
       ofs.width(2);        
        ofs << arr[i]<< " ";
      }  
      ofs<< "    " << get_expr() << "\n";
      count_expr_a += count_expr;     
      count_func_a += count_func; 
   } else {
      count_expr_b += count_expr;     
     count_func_b += count_func;
    }    
    while (p >= arr && *p >= last) --p;
    if (p < arr) break;
   int tmp = ++*p;
    while (p < end) *++p = tmp;
  }
 ofs.close();
  const char sep[] = "/";
  cout << "N: " << N << "    M: " << M 
      << "       range: " << first << "-" << last << "\n"  
       << "expr:  " << count_expr_a + count_expr_b << sep <<  count_expr_a 
       << sep << count_expr_b << "\n"
       << "func:  " << count_func_a + count_func_b << sep <<  count_func_a 
       << sep << count_func_b << "\n" 
       << "total: " << count << sep << count_b << "\n"
       << "time:  " << clock() - ta << "  ms\n\n" << std::flush; 
}

int main()
{
  Calc calc;
  int ra[4]={3,3,8,8};
  int rb[4]={5,7,7,11};
  int rc[4]={4,7,11,13};
  calc.run(ra,4);
  calc.run(rb,4);
  calc.run(rc,4);
  //calc.calc_range(1,13,4,24,"nul");
  //calc.calc_range(1,50,4,24,"nul");
  //calc.calc_range(1,100,4,24,"nul");
  calc.calc_range(1,13, 4,24,"a24-04t.txt");
  calc.calc_range(1,13, 5,24,"a24-05t.txt");
  calc.calc_range(1,13, 6,24,"a24-06t.txt");
  calc.calc_range(1,13, 7,24,"a24-07t.txt");
  calc.calc_range(1,13, 8,24,"a24-08t.txt");
  calc.calc_range(1,13, 9,24,"a24-09t.txt");
  calc.calc_range(1,13,10,24,"a24-10t.txt");
}

====PS:来自《编程之美》方法====

给定4个数,能否只通过加减乘除计算得到24?由于只有4个数,弄个多重循环,就可以。如果要推广到n个数,有两种思路:

① 采用前缀/后缀表达式。相当于将n个数用n-1个括号括起来,其数目就是一个catlan数。最多可得到 f(n) = (1/n * (2*n - 2)! / (n-1)! / (n-1)!) * n! * 4^(n-1) = 4^(n-1) * (2*n-2)! / (n-1)! 种表达式,当n=4时,共可得到 7680种。

② 从n个数中任意抽取2个,进行计算最多有6种结果,将计算结果放回去,再从剩下的n-1个任取2个,进行计算,如此反复,直到只剩下1个数。按这种算法,共要处理表达式:g(n)=(n*(n-1)/2*6) * ((n-1)*(n-2)/2*6) * ((n-2)*(n-3)/2*6) * (2*1/2*6) = n!*(n-1)!*3^(n-1)当n=4时,最多要处理3888种。 (书上的代码将这两种思路混在一块了。)

f(n) / g(n) = (4/3)^(n-1) * (2*n-2)! / n! / (n-1)! / (n-1)!

很明显,当n比较大时(比如n大于8),会有 f(n) < g(n)。比如:f(10)/g(10)=0.178。

从f(n)与g(n)的比值,可以看出,这两种解法都存在大量的不必要计算。当n比较大时,思路2的冗余计算已经严重影响了性能。要如何减少这些不必要计算呢?

可以记录得到某个计算结果时所进行操作。比如: a、b、c和d这4个数取前2个,进行加法计算得到 a+b,则记录‘+’。另外,假设加减号的优先级为0,乘除号的优先级为1。

a和b进行减/除计算时,实际上得到 a-b与b-a,a/b与b/a。

当取出2个数a和b,进行计算,这两个数上次的操作符有下面这几种情况:

① 都为空:要计算6个结果,即 a+b, a-b, b-a, a*b, a/b, b/a。

② 只有一个为空:假设: a = a1 op1 a2

   ⑴ 一种剪枝方法是: 若op1为减(除)号,则不进行加减(乘除)计算。    因为: (a-b)-c可以转为a-(b+c),这两个表达式只要计算一个就可以。

 ⑵ 另一种剪枝方法:额外记录每次计算最靠后的那个数的位置。比如位置顺序:a、b、c、d,进行a+c计算时,记录了c位置,再与数b计算时,由于b位置在c位置前,不允许计算 (a+c) + b 和 (a+c) – b这样就避免了表达式 a+b+c和 a-b+c被重复计算。

③ 都不为空: 假设: a = a1 op1 a2, b= b1 op2 b2

   要计算的结果: a op3 b = (a1 op1 a2)op3 (b1 op2 b2)

   ⑴如果 op1 和 op2的优先级相同,那么 op3 的优先级不能与它们相同,若相同,则原来的表达式可以转为 ((a1 op4 a2) op5 b1) op6 b2,因而没必要对原来的表达式进行计算。比如 (m1+m2)与(m3-m4)之间只进行乘除计算,而不进行加减计算。

    ⑵如果 op1 和 op2的优先级不同,那么 op3 无论怎么取,其优先级都必会与其中一个相同,则原表达式可以转化((c1 op4 c2) op5 c3) op6 c4这种形式,因而该表达式没必要计算。如(m1+m2)与(m3*m4),不进行任何计算。

总之:op1 op2优先级不同时,不进行计算。

        op1 与 op2优先级相同时,进行计算的操作符优先级不与它们相同。

要注意的是:剪枝不一定提高性能。如果n个数计算可得到24,过多的避免冗余计算,有可能严重降低性能。计算n=6时,碰到一个组合,仅使用了③的剪枝方法,得到结果时处理了四百个表达式,但再采用了②的第一种剪枝方法,处理的表达式达到五十三万多。(也许②的第二种剪枝方法不存在这么严重的问题。)与烙饼排序不同的是,烙饼排序总能找到一个结果,而n个数计算有可能无解。显然在无解时,采用尽可能多的剪枝方法,必然会极大的提高性能。

另外,对于输出表达式,书上的程序进行了大量的字符串操作,实际上可以只记录,每一步取出的两个数的位置(即记录i、j值),在需要输出时,再根据所记录的位置,进行相应的字符串操作就可以了。

#include <iostream>
#include <string>
#include <set>
#include <cmath>
using namespace std;

bool calc(int src[], size_t N, double M = 24.0){
  if (N == 0 || src == NULL) return false;
    set<double> result[1 << N];
    for (size_t i = 0; i < N; ++i) result[1<<i].insert((double)src[i]); 
  
    for (size_t i =1; i < (1<<N); ++i) {
      for (size_t j = 1; j <= (i >> 1); ++j) {
         if ((i & j) != j) continue;
         for (set<double>::iterator p = result[j].begin(); p != result[j].end(); ++p) {
           double va = *p;
           size_t k = i ^ j;
           for (set<double>::iterator q = result[k].begin(); q != result[k].end(); ++q) {
             double vb = *q;
             result[i].insert(va + vb);
             result[i].insert(va - vb);
             result[i].insert(vb - va);
             result[i].insert(va * vb);
             if (vb != 0.0) result[i].insert(va / vb);
             if (va != 0.0) result[i].insert(vb / va);
           }
        }
      } 
    }
  
    size_t j = (1 << N) - 1;
    const double zero = 1e-9;
    for (set<double>::iterator p = result[j].begin(); p != result[j].end(); ++p) {
      if (fabs(*p - M) < zero) return true;
    }
    return false;
 }

 int main(){
    int src[]={13, 773, 28, 98, 731, 97357246};
    cout << calc(src,sizeof(src)/sizeof(src[0]))<<endl;
 }

书上给出的最后一种解法,通过使用集合记录中间结果来减少冗余计算。本以为,程序会占用大量的内存,用一个极端的例子(13, 773, 28, 98, 731, 1357,97357246这7个数)测试了一下实现的程序,发现程序竟然占用了1G以上的内存(无论集合的实现采用STL中的set还是unordered_set),但后来,取7个均在1到13之间的数,再重新测试了下发现,程序所占用的内存比想像的小的多,也就几兆。对数值都在1到13之间的n个数的所有组合进行判断。在n等于4时,实现的程序约过1秒就给出了结果,而n等于5时,程序要运行58秒,效率实在差,可以通过这几方面优化:

  1. 保存每个集合的子集合的编号:对给定的n,共有1到2^n – 1个集合,每个集合的子集合编号是固定的,但原程序每计算一个n个数的组合,都要对这些子集合编号计算一遍,可以保存每个集合的子集合编号,减少大量的重复计算。
  2.  改进计算子集合编号的算法:原程序的算法的时间复杂度是O(4^n),在n较大时,相当慢。
  3.  对最后一个集合是否含有24的判断:原程序要遍历该集合所有的元素,效率比较差,可以考虑,在将两个集合合并到该集合时,只对取出的两个元素的计算结果是否为24进行判断,这样不仅省去最后的遍历,而且不用将结果插入到最后的那个集合中,省去了大量操作。

采用1和3两种改进方法后,程序运行时间由原来的58秒缩短到了14秒,但这还不能让人满意。对2,3,5,6,8这5个数,如果用书上的第一种方法,可以只调用4次函数就可以得到结果:2+3+5+6+8=24,但用集合的方法来处理,却要对所有的集合进行子集合合并后才能给出结果,这造成了效率极差,可以这样改进该算法:

初始有n个集合:每一次任意取出2个集合,合并后,放回去,再取出任意2个集合,重复前面的操作,直到只剩下一个集合为止。

例如:初始有5个数,把这5个数分别放到5个集合,并分别编号为:1、2、4、8、16。任意取出2个集合,假设为1和4,将1和4合并,得到编号为5(=1+4)的集合,剩下的集合为:5、2、16、8,再取出2个,假设为5和8,合并后,得到13、2、16,再取2个,假设为13和16,合并后得到29、2,当剩下2个集合时,可以直接对这两个集合间的的计算结果是否为24进行判断,直接得出结果,省去不必要的合并(合并后再判断是否有元素近似等于24,程序运行时间8s多,而直接对计算结果判断,程序只要运行1s多)。

优化后的程序,只要运行1s多。但其效率还是不如书上的第一种方法的改进版,仔细想想,n越大,集合的元素也就越多,两个集合之间的合并,就越耗时间。而且采用集合保存中间结果,表面上减少了重复状态,会提高效率,但实际上,由于采用了集合,就多了很多不必要的计算,(比如,对2+3+5+6+8=24,最少只要4次计算就能得出结果,采用集合合并后,则要计算几百次(约为6^4)),再加上实现集合所采用的数据结构的开销,效率高不了。

#include <iostream>
#include <fstream>
#include <unordered_set>
#include <vector>
#include <ctime>
#include <cmath>

using namespace std;
typedef unordered_set<double> mset;

unsigned long long all_size=0;
unsigned long long big_size=0;

bool calc(int src[], size_t N, double M = 24.0){
  if (N == 0 || src == NULL) return false;
  mset result[1 << N];
  for (size_t i = 0; i < N; ++i) result[1<<i].insert((double)src[i]); 
  for (size_t i =1; i < (1<<N); ++i) {
    for (size_t j = 1; j <= (i >> 1); ++j) {
       if ((i & j) != j) continue;
       for (mset::iterator p = result[j].begin(); p != result[j].end(); ++p) {
         double va = *p;
         size_t k = i ^ j;
         for (mset::iterator q = result[k].begin(); q != result[k].end(); ++q) {
           double vb = *q;
           result[i].insert(va + vb);
           result[i].insert(va - vb);
           result[i].insert(vb - va);
           result[i].insert(va * vb);
           if (vb != 0.0) result[i].insert(va / vb);
           if (va != 0.0) result[i].insert(vb / va);
         }
       }
    }  
  }
  
  size_t j = (1 << N) - 1;
  all_size=0;
  big_size=result[j].size();
  for (size_t i =1; i < (1<<N); ++i) all_size += result[i].size();
  for (mset::iterator p = result[j].begin(); p != result[j].end(); ++p) {
    if (fabs(M - *p) < 1e-9) return true;
  }
  return false;
}

void calc_range(int first, int last,size_t N, double M = 24, string filename="nul"){
  if (N ==0 || first >= last || filename.empty()) return;
  clock_t ta = clock();
  vector<int> vv(N, first);
  int *end = &vv[N-1], *p = end, *arr = &vv[0];
  ofstream ofs(filename.c_str());
  size_t count = 0;
  size_t count_b = 0;
  unsigned long long max_big_size =0, max_all_size = 0;
  while(true){
    ++count_b;
    if (calc(arr,N, M)) {
      ofs.width(4);
      ofs<<  ++count << "    ";
      for (size_t i = 0; i < N; ++i) { 
        ofs.width(2);        
        ofs << arr[i]<< " ";
      } 
      ofs << "\n";
      if (max_big_size < big_size) max_big_size = big_size;
      if (max_all_size < all_size) max_all_size = all_size;
    }   
    while (p >= arr && *p >= last) --p;
    if (p < arr) break;
    int tmp = ++*p;
    while (p < end) *++p = tmp;
  }
  ofs.close();
  const char sep[] = "/"; 
  cout  << "total: " << count << sep << count_b << "\n"
        << max_big_size << sep << max_all_size << "\n"
   << "time:  " << clock() - ta << "\n\n"; 
}

int main(){
  calc_range(1,13,5,24,"nul");cin.get();
}

敬请关注本博客和新浪微博songzi_tea.

抱歉!评论已关闭.