现在的位置: 首页 > 综合 > 正文

OpenCV:随机决策森林CvRTrees使用实例

2018年05月13日 ⁄ 综合 ⁄ 共 10363字 ⁄ 字号 评论关闭

本文介绍:OpenCV机器学习库MLL中随机森林Random Trees的使用

参考文献:

1.Breiman,
Leo
 (2001). "Random Forests". Machine
Learning
 

2.Random
Forests网站

不熟悉MLL的参考此文:OpenCV机器学习库MLL

OpenCV的机器学习算法都比较简单:train ——>predict

  1. class CV_EXPORTS_W CvRTrees : public CvStatModel  
  2. {  
  3. public:  
  4.     CV_WRAP CvRTrees();  
  5.     virtual ~CvRTrees();  
  6.     virtual bool train( const CvMat* trainData, int tflag,  
  7.                         const CvMat* responses, const CvMat* varIdx=0,  
  8.                         const CvMat* sampleIdx=0, const CvMat* varType=0,  
  9.                         const CvMat* missingDataMask=0,  
  10.                         CvRTParams params=CvRTParams() );  
  11.   
  12.     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );  
  13.     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;  
  14.     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;  
  15.   
  16.     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,  
  17.                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),  
  18.                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),  
  19.                        const cv::Mat& missingDataMask=cv::Mat(),  
  20.                        CvRTParams params=CvRTParams() );  
  21.     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;  
  22.     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;  
  23.     CV_WRAP virtual cv::Mat getVarImportance();  
  24.   
  25.     CV_WRAP virtual void clear();  
  26.   
  27.     virtual const CvMat* get_var_importance();  
  28.     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,  
  29.         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;  
  30.   
  31.     virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}  
  32.   
  33.     virtual float get_train_error();  
  34.   
  35.     virtual void read( CvFileStorage* fs, CvFileNode* node );  
  36.     virtual void write( CvFileStorage* fs, const char* name ) const;  
  37.   
  38.     CvMat* get_active_var_mask();  
  39.     CvRNG* get_rng();  
  40.   
  41.     int get_tree_count() const;  
  42.     CvForestTree* get_tree(int i) const;  
  43.   
  44. protected:  
  45.     virtual std::string getName() const;  
  46.   
  47.     virtual bool grow_forest( const CvTermCriteria term_crit );  
  48.   
  49.     // array of the trees of the forest  
  50.     CvForestTree** trees;  
  51.     CvDTreeTrainData* data;  
  52.     int ntrees;  
  53.     int nclasses;  
  54.     double oob_error;  
  55.     CvMat* var_importance;  
  56.     int nsamples;  
  57.   
  58.     cv::RNG* rng;  
  59.     CvMat* active_var_mask;  
  60. };  


使用CvRTrees类,来对手写体数据作分类

  1. // Example : random forest (tree) learning  
  2. // Author : Toby Breckon, toby.breckon@cranfield.ac.uk  
  3.   
  4. // Copyright (c) 2011 School of Engineering, Cranfield University  
  5. // License : LGPL - http://www.gnu.org/licenses/lgpl.html  
  6.   
  7. #include <cv.h>       // opencv general include file  
  8. #include <ml.h>         // opencv machine learning include file  
  9. #include <stdio.h>  
  10.   
  11. using namespace cv; // OpenCV API is in the C++ "cv" namespace  
  12.   
  13. /******************************************************************************/  
  14. // global definitions (for speed and ease of use)  
  15. //手写体数字识别  
  16.   
  17. #define NUMBER_OF_TRAINING_SAMPLES 3823  
  18. #define ATTRIBUTES_PER_SAMPLE 64  
  19. #define NUMBER_OF_TESTING_SAMPLES 1797  
  20.   
  21. #define NUMBER_OF_CLASSES 10  
  22.   
  23. // N.B. classes are integer handwritten digits in range 0-9  
  24.   
  25. /******************************************************************************/  
  26.   
  27. // loads the sample database from file (which is a CSV text file)  
  28.   
  29. int read_data_from_csv(const char* filename, Mat data, Mat classes,  
  30.                        int n_samples )  
  31. {  
  32.     float tmp;  
  33.   
  34.     // if we can't read the input file then return 0  
  35.     FILE* f = fopen( filename, "r" );  
  36.     if( !f )  
  37.     {  
  38.         printf("ERROR: cannot read file %s\n",  filename);  
  39.         return 0; // all not OK  
  40.     }  
  41.   
  42.     // for each sample in the file  
  43.   
  44.     for(int line = 0; line < n_samples; line++)  
  45.     {  
  46.         // for each attribute on the line in the file  
  47.         for(int attribute = 0; attribute < (ATTRIBUTES_PER_SAMPLE + 1); attribute++)  
  48.         {  
  49.             if (attribute < 64)  
  50.             {  
  51.                 // first 64 elements (0-63) in each line are the attributes  
  52.                 fscanf(f, "%f,", &tmp);  
  53.                 data.at<float>(line, attribute) = tmp;  
  54.                 // printf("%f,", data.at<float>(line, attribute));  
  55.             }  
  56.             else if (attribute == 64)  
  57.             {  
  58.                 // attribute 65 is the class label {0 ... 9}  
  59.                 fscanf(f, "%f,", &tmp);  
  60.                 classes.at<float>(line, 0) = tmp;  
  61.                 // printf("%f\n", classes.at<float>(line, 0));  
  62.             }  
  63.         }  
  64.     }  
  65.   
  66.     fclose(f);  
  67.     return 1; // all OK  
  68. }  
  69.   
  70. /******************************************************************************/  
  71.   
  72. int main( int argc, char** argv )  
  73. {  
  74.       
  75.     for (int i=0; i< argc; i++)  
  76.         std::cout<<argv[i]<<std::endl;  
  77.       
  78.       
  79.     // lets just check the version first  
  80.     printf ("OpenCV version %s (%d.%d.%d)\n",  
  81.             CV_VERSION,  
  82.             CV_MAJOR_VERSION, CV_MINOR_VERSION, CV_SUBMINOR_VERSION);  
  83.       
  84.     //定义训练数据与标签矩阵  
  85.     Mat training_data = Mat(NUMBER_OF_TRAINING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);  
  86.     Mat training_classifications = Mat(NUMBER_OF_TRAINING_SAMPLES, 1, CV_32FC1);  
  87.   
  88.     //定义测试数据矩阵与标签  
  89.     Mat testing_data = Mat(NUMBER_OF_TESTING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);  
  90.     Mat testing_classifications = Mat(NUMBER_OF_TESTING_SAMPLES, 1, CV_32FC1);  
  91.   
  92.     // define all the attributes as numerical  
  93.     // alternatives are CV_VAR_CATEGORICAL or CV_VAR_ORDERED(=CV_VAR_NUMERICAL)  
  94.     // that can be assigned on a per attribute basis  
  95.   
  96.     Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U );  
  97.     var_type.setTo(Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical  
  98.   
  99.     // this is a classification problem (i.e. predict a discrete number of class  
  100.     // outputs) so reset the last (+1) output var_type element to CV_VAR_CATEGORICAL  
  101.   
  102.     var_type.at<uchar>(ATTRIBUTES_PER_SAMPLE, 0) = CV_VAR_CATEGORICAL;  
  103.   
  104.     double result; // value returned from a prediction  
  105.   
  106.     //加载训练数据集和测试数据集  
  107.     if (read_data_from_csv(argv[1], training_data, training_classifications, NUMBER_OF_TRAINING_SAMPLES) &&  
  108.             read_data_from_csv(argv[2], testing_data, testing_classifications, NUMBER_OF_TESTING_SAMPLES))  
  109.     {  
  110.       /********************************步骤1:定义初始化Random Trees的参数******************************/  
  111.         float priors[] = {1,1,1,1,1,1,1,1,1,1};  // weights of each classification for classes  
  112.         CvRTParams params = CvRTParams(25, // max depth  
  113.                                        5, // min sample count  
  114.                                        0, // regression accuracy: N/A here  
  115.                                        false// compute surrogate split, no missing data  
  116.                                        15, // max number of categories (use sub-optimal algorithm for larger numbers)  
  117.                                        priors, // the array of priors  
  118.                                        false,  // calculate variable importance  
  119.                                        4,       // number of variables randomly selected at node and used to find the best split(s).  
  120.                                        100,  // max number of trees in the forest  
  121.                                        0.01f,               // forrest accuracy  
  122.                                        CV_TERMCRIT_ITER |   CV_TERMCRIT_EPS // termination cirteria  
  123.                                       );  
  124.   
  125.         /****************************步骤2:训练 Random Decision Forest(RDF)分类器*********************/  
  126.         printf( "\nUsing training database: %s\n\n", argv[1]);  
  127.         CvRTrees* rtree = new CvRTrees;  
  128.         rtree->train(training_data, CV_ROW_SAMPLE, training_classifications,  
  129.                      Mat(), Mat(), var_type, Mat(), params);  
  130.   
  131.         // perform classifier testing and report results  
  132.         Mat test_sample;  
  133.         int correct_class = 0;  
  134.         int wrong_class = 0;  
  135.         int false_positives [NUMBER_OF_CLASSES] = {0,0,0,0,0,0,0,0,0,0};  
  136.   
  137.         printf( "\nUsing testing database: %s\n\n", argv[2]);  
  138.   
  139.         for (int tsample = 0; tsample < NUMBER_OF_TESTING_SAMPLES; tsample++)  
  140.         {  
  141.   
  142.             // extract a row from the testing matrix  
  143.             test_sample = testing_data.row(tsample);  
  144.         /********************************步骤3:预测*********************************************/  
  145.             result = rtree->predict(test_sample, Mat());  
  146.   
  147.             printf("Testing Sample %i -> class result (digit %d)\n", tsample, (int) result);  
  148.   
  149.             // if the prediction and the (true) testing classification are the same  
  150.             // (N.B. openCV uses a floating point decision tree implementation!)  
  151.             if (fabs(result - testing_classifications.at<float>(tsample, 0))  
  152.                     >= FLT_EPSILON)  
  153.             {  
  154.                 // if they differ more than floating point error => wrong class  
  155.                 wrong_class++;  
  156.                 false_positives[(int) result]++;  
  157.             }  
  158.             else  
  159.             {  
  160.                 // otherwise correct  
  161.                 correct_class++;  
  162.             }  
  163.         }  
  164.   
  165.         printf( "\nResults on the testing database: %s\n"  
  166.                 "\tCorrect classification: %d (%g%%)\n"  
  167.                 "\tWrong classifications: %d (%g%%)\n",  
  168.                 argv[2],  
  169.                 correct_class, (double) correct_class*100/NUMBER_OF_TESTING_SAMPLES,  
  170.                 wrong_class, (double) wrong_class*100/NUMBER_OF_TESTING_SAMPLES);  
  171.   
  172.         for (int i = 0; i < NUMBER_OF_CLASSES; i++)  
  173.         {  
  174.             printf( "\tClass (digit %d) false postives  %d (%g%%)\n", i,  
  175.                     false_positives[i],  
  176.                     (double) false_positives[i]*100/NUMBER_OF_TESTING_SAMPLES);  
  177.         }  
  178.   
  179.         // all matrix memory free by destructors  
  180.   
  181.         // all OK : main returns 0  
  182.         return 0;  
  183.     }  
  184.   
  185.     // not OK : main returns -1  
  186.     return -1;  
  187. }  
  188. /******************************************************************************/  

=============================================================================

手写体数据:

设置数据集 train test:

在test数据集上的正确率:

抱歉!评论已关闭.