楼主: A蓝白红
1401 0

[问答] 用CART方法自定义最优分箱函数时遇到的递归函数问题 [推广有奖]

  • 2关注
  • 0粉丝

本科生

59%

还不是VIP/贵宾

-

威望
0
论坛币
256 个
通用积分
1.0500
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
934 点
帖子
43
精华
0
在线时间
156 小时
注册时间
2011-11-8
最后登录
2025-12-17

楼主
A蓝白红 发表于 2019-7-21 20:41:58 |AI写论文
100论坛币
  1. calc_score_median<- function(df, col){

  2. # 计算相邻评分的中位数,以便进行决策树二元切分
  3. #
  4. # param df: 待切分样本
  5. # param col: 分割变量名称
  6.   var_list <- sort(unique(df[,col]))
  7.   var_median_list <- numeric(length(var_list)-1)
  8.   for(i in 1:(length(var_list)-1)){
  9.     var_median_list[i] = (var_list[i] + var_list[i+1]) / 2
  10.   }
  11.   return(var_median_list)
  12. }





  13. choose_best_split <- function(df,col,target,min_sample){
  14.   # 使用CART分类决策树选择最好的样本切分点
  15.   # 返回切分点
  16.   # param sample_set: 待切分样本
  17.   # param var: 分割变量名称
  18.   # param min_sample: 待切分样本的最小样本量(限制条件),如果小于该阈值则不进行切分,可设置为整体样本量的5%
  19.   var_median_list <- calc_score_median(df,col)
  20.   median_len = length(var_median_list)
  21.   sample_cnt <- dim(df)[1]
  22.   sample1_cnt <- sum(df[,target])
  23.   sample0_cnt <- sample_cnt-sample1_cnt
  24.   Gini <- 1- ((sample1_cnt/sample_cnt)**2)-((sample0_cnt/sample_cnt)**2)
  25.   bestGini <- 0.0
  26.   bestSplit_point <- 0.0
  27.   bestSplit_position <- 0.0
  28.   for(i in 1:median_len){
  29.     left <- df[which(df[,col] < var_median_list[i]),]
  30.     right <- df[which(df[,col] > var_median_list[i]),]
  31.     left_cnt <- dim(left)[1]
  32.     right_cnt <- dim(right)[1]
  33.     left1_cnt <- sum(left[,target])
  34.     right1_cnt <- sum(right[,target])
  35.     left0_cnt =  left_cnt - left1_cnt
  36.     right0_cnt =  right_cnt - right1_cnt
  37.     left_ratio = left_cnt / sample_cnt
  38.     right_ratio = right_cnt / sample_cnt

  39.     Gini_left = 1 - ((left1_cnt / left_cnt)**2) - ((left0_cnt / left_cnt)**2)
  40.     Gini_right = 1 - ((right1_cnt / right_cnt)**2) - ((right0_cnt / right_cnt)**2)
  41.     Gini_temp = Gini - (left_ratio * Gini_left + right_ratio * Gini_right)
  42.     if(left_cnt < min_sample || right_cnt < min_sample){
  43.       next
  44.     }
  45.     if(Gini_temp > bestGini){
  46.       bestGini = Gini_temp
  47.       bestSplit_point = var_median_list[i]
  48.       if(median_len>1){
  49.         bestSplit_position = i / (median_len - 1)
  50.       }else{
  51.         bestSplit_position = i / median_len
  52.       }
  53.     }else{
  54.       next
  55.     }
  56.   }
  57.   Gini = Gini - bestGini
  58.   return(list(bestSplit_point, bestSplit_position))
  59. }




  60. bining_data_split <- function(df, col, target,min_sample,split_list){
  61.   # 划分数据找到最优分割点list
  62.   # param sample_set: 待切分样本
  63.   # param var: 分割变量名称
  64.   # param min_sample: 待切分样本的最小样本量(限制条件),如果小于该阈值则不进行切分,可设置为整体样本量的5%
  65.   # param split_list: 最优分割点list,split_list 参数是用来保存返回的切分点,每次切分后返回的切分点存入该list
  66.   middle <- choose_best_split(df,col,target,min_sample)
  67.   split <- middle[[1]]
  68.   position <- middle[[2]]
  69.   if(split != 0){
  70.     split_list <- c(split_list,split)
  71.   }
  72.   # 根据分割点划分数据集,继续进行划分

  73.   sample_set_left <- df[which(df[,col] < split),]
  74.   sample_set_right <- df[which(df[,col] > split),]
  75.   # 如果左子树样本量超过2倍最小样本量,且分割点不是第一个分割点,则切分左子树
  76.   #在这里判断切分点分割的左子树和右子树是否满足“内部节点再划分所需的最小样本数>=总样本量的10%”的条件,如果满足则进行递归调用。
  77.   if(length(sample_set_left) >= (min_sample * 2) && ! position %in% c(0, 1)){
  78.     bining_data_split(sample_set_left, col, target,min_sample, split_list)
  79.   }else{
  80.     return(NULL)
  81.   }

  82.   if(length(sample_set_right) >= (min_sample * 2) && ! position %in% c(0, 1)){
  83.     bining_data_split(sample_set_right, col, target,min_sample, split_list)
  84.   }else{
  85.     return(NULL)
  86.   }

  87. }



  88. get_bestsplit_list <- function(df, col,target,min_sample_ratio){
  89.   # 计算最小样本阈值(终止条件)
  90.   min_sample <- dim(df)[1]*min_sample_ratio
  91.   split_list = c()
  92.   bining_data_split(df, col, target,min_sample, split_list)
  93.   return(split_list)
  94. }
复制代码
bining_data_split函数被设计成一个递归函数,其中有一个关键的地方就是参数split_list,每次bining_data_split循环调用自身时希望会值能够不断的添加到split_list,最终用get_bestsplit_list 求split_list的结果,但是bining_data_split 函数中if(split != 0){
split_list <- c(split_list,split)
}这种设计好像是有问题的,导致结果出不来


求大神能帮忙解答解答,万分感谢

关键词:CART 自定义 CAR ART median R语言 递归 CART 最优分箱

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注cda
拉您进交流群
GMT+8, 2026-1-3 23:22