#include<iostream> #include<fstream> #include<sstream> #include<string> #include<map> #include<list> #include<set> #include<queue> #include<utility> #include<vector> #include<cmath> using namespace std; //置信水平取0.95时的卡方表 const double CHI[18]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962}; /*根据多维数组计算卡方值*/ template<typename Comparable> double cal_chi(Comparable **arr,int row,int col){ vector<Comparable> rowsum(row); vector<Comparable> colsum(col); Comparable totalsum=static_cast<Comparable>(0); //cout<<"observation"<<endl; for(int i=0;i<row;++i){ for(int j=0;j<col;++j){ //cout<<arr[i][j]<<"\t"; totalsum+=arr[i][j]; rowsum[i]+=arr[i][j]; colsum[j]+=arr[i][j]; } //cout<<endl; } double rect=0.0; //cout<<"exception"<<endl; for(int i=0;i<row;++i){ for(int j=0;j<col;++j){ double excep=1.0*rowsum[i]*colsum[j]/totalsum; //cout<<excep<<"\t"; if(excep!=0) rect+=pow(arr[i][j]-excep,2.0)/excep; } //cout<<endl; } return rect; } class MyTriple{ public: double first; int second; int third; MyTriple(){ first=0.0; second=0; third=0; } MyTriple(double f,int s,int t):first(f),second(s),third(t){} bool operator< (const MyTriple &obj) const{ int cmp=this->first-obj.first; if(cmp>0) return false; else if(cmp<0) return true; else{ cmp=obj.second-this->second; if(cmp<0) return true; else return false; } } }; typedef map<string,int> MAP_REST_COUNT; typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST; typedef vector<MAP_ATTR_REST> VEC_STATI; const int ATTR_NUM=8; //自变量的维度 vector<string> X(ATTR_NUM); int rest_number; //因变量的种类数,即类别数 vector<pair<string,int> > classes; //把类别、对应的记录数存放在一个数组中 int total_record_number; //总的记录数 vector<vector<string> > inputData; //原始输入数据 class node{ public: node* parent; //父节点 node* leftchild; //左孩子节点 node* rightchild; //右孩子节点 string cond; //分枝条件 string decision; //在该节点上作出的类别判定 double precision; //判定的正确率 int record_number; //该节点上涵盖的记录个数 int size; //子树包含的叶子节点的数目 int index; //层次遍历树,给节点标上序号 double alpha; //表面误差率的增加量 node(){ parent=NULL; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } node(node* p){ parent=p; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } node(node* p,string c,string d):cond(c),decision(d){ parent=p; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } void printInfo(){ cout<<"index:"<<index<<"\tdecisoin:"<<decision<<"\tprecision:"<<precision<<"\tcondition:"<<cond<<"\tsize:"<<size; if(parent!=NULL) cout<<"\tparent index:"<<parent->index; if(leftchild!=NULL) cout<<"\tleftchild:"<<leftchild->index<<"\trightchild:"<<rightchild->index; cout<<endl; } void printTree(){ printInfo(); if(leftchild!=NULL) leftchild->printTree(); if(rightchild!=NULL) rightchild->printTree(); } }; int readInput(string filename){ ifstream ifs(filename.c_str()); if(!ifs){ cerr<<"open inputfile failed!"<<endl; return -1; } map<string,int> catg; string line; getline(ifs,line); string item; istringstream strstm(line); strstm>>item; for(int i=0;i<X.size();++i){ strstm>>item; X[i]=item; } while(getline(ifs,line)){ vector<string> conts(ATTR_NUM+2); istringstream strstm(line); //strstm.str(line); for(int i=0;i<conts.size();++i){ strstm>>item; conts[i]=item; if(i==conts.size()-1) catg[item]++; } inputData.push_back(conts); } total_record_number=inputData.size(); ifs.close(); map<string,int>::const_iterator itr=catg.begin(); while(itr!=catg.end()){ classes.push_back(make_pair(itr->first,itr->second)); itr++; } rest_number=classes.size(); return 0; } /*根据inputData作出一个统计stati*/ void statistic(vector<vector<string> > &inputData,VEC_STATI &stati){ for(int i=1;i<ATTR_NUM+1;++i){ MAP_ATTR_REST attr_rest; for(int j=0;j<inputData.size();++j){ string attr_value=inputData[j][i]; string rest=inputData[j][ATTR_NUM+1]; MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value); if(itr==attr_rest.end()){ MAP_REST_COUNT rest_count; rest_count[rest]=1; attr_rest[attr_value]=rest_count; } else{ MAP_REST_COUNT::iterator iter=(itr->second).find(rest); if(iter==(itr->second).end()){ (itr->second).insert(make_pair(rest,1)); } else{ iter->second+=1; } } } stati.push_back(attr_rest); } } /*依据某条件作出分枝时,inputData被分成两部分*/ void splitInput(vector<vector<string> > &inputData,int fitIndex,string cond,vector<vector<string> > &LinputData,vector<vector<string> > &RinputData){ for(int i=0;i<inputData.size();++i){ if(inputData[i][fitIndex+1]==cond) LinputData.push_back(inputData[i]); else RinputData.push_back(inputData[i]); } } void printStati(VEC_STATI &stati){ for(int i=0;i<stati.size();i++){ MAP_ATTR_REST::const_iterator itr=stati[i].begin(); while(itr!=stati[i].end()){ cout<<itr->first; MAP_REST_COUNT::const_iterator iter=(itr->second).begin(); while(iter!=(itr->second).end()){ cout<<"\t"<<iter->first<<"\t"<<iter->second; iter++; } itr++; cout<<endl; } cout<<endl; } } void split(node *root,vector<vector<string> > &inputData,vector<pair<string,int> > classes){ //root->printInfo(); root->record_number=inputData.size(); VEC_STATI stati; statistic(inputData,stati); //printStati(stati); //for(int i=0;i<rest_number;i++) // cout<<classes[i].first<<"\t"<<classes[i].second<<"\t"; //cout<<endl; /*找到最大化GINI指标的划分*/ double minGain=1.0; //最小的GINI增益 int fitIndex=-1; string fitCond; vector<pair<string,int> > fitleftclasses; vector<pair<string,int> > fitrightclasses; int fitleftnumber; int fitrightnumber; for(int i=0;i<stati.size();++i){ //扫描每一个自变量 MAP_ATTR_REST::const_iterator itr=stati[i].begin(); while(itr!=stati[i].end()){ //扫描自变量上的每一个取值 string condition=itr->first; //判定的条件,即到达左孩子的条件 //cout<<"cond 为"<<X[i]+condition<<"时:"; vector<pair<string,int> > leftclasses(classes); //左孩子节点上类别、及对应的数目 vector<pair<string,int> > rightclasses(classes); //右孩子节点上类别、及对应的数目 int leftnumber=0; //左孩子节点上包含的类别数目 int rightnumber=0; //右孩子节点上包含的类别数目 for(int j=0;j<leftclasses.size();++j){ //更新类别对应的数目 string rest=leftclasses[j].first; MAP_REST_COUNT::const_iterator iter2; iter2=(itr->second).find(rest); if(iter2==(itr->second).end()){ //没找到 leftclasses[j].second=0; rightnumber+=rightclasses[j].second; } else{ //找到 leftclasses[j].second=iter2->second; leftnumber+=leftclasses[j].second; rightclasses[j].second-=(iter2->second); rightnumber+=rightclasses[j].second; } } /**if(leftnumber==0 || rightnumber==0){ cout<<"左右有一边为空"<<endl; for(int k=0;k<rest_number;k++) cout<<leftclasses[k].first<<"\t"<<leftclasses[k].second<<"\t"; cout<<endl; for(int k=0;k<rest_number;k++) cout<<rightclasses[k].first<<"\t"<<rightclasses[k].second<<"\t"; cout<<endl; }**/ double gain1=1.0; //计算GINI增益 double gain2=1.0; if(leftnumber==0) gain1=0.0; else for(int j=0;j<leftclasses.size();++j) gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0); if(rightnumber==0) gain2=0.0; else for(int j=0;j<rightclasses.size();++j) gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0); double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2; //cout<<"GINI增益:"<<gain<<endl; if(gain<minGain){ //cout<<"GINI增益:"<<gain<<"\t"<<i<<"\t"<<condition<<endl; fitIndex=i; fitCond=condition; fitleftclasses=leftclasses; fitrightclasses=rightclasses; fitleftnumber=leftnumber; fitrightnumber=rightnumber; minGain=gain; } itr++; } } /*计算卡方值,看有没有必要进行分裂*/ //cout<<"按"<<X[fitIndex]+fitCond<<"划分,计算卡方"<<endl; int **arr=new int*[2]; for(int i=0;i<2;i++) arr[i]=new int[rest_number]; for(int i=0;i<rest_number;i++){ arr[0][i]=fitleftclasses[i].second; arr[1][i]=fitrightclasses[i].second; } double chi=cal_chi(arr,2,rest_number); //cout<<"chi="<<chi<<" CHI="<<CHI[rest_number-2]<<endl; if(chi<CHI[rest_number-2]){ //独立,没必要再分裂了 delete []arr[0]; delete []arr[1]; delete []arr; return; //不需要分裂函数就返回 } delete []arr[0]; delete []arr[1]; delete []arr; /*分裂*/ root->cond=X[fitIndex]+"="+fitCond; //root的分枝条件 //cout<<"分类条件:"<<root->cond<<endl; node *travel=root; //root及其祖先节点的size都要加1 while(travel!=NULL){ (travel->size)++; travel=travel->parent; } node *LChild=new node(root); //创建左右孩子 node *RChild=new node(root); root->leftchild=LChild; root->rightchild=RChild; int maxLcount=0; int maxRcount=0; string Ldicision,Rdicision; for(int i=0;i<rest_number;++i){ //统计哪种类别出现的最多,从而作出类别判定 if(fitleftclasses[i].second>maxLcount){ maxLcount=fitleftclasses[i].second; Ldicision=fitleftclasses[i].first; } if(fitrightclasses[i].second>maxRcount){ maxRcount=fitrightclasses[i].second; Rdicision=fitrightclasses[i].first; } } LChild->decision=Ldicision; RChild->decision=Rdicision; LChild->precision=1.0*maxLcount/fitleftnumber; RChild->precision=1.0*maxRcount/fitrightnumber; /*递归对左右孩子进行分裂*/ vector<vector<string> > LinputData,RinputData; splitInput(inputData,fitIndex,fitCond,LinputData,RinputData); //cout<<"左边inputData行数:"<<LinputData.size()<<endl; //cout<<"右边inputData行数:"<<RinputData.size()<<endl; split(LChild,LinputData,fitleftclasses); split(RChild,RinputData,fitrightclasses); } /*计算子树的误差代价*/ double calR2(node *root){ if(root->leftchild==NULL) return (1-root->precision)*root->record_number/total_record_number; else return calR2(root->leftchild)+calR2(root->rightchild); } /*层次遍历树,给节点标上序号。同时计算alpha*/ void index(node *root,priority_queue<MyTriple> &pq){ int i=1; queue<node*> que; que.push(root); while(!que.empty()){ node* n=que.front(); que.pop(); n->index=i++; if(n->leftchild!=NULL){ que.push(n->leftchild); que.push(n->rightchild); //计算表面误差率的增量 double r1=(1-n->precision)*n->record_number/total_record_number; //节点的误差代价 double r2=calR2(n); n->alpha=(r1-r2)/(n->size-1); pq.push(MyTriple(n->alpha,n->size,n->index)); } } } /*剪枝*/ void prune(node *root,priority_queue<MyTriple> &pq){ MyTriple triple=pq.top(); int i=triple.third; queue<node*> que; que.push(root); while(!que.empty()){ node* n=que.front(); que.pop(); if(n->index==i){ cout<<"将要剪掉"<<i<<"的左右子树"<<endl; n->leftchild=NULL; n->rightchild=NULL; int s=n->size-1; node *trav=n; while(trav!=NULL){ trav->size-=s; trav=trav->parent; } break; } else if(n->leftchild!=NULL){ que.push(n->leftchild); que.push(n->rightchild); } } } void test(string filename,node *root){ ifstream ifs(filename.c_str()); if(!ifs){ cerr<<"open inputfile failed!"<<endl; return; } string line; getline(ifs,line); string item; istringstream strstm(line); //跳过第一行 map<string,string> independent; //自变量,即分类的依据 while(getline(ifs,line)){ istringstream strstm(line); //strstm.str(line); strstm>>item; cout<<item<<"\t"; for(int i=0;i<ATTR_NUM;++i){ strstm>>item; independent[X[i]]=item; } node *trav=root; while(trav!=NULL){ if(trav->leftchild==NULL){ cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;; break; } string cond=trav->cond; string::size_type pos=cond.find("="); string pre=cond.substr(0,pos); string post=cond.substr(pos+1); if(independent[pre]==post) trav=trav->leftchild; else trav=trav->rightchild; } } ifs.close(); } int main(){ string inputFile="animal"; readInput(inputFile); VEC_STATI stati; //最原始的统计 statistic(inputData,stati); // for(int i=0;i<classes.size();++i) // cout<<classes[i].first<<"\t"<<classes[i].second<<"\t"; // cout<<endl; node *root=new node(); split(root,inputData,classes); //分裂根节点 priority_queue<MyTriple> pq; index(root,pq); root->printTree(); cout<<"剪枝前使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl; test(inputFile,root); prune(root,pq); cout<<"剪枝后使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl; test(inputFile,root); return 0; }
人 恒温
毛发 是 否
否 否
是 否 哺乳类
巨蟒 冷血 鳞片
否 是
否 否 否
是 爬行类
鲑鱼 冷血 鳞片
否 是
否 是 否
否 鱼类
鲸 恒温 毛发
是 否
否 是 否
否 哺乳类
蛙 冷血 无
否 是
否 有时 是
是 两栖类
巨蜥 冷血 鳞片
否 是
否 否 是
否 爬行类
蝙蝠 恒温 毛发
是 否
是 否 是
否 哺乳类
猫 恒温 皮
是 否
否 否 是
否 哺乳类
豹纹鲨 冷血
鳞片 是 否
否 是
否 否 鱼类
海龟 冷血 鳞片
否 是
否 有时 是
否 爬行类
豪猪 恒温 刚毛
是 否
否 否 是
是 哺乳类
鳗 冷血 鳞片
否 是
否 是 否
否 鱼类
蝾螈 冷血 无
否 是
否 有时 是
是 两栖类