现在的位置: 首页 > 综合 > 正文

OpenCV:机器学习—Statistical Model

2018年05月13日 ⁄ 综合 ⁄ 共 6650字 ⁄ 字号 评论关闭

来看看MLL的主要构成:Statistical Model是个基类,下面的K-NN、SVM等都是其子类。

不太喜欢这个Statistical定语,Statistics在ML界横行的好多年,感觉温度已经降下来了。

来看下Statistical Model:

  1. class CV_EXPORTS_W CvStatModel  
  2. {  
  3. public:  
  4.     CvStatModel();  
  5.     virtual ~CvStatModel();  
  6.   
  7.     virtual void clear();  
  8.   
  9.     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;  
  10.     CV_WRAP virtual void load( const char* filename, const char* name=0 );  
  11.   
  12.     virtual void write( CvFileStorage* storage, const char* name ) const;  
  13.     virtual void read( CvFileStorage* storage, CvFileNode* node );  
  14.     virtual bool train(const Mat& train_data, const Mat& responses, Mat(), Mat(), CVParms params );  
  15.     virtual float predict(const Mat& sample, ...);  
  16.   
  17. protected:  
  18.     const char* default_model_name;  
  19. };  

void CvStatModel::clear()         清除内存重置模型状态;

void CvStatModel::save() /load()         保存/加载文件和模型;

void CvStatModel:read() /write()         读写文件和模型;

bool CvStatModel::train()     训练模型;

float CvStatModel::predict()  预测样本结果;

那么朴素贝叶斯、K-近邻、支持向量机、决策树等类都是继承CVStatModel;

使用这些方法的基本框架就是:

Method.train(train_data, responses, Mat(), Mat(),  params);

Method.predict(sampleMat);

======================================================
一个具体的例子<Support
Vector Machines for Non-Linearly Separable Data
>

  1. #include <iostream>  
  2. #include <opencv2/core/core.hpp>  
  3. #include <opencv2/highgui/highgui.hpp>  
  4. #include <opencv2/ml/ml.hpp>  
  5.   
  6. #define NTRAINING_SAMPLES   100         // Number of training samples per class  
  7. #define FRAC_LINEAR_SEP     0.9f        // Fraction of samples which compose the linear separable part  
  8.   
  9. using namespace cv;  
  10. using namespace std;  
  11.   
  12. void help()  
  13. {  
  14.     cout<< "\n--------------------------------------------------------------------------" << endl  
  15.         << "This program shows Support Vector Machines for Non-Linearly Separable Data. " << endl  
  16.         << "Usage:"                                                               << endl  
  17.         << "./non_linear_svms" << endl  
  18.         << "--------------------------------------------------------------------------"   << endl  
  19.         << endl;  
  20. }  
  21.   
  22. int main()  
  23. {  
  24.     help();  
  25.   
  26.     // Data for visual representation  
  27.     const int WIDTH = 512, HEIGHT = 512;  
  28.     Mat I = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);  
  29.   
  30.     //--------------------- 1. Set up training data randomly ---------------------------------------  
  31.     Mat trainData(2*NTRAINING_SAMPLES, 2, CV_32FC1);  
  32.     Mat labels   (2*NTRAINING_SAMPLES, 1, CV_32FC1);  
  33.       
  34.     RNG rng(100); // Random value generation class  
  35.   
  36.     // Set up the linearly separable part of the training data  
  37.     int nLinearSamples = (int) (FRAC_LINEAR_SEP * NTRAINING_SAMPLES);  
  38.   
  39.     // Generate random points for the class 1  
  40.     Mat trainClass = trainData.rowRange(0, nLinearSamples);  
  41.     // The x coordinate of the points is in [0, 0.4)  
  42.     Mat c = trainClass.colRange(0, 1);  
  43.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * WIDTH));  
  44.     // The y coordinate of the points is in [0, 1)  
  45.     c = trainClass.colRange(1,2);  
  46.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));  
  47.   
  48.     // Generate random points for the class 2  
  49.     trainClass = trainData.rowRange(2*NTRAINING_SAMPLES-nLinearSamples, 2*NTRAINING_SAMPLES);  
  50.     // The x coordinate of the points is in [0.6, 1]  
  51.     c = trainClass.colRange(0 , 1);   
  52.     rng.fill(c, RNG::UNIFORM, Scalar(0.6*WIDTH), Scalar(WIDTH));  
  53.     // The y coordinate of the points is in [0, 1)  
  54.     c = trainClass.colRange(1,2);  
  55.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));  
  56.   
  57.     //------------------ Set up the non-linearly separable part of the training data ---------------  
  58.   
  59.     // Generate random points for the classes 1 and 2  
  60.     trainClass = trainData.rowRange(  nLinearSamples, 2*NTRAINING_SAMPLES-nLinearSamples);  
  61.     // The x coordinate of the points is in [0.4, 0.6)  
  62.     c = trainClass.colRange(0,1);  
  63.     rng.fill(c, RNG::UNIFORM, Scalar(0.4*WIDTH), Scalar(0.6*WIDTH));   
  64.     // The y coordinate of the points is in [0, 1)  
  65.     c = trainClass.colRange(1,2);  
  66.     rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));  
  67.       
  68.     //------------------------- Set up the labels for the classes ---------------------------------  
  69.     labels.rowRange(                0,   NTRAINING_SAMPLES).setTo(1);  // Class 1  
  70.     labels.rowRange(NTRAINING_SAMPLES, 2*NTRAINING_SAMPLES).setTo(2);  // Class 2  
  71.   
  72.     //------------------------ 2. Set up the support vector machines parameters --------------------  
  73.     CvSVMParams params;  
  74.     params.svm_type    = SVM::C_SVC;  
  75.     params.C           = 0.1;  
  76.     params.kernel_type = SVM::LINEAR;  
  77.     params.term_crit   = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);  
  78.   
  79.     //------------------------ 3. Train the svm ----------------------------------------------------  
  80.     cout << "Starting training process" << endl;  
  81.     CvSVM svm;  
  82.     svm.train(trainData, labels, Mat(), Mat(), params);  
  83.     cout << "Finished training process" << endl;  
  84.       
  85.     //------------------------ 4. Show the decision regions ----------------------------------------  
  86.     Vec3b green(0,100,0), blue (100,0,0);  
  87.     for (int i = 0; i < I.rows; ++i)  
  88.         for (int j = 0; j < I.cols; ++j)  
  89.         {  
  90.             Mat sampleMat = (Mat_<float>(1,2) << i, j);  
  91.             float response = svm.predict(sampleMat);  
  92.   
  93.             if      (response == 1)    I.at<Vec3b>(j, i)  = green;  
  94.             else if (response == 2)    I.at<Vec3b>(j, i)  = blue;  
  95.         }  
  96.   
  97.     //----------------------- 5. Show the training data --------------------------------------------  
  98.     int thick = -1;  
  99.     int lineType = 8;  
  100.     float px, py;  
  101.     // Class 1  
  102.     for (int i = 0; i < NTRAINING_SAMPLES; ++i)  
  103.     {  
  104.         px = trainData.at<float>(i,0);  
  105.         py = trainData.at<float>(i,1);  
  106.         circle(I, Point( (int) px,  (int) py ), 3, Scalar(0, 255, 0), thick, lineType);  
  107.     }  
  108.     // Class 2  
  109.     for (int i = NTRAINING_SAMPLES; i <2*NTRAINING_SAMPLES; ++i)  
  110.     {  
  111.         px = trainData.at<float>(i,0);  
  112.         py = trainData.at<float>(i,1);  
  113.         circle(I, Point( (int) px, (int) py ), 3, Scalar(255, 0, 0), thick, lineType);  
  114.     }  
  115.   
  116.     //------------------------- 6. Show support vectors --------------------------------------------  
  117.     thick = 2;  
  118.     lineType  = 8;  
  119.     int x     = svm.get_support_vector_count();  
  120.   
  121.     for (int i = 0; i < x; ++i)  
  122.     {  
  123.         const float* v = svm.get_support_vector(i);  
  124.         circle( I,  Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thick, lineType);  
  125.     }  
  126.   
  127.     imwrite("result.png", I);                      // save the Image  
  128.     imshow("SVM for Non-Linear Training Data", I); // show it to the user  
  129.     waitKey(0);  
  130. }  


结果:



抱歉!评论已关闭.