很久没写含这么多stl的程序了,很故意的用set,map,vector,熟手一下。
也记录一下吧,虽然写得比较渣。
三个文件:
测试数据:data.txt
D1 Sunny Hot High Weak No D2 Sunny Hot High Strong No D3 Overcast Hot High Weak Yes D4 Rain Mild High Weak Yes D5 Rain Cool Normal Weak Yes D6 Rain Cool Normal Strong No D7 Overcast Cool Normal Strong Yes D8 Sunny Mild High Weak No D9 Sunny Cool Normal Weak Yes D10 Rain Mild Normal Weak Yes D11 Sunny Mild Normal Strong Yes D12 Overcast Mild High Strong Yes D13 Overcast Hot Normal Weak Yes D14 Rain Mild High Strong No
程序头文件:id3.h
#ifndef ID3_H #define ID3_H #include<fstream> #include<iostream> #include<vector> #include<map> #include<set> #include<cmath> using namespace std; const int DataRow=14; const int DataColumn=6; struct Node { double value;//代表此时yes的概率。 int attrid; Node * parentNode; vector<Node*> childNode; }; #endif
程序源文件id3.cpp
#include "id3.h" string DataTable[DataRow][DataColumn]; map<string,int> str2int; set<int> S; set<int> Attributes; string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"}; string attrValue[DataColumn][DataRow]= { {},//D1,D2这个属性不需要 {"Sunny","Overcast","Rain"}, {"Hot","Mild","Cool"}, {"High","Normal"}, {"Weak","Strong"}, {"No","Yes"} }; int attrCount[DataColumn]={14,3,3,2,2,2}; double lg2(double n) { return log(n)/log(2); } void Init() { ifstream fin("data.txt"); for(int i=0;i<14;i++) { for(int j=0;j<6;j++) { fin>>DataTable[i][j]; } } fin.close(); for(int i=1;i<=5;i++) { str2int[attrName[i]]=i; for(int j=0;j<attrCount[i];j++) { str2int[attrValue[i][j]]=j; } } for(int i=0;i<DataRow;i++) S.insert(i); for(int i=1;i<=4;i++) Attributes.insert(i); } double Entropy(const set<int> &s) { double yes=0,no=0,sum=s.size(),ans=0; for(set<int>::iterator it=s.begin();it!=s.end();it++) { string s=DataTable[*it][str2int["PlayTennis"]]; if(s=="Yes") yes++; else no++; } if(no==0||yes==0) return ans=0; ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum); return ans; } double Gain(const set<int> & example,int attrid) { int attrcount=attrCount[attrid]; double ans=Entropy(example); double sum=example.size(); set<int> * pset=new set<int>[attrcount]; for(set<int>::iterator it=example.begin();it!=example.end();it++) { pset[str2int[DataTable[*it][attrid]]].insert(*it); } for(int i=0;i<attrcount;i++) { ans-=pset[i].size()/sum*Entropy(pset[i]); } return ans; } int FindBestAttribute(const set<int> & example,const set<int> & attr) { double mx=0; int k=-1; for(set<int>::iterator i=attr.begin();i!=attr.end();i++) { double ret=Gain(example,*i); if(ret>mx) { mx=ret; k=*i; } } if(k==-1) cout<<"FindBestAttribute error!"<<endl; return k; } Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent) { Node *now=new Node;//创建树节点。 now->parentNode=parent; if(attributes.empty())//如果此时属性列表已用完,即为空,则返回。 return now; /* * 统计一下example,如果都为正或者都为负则表示已经抵达决策树的叶子节点 * 叶子节点的特征是有childNode为空。 */ int yes=0,no=0,sum=example.size(); for(set<int>::iterator it=example.begin();it!=example.end();it++) { string s=DataTable[*it][str2int["PlayTennis"]]; if(s=="Yes") yes++; else no++; } if(yes==sum||yes==0) { now->value=yes/sum; return now; } /*找到最高信息增益的属性并将该属性从attributes集合中删除*/ int bestattrid=FindBestAttribute(example,attributes); now->attrid=bestattrid; attributes.erase(attributes.find(bestattrid)); /*将exmple根据最佳属性的不同属性值分成几个分支,每个分支有即一个子树*/ vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]); for(set<int>::iterator i=example.begin();i!=example.end();i++) { int id=str2int[DataTable[*i][bestattrid]]; child[id].insert(*i); } for(int i=0;i<child.size();i++) { Node * ret=Id3_solution(child[i],attributes,now); now->childNode.push_back(ret); } return now; } int main() { Init(); Node * Root=Id3_solution(S,Attributes,NULL); return 0; }