## ----------------------------------------------------------##
## 决策树 ##
## ----------------------------------------------------------##
library("rpart")
library("rpart.plot")
ct <- rpart.control(xval=10
,minsplit=10
,minbucket = 10
,cp=0.01
,maxdepth=10
)
## is_bank_card_null which(colnames(data_input)=="is_bank_card_null")=2
fitFP <- rpart(is_black ~ ., data = data_input,control=ct
,method="class")
rpart.plot(fitFP,type=2,extra=2)
#table(data_input$is_black)
#table(data_input[,c(1,2)])
var.imp=fitFP$variable.importance
#fitFP$frame
#fitFP$splits
summaryFP=summary(fitFP)
## ----------------------------------------------------------##
## 决策树结果, 链路输出 ##
## ----------------------------------------------------------##
library(stringr)
char_2_vec=function(char){
char=str_trim(char) ##去除收尾的空格
#char = "4) a> =12 b> 23 c<= 3 d< 98 nothing"
##将"> " 转化为">"
char_1=str_replace_all(char, c("> "), ">")
i=0
while(!(char== char_1)) {
char_1=str_replace_all(char, c("> "), ">")
char_2=str_replace_all(char_1, c("> "), ">")
char=char_2
i=i+1
#cat('i: Step ',i,'\n')
}
##将">= " 转化为">="
char_1=str_replace_all(char, c(">= "), ">=")
i=0
while(!(char== char_1)) {
char_1=str_replace_all(char, c(">= "), ">=")
char_2=str_replace_all(char_1, c(">= "), ">=")
char=char_2
i=i+1
#cat('i: Step ',i,'\n')
}
##将"< " 转化为"<"
char_1=str_replace_all(char, c("< "), "<")
i=0
while(!(char== char_1)) {
char_1=str_replace_all(char, c("< "), "<")
char_2=str_replace_all(char_1, c("< "), "<")
char=char_2
i=i+1
#cat('i: Step ',i,'\n')
}
##将"<= " 转化为"<="
char_1=str_replace_all(char, c("<= "), "<=")
i=0
while(!(char== char_1)) {
char_1=str_replace_all(char, c("<= "), "<=")
char_2=str_replace_all(char_1, c("<= "), "<=")
char=char_2
i=i+1
#cat('i: Step ',i,'\n')
}
##将"= " 转化为"="
char_1=str_replace_all(char, c("= "), "=")
i=0
while(!(char== char_1)) {
char_1=str_replace_all(char, c("= "), "=")
char_2=str_replace_all(char_1, c("= "), "=")
char=char_2
i=i+1
#cat('i: Step ',i,'\n')
}
## 删除括号"()", 以及多余的空格
char=str_replace_all(char, "[()]", "")
charArray=str_split(char, " ") #这里得到的是列表
idx_blk=which(charArray[[1]]=="")
if(length(idx_blk)>=1){
vec=charArray[[1]][-idx_blk]
}else{
vec=charArray[[1]]
}
return(vec)
}
Tree_2_csv=function(summaryTree){
input=capture.output(summaryFP) ##转化为字符串
input1=str_trim(input)[-c(1:5)] ##去除首尾的空格, 且删除前5行定义字段!
NAMES=c("node", "split", "n", "loss",
"yval", "yprob_0", "yprob_1","is_leaf")
n_row=length(input1)
n_col=length(NAMES)
CSV=matrix(rep(NA, n_row*n_col), ncol=n_col)
colnames(CSV)= NAMES
for(i in 1:n_row){
char=input1[i]
vec=char_2_vec(char)
CSV[i, 1:length(vec)] = vec
}
return(as.data.frame(CSV))
}
CSV_2_CHAIN=function(CSV){
leaf_info=list()
CHAIN=list()
node=as.numeric(as.vector(CSV$node)) ## 转为数值型
leaf_idx=which(CSV$is_leaf=="*")
leaf_node=CSV$node[leaf_idx]
leaf_node=as.numeric(as.vector(leaf_node)) ## 转为数值型
for(i in 1:length(leaf_node)){
chain.num=c()
leaf.num = leaf_node[i]
#得到 chain.num 链路
while(leaf.num >= 1 ){
chain.num=c(chain.num, leaf.num)
leaf.num=as.integer(leaf.num/2)
}
#根据 chain.num 得到行序号 chain.ord
chain.ord=c()
for(j in 1:length(chain.num)){
chain.ord=c(chain.ord, which(node==chain.num[j]) )
}
#顺排, 第一个就是叶节点的下标,输出叶节点信息
leaf.ord = chain.ord[1]
leaf_info[[i]]=c(
as.numeric(as.character(CSV$node[leaf.ord])) #叶节点的位置
,as.numeric(as.character(CSV$yval[leaf.ord])) #叶节点的判定
,as.numeric(as.character(CSV$n[leaf.ord])) #叶节点的点总数
,as.numeric(as.character(CSV$n[leaf.ord]))-
as.numeric(as.character(CSV$loss[leaf.ord])) #叶节点的判定所对应的点数
)
#逆排rev, 并且输出链路
CHAIN[[i]]=as.character( CSV$split[rev(chain.ord)] )
}
return(list(leaf_node=leaf_node
,leaf_info=leaf_info
,CHAIN=CHAIN
)
)
}
CSV=Tree_2_csv(summaryFP) ##测试
RESULT = CSV_2_CHAIN(CSV) ##测试
将决策树的结果输出,输出叶节点的链路,在RESULT这个list的CHAIN中.
|