一、ID3 决策树完整代码(treesID3.py)
from math import log
import operator
import matplotlib.pyplot as plt
# 解决中文显示问题
plt.rcParams["font.family"] = ["sans-serif"]
plt.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False
# ---------------------- ID3核心算法 ----------------------
def calcShannonEnt(dataSet):
"""计算数据集的香农熵"""
numEntries = len(dataSet)
labelCounts = {}
# 统计每个标签出现的次数
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts:
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
# 计算香农熵
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
"""按照给定特征划分数据集"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# 去除当前特征列
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
"""选择最优特征(基于信息增益)"""
numFeatures = len(dataSet[0]) - 1 # 特征数量(最后一列是标签)
baseEntropy = calcShannonEnt(dataSet) # 基础熵
bestInfoGain = 0.0
bestFeature = -1
infoGains = [] # 保存所有特征的信息增益
for i in range(numFeatures):
# 获取当前特征的所有取值
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 计算条件熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 计算信息增益
infoGain = baseEntropy - newEntropy
infoGains.append(infoGain)
# 更新最佳特征
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature, infoGains # 返回最佳特征索引和所有特征的信息增益
def majorityCnt(classList):
"""多数表决法决定叶子节点类别"""
classCount = {}
for vote in classList:
if vote not in classCount:
classCount[vote] = 0
classCount[vote] += 1
# 按出现次数排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
"""构建ID3决策树"""
classList = [example[-1] for example in dataSet]
# 递归终止条件1:所有样本属于同一类别
if classList.count(classList[0]) == len(classList):
return "贷款" if classList[0] == 1 else "不贷款"
# 递归终止条件2:没有特征可分,返回多数类
if len(dataSet[0]) == 1:
majority = majorityCnt(classList)
return "贷款" if majority == 1 else "不贷款"
# 选择最佳特征
bestFeat, infoGains = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
infoGain = infoGains[bestFeat] # 当前节点的信息增益
# 构建树节点
myTree = {
'label': bestFeatLabel,
'info_gain': round(infoGain, 4), # 保留4位小数
'children': {}
}
# 递归构建子树
subLabels = labels[:] # 复制标签列表(避免修改原列表)
del(subLabels[bestFeat]) # 删除已使用的特征
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
myTree['children'][value] = createTree(
splitDataSet(dataSet, bestFeat, value),
subLabels
)
return myTree
def classify(inputTree, featLabels, testVec):
"""使用决策树进行分类预测"""
if isinstance(inputTree, str): # 叶子节点
return 1 if inputTree == "贷款" else 0
firstLabel = inputTree['label']
featIndex = featLabels.index(firstLabel)
secondDict = inputTree['children']
# 递归查找匹配的子节点
for key in secondDict.keys():
if testVec[featIndex] == key:
return classify(secondDict[key], featLabels, testVec)
return None # 未找到匹配
# ---------------------- 可视化功能 ----------------------
# 节点样式定义
decision_node = dict(boxstyle="sawtooth,pad=1.2", fc="lightblue", ec="black", lw=1.5)
leaf_node = dict(boxstyle="round4,pad=1.2", fc="lightgreen", ec="black", lw=1.5)
arrow_args = dict(arrowstyle="<-", color="gray", lw=2,
shrinkA=0, shrinkB=22,
connectionstyle="arc3,rad=0.05")
def getNumLeafs(myTree):
"""获取叶子节点数量"""
if isinstance(myTree, str):
return 1
numLeafs = 0
for key in myTree['children'].keys():
numLeafs += getNumLeafs(myTree['children'][key])
return numLeafs
def getTreeDepth(myTree):
"""获取树的深度"""
if isinstance(myTree, str):
return 1
maxDepth = 0
for key in myTree['children'].keys():
thisDepth = 1 + getTreeDepth(myTree['children'][key])
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""绘制节点"""
plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType,
arrowprops=arrow_args, fontsize=12)
def plotMidText(cntrPt, parentPt, txtString):
"""在父子节点之间添加文本(特征取值)"""
xMid = (parentPt[0] + cntrPt[0]) / 2.0
yMid = (parentPt[1] + cntrPt[1]) / 2.0
# 微调文本位置避免重叠
dx = cntrPt[0] - parentPt[0]
dy = cntrPt[1] - parentPt[1]
offset = 0.02
if abs(dx) < 0.05:
xMid += offset * (1 if cntrPt[0] > parentPt[0] else -1)
else:
yMid += offset * (1 if cntrPt[1] > parentPt[1] else -1)
plt.text(xMid, yMid, txtString, va="center", ha="center",
fontsize=10, color='darkred', weight='bold',
bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.9))
def plotInfoGain(centerPt, infoGain):
"""在决策节点下方绘制信息增益"""
plt.text(centerPt[0], centerPt[1] - 0.045,
f"信息增益: {infoGain}",
va="center", ha="center",
fontsize=9, color='blue', weight='bold',
bbox=dict(boxstyle="round,pad=0.2", facecolor='lightyellow', alpha=0.8))
def plotTree(myTree, parentPt, nodeTxt):
"""递归绘制决策树"""
if isinstance(myTree, str): # 叶子节点
plotTree.xOff += 1.0 / plotTree.totalW
leaf_center = (plotTree.xOff, plotTree.yOff)
plotNode(myTree, leaf_center, parentPt, leaf_node)
plotMidText(leaf_center, parentPt, nodeTxt)
return
numLeafs = getNumLeafs(myTree)
firstStr = myTree['label']
infoGain = myTree['info_gain']
# 计算当前节点位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 绘制当前节点
plotNode(firstStr, cntrPt, parentPt, decision_node)
plotInfoGain(cntrPt, infoGain)
plotMidText(cntrPt, parentPt, nodeTxt)
# 递归绘制子节点
secondDict = myTree['children']
plotTree.yOff -= 1.1 / plotTree.totalD
for key in secondDict.keys():
plotTree(secondDict[key], cntrPt, str(key))
plotTree.yOff += 1.1 / plotTree.totalD
def createPlot(inTree):
"""创建决策树可视化图表"""
fig = plt.figure(1, facecolor='white', figsize=(16, 12))
fig.clf()
axprops = dict(xticks=[], yticks=[])
plt.subplot(111, frameon=False, **axprops)
# 初始化树布局参数
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
# 绘制树
plotTree(inTree, (0.5, 1.0), '')
# 添加图例
plt.figtext(0.02, 0.02,
"● 决策节点(蓝色): 特征名称\n● 叶节点(绿色): 分类结果\n● 边上数字: 特征取值",
fontsize=10,
bbox=dict(boxstyle="round", facecolor='lightgray', alpha=0.8))
plt.title('决策树可视化 (ID3算法)', fontsize=16, pad=40, fontweight='bold')
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.12)
plt.show()
# ---------------------- 数据加载与主函数 ----------------------
def loadDataSet(filename):
"""从文件加载数据集"""
dataSet = []
try:
with open(filename, 'r', encoding='utf-8') as fr:
for line in fr.readlines():
line = line.strip()
if not line: # 跳过空行
continue
example = list(map(int, line.split(',')))
dataSet.append(example)
return dataSet
except FileNotFoundError:
print(f"错误:未找到文件 {filename},请放在代码同一目录下")
return None
if __name__ == "__main__":
# 特征标签
labels = ['年龄段', '有工作', '有自己的房', '信贷情况']
# 加载数据
train_data = loadDataSet('dataset.txt')
test_data = loadDataSet('testset.txt')
if not train_data or not test_data:
exit()
# 构建决策树
myTree = createTree(train_data, labels.copy())
print("ID3决策树结构:")
print(myTree)
print("\n" + "-" * 50 + "\n")
# 可视化决策树
createPlot(myTree)
# 测试集验证
correct = 0
print("测试集预测结果:")
print(f"{'样本特征':<16} {'预测':<5} {'真实':<6} 结果")
print("-" * 40)
for sample in test_data:
features = sample[:-1]
true_label = sample[-1]
pred_label = classify(myTree, labels, features)
res = "正确" if pred_label == true_label else "错误"
if res == "正确":
correct += 1
print(f"{str(features):<16} {pred_label:<5} {true_label:<6} {res}")
# 计算准确率
accuracy = correct / len(test_data) * 100
print(f"\n准确率: {accuracy:.2f}%")
二、C4.5 决策树完整代码(treesC4.5.py)
from math import log
import operator
import matplotlib.pyplot as plt
# 解决中文显示问题
plt.rcParams["font.family"] = ["sans-serif"]
plt.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False
# ---------------------- C4.5核心算法 ----------------------
def calcShannonEnt(dataSet):
"""计算数据集的香农熵(与ID3相同)"""
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts:
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
"""按照给定特征划分数据集(与ID3相同)"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def calcFeatureEntropy(dataSet, axis):
"""计算特征本身的熵(分裂信息)- C4.5特有"""
numEntries = len(dataSet)
featureCounts = {}
# 统计特征每个取值的出现次数
for featVec in dataSet:
featureVal = featVec[axis]
if featureVal not in featureCounts:
featureCounts[featureVal] = 0
featureCounts[featureVal] += 1
# 计算特征熵
featureEntropy = 0.0
for key in featureCounts:
prob = float(featureCounts[key]) / numEntries
featureEntropy -= prob * log(prob, 2)
return featureEntropy
def chooseBestFeatureToSplit(dataSet):
"""选择最优特征(基于信息增益率)- C4.5核心差异"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestGainRatio = 0.0
bestFeature = -1
infoGainRatios = [] # 保存所有特征的信息增益率
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 1. 计算条件熵(与ID3相同)
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 2. 计算信息增益
infoGain = baseEntropy - newEntropy
# 3. 计算特征熵(分裂信息)
featureEntropy = calcFeatureEntropy(dataSet, i)
# 4. 计算信息增益率(避免除零错误)
if featureEntropy == 0:
gainRatio = 0
else:
gainRatio = infoGain / featureEntropy
infoGainRatios.append(gainRatio)
# 更新最佳特征
if gainRatio > bestGainRatio:
bestGainRatio = gainRatio
bestFeature = i
return bestFeature, infoGainRatios
def majorityCnt(classList):
"""多数表决法决定叶子节点类别(与ID3相同)"""
classCount = {}
for vote in classList:
if vote not in classCount:
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
"""构建C4.5决策树"""
classList = [example[-1] for example in dataSet]
# 递归终止条件(与ID3相同)
if classList.count(classList[0]) == len(classList):
return "贷款" if classList[0] == 1 else "不贷款"
if len(dataSet[0]) == 1:
majority = majorityCnt(classList)
return "贷款" if majority == 1 else "不贷款"
# 选择最佳特征(使用信息增益率)
bestFeat, infoGainRatios = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
infoGainRatio = infoGainRatios[bestFeat] # 当前节点的信息增益率
# 构建树节点(存储信息增益率)
myTree = {
'label': bestFeatLabel,
'info_gain_ratio': round(infoGainRatio, 4), # 保留4位小数
'children': {}
}
# 递归构建子树(与ID3流程相同)
subLabels = labels[:]
del(subLabels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
myTree['children'][value] = createTree(
splitDataSet(dataSet, bestFeat, value),
subLabels
)
return myTree
def classify(inputTree, featLabels, testVec):
"""分类预测函数(与ID3相同)"""
if isinstance(inputTree, str):
return 1 if inputTree == "贷款" else 0
firstLabel = inputTree['label']
featIndex = featLabels.index(firstLabel)
secondDict = inputTree['children']
for key in secondDict.keys():
if testVec[featIndex] == key:
return classify(secondDict[key], featLabels, testVec)
return None
# ---------------------- 可视化功能 ----------------------
# 节点样式(与ID3相同)
decision_node = dict(boxstyle="sawtooth,pad=1.2", fc="lightblue", ec="black", lw=1.5)
leaf_node = dict(boxstyle="round4,pad=1.2", fc="lightgreen", ec="black", lw=1.5)
arrow_args = dict(arrowstyle="<-", color="gray", lw=2,
shrinkA=0, shrinkB=22,
connectionstyle="arc3,rad=0.05")
def getNumLeafs(myTree):
"""获取叶子节点数量(与ID3相同)"""
if isinstance(myTree, str):
return 1
numLeafs = 0
for key in myTree['children'].keys():
numLeafs += getNumLeafs(myTree['children'][key])
return numLeafs
def getTreeDepth(myTree):
"""获取树的深度(与ID3相同)"""
if isinstance(myTree, str):
return 1
maxDepth = 0
for key in myTree['children'].keys():
thisDepth = 1 + getTreeDepth(myTree['children'][key])
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""绘制节点(与ID3相同)"""
plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType,
arrowprops=arrow_args, fontsize=12)
def plotMidText(cntrPt, parentPt, txtString):
"""在父子节点之间添加文本(与ID3相同)"""
xMid = (parentPt[0] + cntrPt[0]) / 2.0
yMid = (parentPt[1] + cntrPt[1]) / 2.0
dx = cntrPt[0] - parentPt[0]
dy = cntrPt[1] - parentPt[1]
offset = 0.02
if abs(dx) < 0.05:
xMid += offset * (1 if cntrPt[0] > parentPt[0] else -1)
else:
yMid += offset * (1 if cntrPt[1] > parentPt[1] else -1)
plt.text(xMid, yMid, txtString, va="center", ha="center",
fontsize=10, color='darkred', weight='bold',
bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.9))
def plotInfoGainRatio(centerPt, infoGainRatio):
"""在决策节点下方绘制信息增益率 - C4.5特有"""
plt.text(centerPt[0], centerPt[1] - 0.045,
f"信息增益率: {infoGainRatio}",
va="center", ha="center",
fontsize=9, color='blue', weight='bold',
bbox=dict(boxstyle="round,pad=0.2", facecolor='lightyellow', alpha=0.8))
def plotTree(myTree, parentPt, nodeTxt):
"""递归绘制决策树(显示信息增益率)"""
if isinstance(myTree, str): # 叶子节点
plotTree.xOff += 1.0 / plotTree.totalW
leaf_center = (plotTree.xOff, plotTree.yOff)
plotNode(myTree, leaf_center, parentPt, leaf_node)
plotMidText(leaf_center, parentPt, nodeTxt)
return
numLeafs = getNumLeafs(myTree)
firstStr = myTree['label']
infoGainRatio = myTree['info_gain_ratio'] # 使用信息增益率
# 计算当前节点位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 绘制当前节点
plotNode(firstStr, cntrPt, parentPt, decision_node)
plotInfoGainRatio(cntrPt, infoGainRatio) # 绘制信息增益率
plotMidText(cntrPt, parentPt, nodeTxt)
# 递归绘制子节点
secondDict = myTree['children']
plotTree.yOff -= 1.1 / plotTree.totalD
for key in secondDict.keys():
plotTree(secondDict[key], cntrPt, str(key))
plotTree.yOff += 1.1 / plotTree.totalD
def createPlot(inTree):
"""创建决策树可视化图表(标题和图例适配C4.5)"""
fig = plt.figure(1, facecolor='white', figsize=(16, 12))
fig.clf()
axprops = dict(xticks=[], yticks=[])
plt.subplot(111, frameon=False, **axprops)
# 初始化树布局参数
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
# 绘制树
plotTree(inTree, (0.5, 1.0), '')
# 添加图例(说明使用信息增益率)
plt.figtext(0.02, 0.02,
"● 决策节点(蓝色): 特征名称\n● 叶节点(绿色): 分类结果\n● 边上数字: 特征取值\n● C4.5算法使用信息增益率",
fontsize=10,
bbox=dict(boxstyle="round", facecolor='lightgray', alpha=0.8))
plt.title('决策树可视化 (C4.5算法)', fontsize=16, pad=40, fontweight='bold')
plt.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.12)
plt.show()
# ---------------------- 数据加载与主函数 ----------------------
def loadDataSet(filename):
"""从文件加载数据集(与ID3相同)"""
dataSet = []
try:
with open(filename, 'r', encoding='utf-8') as fr:
for line in fr.readlines():
line = line.strip()
if not line:
continue
example = list(map(int, line.split(',')))
dataSet.append(example)
return dataSet
except FileNotFoundError:
print(f"错误:未找到文件 {filename},请放在代码同一目录下")
return None
if __name__ == "__main__":
labels = ['年龄段', '有工作', '有自己的房', '信贷情况']
train_data = loadDataSet('dataset.txt')
test_data = loadDataSet('testset.txt')
if not train_data or not test_data:
exit()
# 构建C4.5决策树
myTree = createTree(train_data, labels.copy())
print("C4.5决策树结构:")
print(myTree)
print("\n" + "-" * 50 + "\n")
# 可视化决策树
createPlot(myTree)
# 测试集验证
correct = 0
print("测试集预测结果:")
print(f"{'样本特征':<16} {'预测':<5} {'真实':<6} 结果")
print("-" * 40)
for sample in test_data:
features = sample[:-1]
true_label = sample[-1]
pred_label = classify(myTree, labels, features)
res = "正确" if pred_label == true_label else "错误"
if res == "正确":
correct += 1
print(f"{str(features):<16} {pred_label:<5} {true_label:<6} {res}")
# 计算准确率
accuracy = correct / len(test_data) * 100
print(f"\n准确率: {accuracy:.2f}%")
三、使用说明
将代码保存为 treesID3.py 和 treesC4.5.py。
准备 dataset.txt(训练数据)和 testset.txt(测试数据)。
直接运行脚本即可:
- 输出决策树结构
- 显示可视化决策树
- 输出测试集预测结果及准确率
运行后会自动生成 matplotlib 绘制的决策树图形,清晰展示两种算法的分割策略差异。


雷达卡


京公网安备 11010802022788号







