- 阅读权限
- 255
- 威望
- 0 级
- 论坛币
- 55 个
- 通用积分
- 2.1441
- 学术水平
- 2 点
- 热心指数
- 4 点
- 信用等级
- 2 点
- 经验
- 460 点
- 帖子
- 24
- 精华
- 0
- 在线时间
- 108 小时
- 注册时间
- 2017-1-15
- 最后登录
- 2020-6-22
大专生
还不是VIP/贵宾
- 威望
- 0 级
- 论坛币
 - 55 个
- 通用积分
- 2.1441
- 学术水平
- 2 点
- 热心指数
- 4 点
- 信用等级
- 2 点
- 经验
- 460 点
- 帖子
- 24
- 精华
- 0
- 在线时间
- 108 小时
- 注册时间
- 2017-1-15
- 最后登录
- 2020-6-22
 | 开心 2018-2-8 09:12:48 |
|---|
签到天数: 2 天 连续签到: 2 天 [LV.1]初来乍到
|
经管之家送您一份
应届毕业生专属福利!
求职就业群
感谢您参与论坛问题回答
经管之家送您两个论坛币!
+2 论坛币
- # coding: utf-8
- # In[28]:
- #####香农熵
- from math import log
- import operator
- def calcShannonEnt(dataSet):
- numEntries=len(dataSet)
- labelCounts={}
- for featVec in dataSet:
- currentLabel = featVec[-1]
- if currentLabel not in labelCounts.keys():
- labelCounts[currentLabel] = 0
- labelCounts[currentLabel] += 1
- shannonEnt = 0.0
- for key in labelCounts:
- prob = float(labelCounts[key])/numEntries
- shannonEnt -= prob * log(prob,2)
- return shannonEnt
- # In[6]:
- def createDataSet():
- dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
- labels=['no surfacing','flippers']
- return dataSet, labels
- # In[9]:
- #############划分数据集
- def splitDataSet(dataSet,axis,value):
- retDataSet = []
- for featVec in dataSet:
- if featVec[axis] == value:
- reducedFeatVec = featVec[:axis]
- reducedFeatVec.extend(featVec[axis+1:])
- retDataSet.append(reducedFeatVec)
- return retDataSet
- # In[26]:
- #######选择最佳划分
- def chooseBestFeatureToSplit(dataSet):
- numFeatures = len(dataSet[0]) - 1
- baseEntropy = calcShannonEnt(dataSet)
- bestInfoGain = 0.0;bestFeature = -1
- for i in range(numFeatures):
- featList = [example[i] for example in dataSet]
- uniqueVals = set(featList)
- newEntropy = 0.0
- for value in uniqueVals:
- subDataSet = splitDataSet(dataSet,i,value)
- prob =len(subDataSet)/float(len(dataSet))
- newEntropy += prob* calcShannonEnt(subDataSet)
- infoGain = baseEntropy - newEntropy
- if (infoGain > bestInfoGain):
- bestInfoGain = infoGain
- bestFeature = i
- return bestFeature
- # In[29]:
- ####投票决定标签
- def majorityCnt(classList):
- classCount = {}
- for vote in classList:
- if vote not in classCount.keys(): classCount[vote] = 0
- classCount[vote] +=1
- sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse = True)
- return sortedClassCount[0][0]
- # In[32]:
- #####递归构造决策树
- def createTree(dataSet,labels):
- classList = [example[-1] for example in dataSet]
- if classList.count(classList[0]) == len(classList):
- return classList[0]
- if len(dataSet[0]) == 1:
- return majorityCnt(classList)
- bestFeat = chooseBestFeatureToSplit(dataSet)
- bestFeatLabel = labels[bestFeat]
- myTree = {bestFeatLabel:{}}
- del(labels[bestFeat])
- featValues = [example[bestFeat] for example in dataSet]
- uniqueVals = set(featValues)
- for value in uniqueVals:
- subLabels = labels[:]
- myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
- return myTree
- # In[33]:
- ####测试
- myDat,labels =createDataSet()
- print(myDat)
- splitDataSet(myDat,0,0)
- chooseBestFeatureToSplit(myDat)
- myTree = createTree(myDat,labels)
- myTree
复制代码
扫码加我 拉你入群
请注明:姓名-公司-职位
以便审核进群资格,未注明则拒绝
|
|
|