楼主: fantuanxiaot
17293 216

[源码分享] [红包]基于C++的RandomForest随机森林总结   [推广有奖]

Ψ▄┳一大卫卍卐席尔瓦

大师

9%

还不是VIP/贵宾

-

威望
7
论坛币
-234475 个
通用积分
124.1424
学术水平
3783 点
热心指数
3819 点
信用等级
3454 点
经验
150417 点
帖子
7616
精华
32
在线时间
1327 小时
注册时间
2013-2-3
最后登录
2022-2-24

初级学术勋章 初级热心勋章 中级热心勋章 中级学术勋章 初级信用勋章 中级信用勋章 高级热心勋章 高级学术勋章 特级学术勋章 特级热心勋章 高级信用勋章 特级信用勋章

相似文件 换一批

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

来自:http://www.cnblogs.com/hrlnw/p/3850459.html


  1. #include <cv.h>       // opencv general include file
  2. #include <ml.h>          // opencv machine learning include file
  3. #include <stdio.h>

  4. using namespace cv; // OpenCV API is in the C++ "cv" namespace

  5. /******************************************************************************/
  6. // global definitions (for speed and ease of use)
  7. //手写体数字识别

  8. #define NUMBER_OF_TRAINING_SAMPLES 60000
  9. #define ATTRIBUTES_PER_SAMPLE 784
  10. #define NUMBER_OF_TESTING_SAMPLES 10000

  11. #define NUMBER_OF_CLASSES 10

  12. // N.B. classes are integer handwritten digits in range 0-9

  13. /******************************************************************************/

  14. // loads the sample database from file (which is a CSV text file)

  15. inline void revertInt(int&x)
  16. {
  17.     x=((x&0x000000ff)<<24)|((x&0x0000ff00)<<8)|((x&0x00ff0000)>>8)|((x&0xff000000)>>24);
  18. };

  19. int read_data_from_csv(const char* samplePath,const char* labelPath, Mat data, Mat classes,
  20.                        int n_samples )
  21. {
  22.     FILE* sampleFile=fopen(samplePath,"rb");
  23.     FILE* labelFile=fopen(labelPath,"rb");
  24.     int mbs=0,number=0,col=0,row=0;
  25.     fread(&mbs,4,1,sampleFile);
  26.     fread(&number,4,1,sampleFile);
  27.     fread(&row,4,1,sampleFile);
  28.     fread(&col,4,1,sampleFile);
  29.     revertInt(mbs);
  30.     revertInt(number);
  31.     revertInt(row);
  32.     revertInt(col);
  33.     fread(&mbs,4,1,labelFile);
  34.     fread(&number,4,1,labelFile);
  35.     revertInt(mbs);
  36.     revertInt(number);
  37.     unsigned char temp;
  38.     for(int line = 0; line < n_samples; line++)
  39.     {
  40.         // for each attribute on the line in the file
  41.         for(int attribute = 0; attribute < (ATTRIBUTES_PER_SAMPLE + 1); attribute++)
  42.         {
  43.             if (attribute < ATTRIBUTES_PER_SAMPLE)
  44.             {
  45.                 // first 64 elements (0-63) in each line are the attributes
  46.                 fread(&temp,1,1,sampleFile);
  47.                 //fscanf(f, "%f,", &tmp);
  48.                 data.at<float>(line, attribute) = static_cast<float>(temp);
  49.                 // printf("%f,", data.at<float>(line, attribute));
  50.             }
  51.             else if (attribute == ATTRIBUTES_PER_SAMPLE)
  52.             {
  53.                 // attribute 65 is the class label {0 ... 9}
  54.                 fread(&temp,1,1,labelFile);
  55.                 //fscanf(f, "%f,", &tmp);
  56.                 classes.at<float>(line, 0) = static_cast<float>(temp);
  57.                 // printf("%f\n", classes.at<float>(line, 0));
  58.             }
  59.         }
  60.     }
  61.     fclose(sampleFile);
  62.     fclose(labelFile);
  63.     return 1; // all OK
  64. }

  65. /******************************************************************************/

  66. int main( int argc, char** argv )
  67. {
  68.    
  69.     for (int i=0; i< argc; i++)
  70.         std::cout<<argv[i]<<std::endl;
  71.    
  72.     // lets just check the version first
  73.     printf ("OpenCV version %s (%d.%d.%d)\n",
  74.             CV_VERSION,
  75.             CV_MAJOR_VERSION, CV_MINOR_VERSION, CV_SUBMINOR_VERSION);
  76.    
  77.     //定义训练数据与标签矩阵
  78.     Mat training_data = Mat(NUMBER_OF_TRAINING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);
  79.     Mat training_classifications = Mat(NUMBER_OF_TRAINING_SAMPLES, 1, CV_32FC1);

  80.     //定义测试数据矩阵与标签
  81.     Mat testing_data = Mat(NUMBER_OF_TESTING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);
  82.     Mat testing_classifications = Mat(NUMBER_OF_TESTING_SAMPLES, 1, CV_32FC1);

  83.     // define all the attributes as numerical
  84.     // alternatives are CV_VAR_CATEGORICAL or CV_VAR_ORDERED(=CV_VAR_NUMERICAL)
  85.     // that can be assigned on a per attribute basis

  86.     Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U );
  87.     var_type.setTo(Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical

  88.     // this is a classification problem (i.e. predict a discrete number of class
  89.     // outputs) so reset the last (+1) output var_type element to CV_VAR_CATEGORICAL

  90.     var_type.at<uchar>(ATTRIBUTES_PER_SAMPLE, 0) = CV_VAR_CATEGORICAL;

  91.     double result; // value returned from a prediction

  92.     //加载训练数据集和测试数据集
  93.     if (read_data_from_csv(argv[1],argv[2], training_data, training_classifications, NUMBER_OF_TRAINING_SAMPLES) &&
  94.             read_data_from_csv(argv[3],argv[4], testing_data, testing_classifications, NUMBER_OF_TESTING_SAMPLES))
  95.     {
  96.       /********************************步骤1:定义初始化Random Trees的参数******************************/
  97.         float priors[] = {1,1,1,1,1,1,1,1,1,1};  // weights of each classification for classes
  98.         CvRTParams params = CvRTParams(20, // max depth
  99.                                        50, // min sample count
  100.                                        0, // regression accuracy: N/A here
  101.                                        false, // compute surrogate split, no missing data
  102.                                        15, // max number of categories (use sub-optimal algorithm for larger numbers)
  103.                                        priors, // the array of priors
  104.                                        false,  // calculate variable importance
  105.                                        50,       // number of variables randomly selected at node and used to find the best split(s).
  106.                                        100,     // max number of trees in the forest
  107.                                        0.01f,                // forest accuracy
  108.                                        CV_TERMCRIT_ITER |    CV_TERMCRIT_EPS // termination cirteria
  109.                                       );

  110.         /****************************步骤2:训练 Random Decision Forest(RDF)分类器*********************/
  111.         printf( "\nUsing training database: %s\n\n", argv[1]);
  112.         CvRTrees* rtree = new CvRTrees;
  113.         bool train_result=rtree->train(training_data, CV_ROW_SAMPLE, training_classifications,
  114.                      Mat(), Mat(), var_type, Mat(), params);
  115. //        float train_error=rtree->get_train_error();
  116. //        printf("train error:%f\n",train_error);
  117.         // perform classifier testing and report results
  118.         Mat test_sample;
  119.         int correct_class = 0;
  120.         int wrong_class = 0;
  121.         int false_positives [NUMBER_OF_CLASSES] = {0,0,0,0,0,0,0,0,0,0};

  122.         printf( "\nUsing testing database: %s\n\n", argv[2]);

  123.         for (int tsample = 0; tsample < NUMBER_OF_TESTING_SAMPLES; tsample++)
  124.         {

  125.             // extract a row from the testing matrix
  126.             test_sample = testing_data.row(tsample);
  127.         /********************************步骤3:预测*********************************************/
  128.             result = rtree->predict(test_sample, Mat());

  129.             printf("Testing Sample %i -> class result (digit %d)\n", tsample, (int) result);

  130.             // if the prediction and the (true) testing classification are the same
  131.             // (N.B. openCV uses a floating point decision tree implementation!)
  132.             if (fabs(result - testing_classifications.at<float>(tsample, 0))
  133.                     >= FLT_EPSILON)
  134.             {
  135.                 // if they differ more than floating point error => wrong class
  136.                 wrong_class++;
  137.                 false_positives[(int) result]++;
  138.             }
  139.             else
  140.             {
  141.                 // otherwise correct
  142.                 correct_class++;
  143.             }
  144.         }

  145.         printf( "\nResults on the testing database: %s\n"
  146.                 "\tCorrect classification: %d (%g%%)\n"
  147.                 "\tWrong classifications: %d (%g%%)\n",
  148.                 argv[2],
  149.                 correct_class, (double) correct_class*100/NUMBER_OF_TESTING_SAMPLES,
  150.                 wrong_class, (double) wrong_class*100/NUMBER_OF_TESTING_SAMPLES);

  151.         for (int i = 0; i < NUMBER_OF_CLASSES; i++)
  152.         {
  153.             printf( "\tClass (digit %d) false postives     %d (%g%%)\n", i,
  154.                     false_positives[i],
  155.                     (double) false_positives[i]*100/NUMBER_OF_TESTING_SAMPLES);
  156.         }

  157.         // all matrix memory free by destructors

  158.         // all OK : main returns 0
  159.         return 0;
  160.     }

  161.     // not OK : main returns -1
  162.     return -1;
  163. }
复制代码
MNIST样本可以在这个网址http://yann.lecun.com/exdb/mnist/下载,改一下路径可以直接跑的。
3.如何自己设计随机森林程序
    有时现有的库无法满足要求,就需要自己设计一个分类器算法,这部分讲一下如何设计自己的随机森林分类器,代码实现就不贴了,因为在工作中用到了,因此比较敏感。
    首先,要有一个RandomForest类,里面保存整个树需要的一些参数,包括但不限于:训练样本数量、测试样本数量、特征维数、每个节点随机提取的特征维数、CART树的数量、树的最大深度、类别数量(如果是分类问题)、一些终止条件、指向所有树的指针,指向训练集和测试集的指针,指向训练集label的指针等。还要有一些函数,至少要有train和predict吧。train里面直接调用每棵树的train方法即可,predict同理,但要对每棵树的预测输出做处理,得到森林的预测输出。
    其次,要有一个sample类,这个类可不是用来存储训练集和对应label的,这是因为,每棵树、每个节点都有自己的样本集和,如果你的存储每个样本集和的话,需要的内存实在是太过巨大了,假设样本数量为M,特征维数为N,则整个训练集大小为M×N,而每棵树的每层都有这么多样本,树的深度为D,共有S棵树的话,则需要存储M×N×D×S的存储空间。这实在是太大了。因此,每个节点训练时用到的训练样本和特征,我们都用序号数组来代替,sample类就是干这个的。sample的函数基本需要两个就行,第一个是从现有训练集有放回的随机抽取一个新的训练集,当然,只包含样本的序号。第二个函数是从现有的特征中无放回的随机抽取一定数量的特征,同理,也是特征序号即可。

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:randomForest random Forest 随机森林 Rest Adele 森林 信息 样本

已有 2 人评分学术水平 热心指数 信用等级 收起 理由
newfei188 + 1 精彩帖子
oink-oink + 4 + 4 + 4 精彩帖子

总评分: 学术水平 + 5  热心指数 + 4  信用等级 + 4   查看全部评分

本帖被以下文库推荐

沙发
fjrong 在职认证  发表于 2015-3-11 00:20:29 |只看作者 |坛友微信交流群

回帖奖励 +6

使用道具

藤椅
fjrong 在职认证  发表于 2015-3-11 00:21:04 |只看作者 |坛友微信交流群

回帖奖励 +6

楼主大好人啊

使用道具

板凳
fjrong 在职认证  发表于 2015-3-11 00:21:35 |只看作者 |坛友微信交流群

回帖奖励 +6

使用道具

报纸
fjrong 在职认证  发表于 2015-3-11 00:22:14 |只看作者 |坛友微信交流群

回帖奖励 +6

果断顶起!

使用道具

地板
fjrong 在职认证  发表于 2015-3-11 00:22:55 |只看作者 |坛友微信交流群

回帖奖励 +6

使用道具

7
fjrong 在职认证  发表于 2015-3-11 00:23:33 |只看作者 |坛友微信交流群

回帖奖励 +6

谢谢楼主拉!!!

使用道具

8
水墨Melody 发表于 2015-3-11 00:24:12 |只看作者 |坛友微信交流群

回帖奖励 +6

使用道具

9
fantuanxiaot 发表于 2015-3-11 00:26:19 |只看作者 |坛友微信交流群
水墨Melody 发表于 2015-3-11 00:24
回复6次

使用道具

10
水墨Melody 发表于 2015-3-11 00:27:21 |只看作者 |坛友微信交流群

回帖奖励 +6

酷!

使用道具

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jr
拉您进交流群

京ICP备16021002-2号 京B2-20170662号 京公网安备 11010802022788号 论坛法律顾问:王进律师 知识产权保护声明   免责及隐私声明

GMT+8, 2024-4-26 19:23