楼主: 金财晟
126 0

ID3 与 C4.5 决策树完整代码 [推广有奖]

  • 0关注
  • 0粉丝

等待验证会员

学前班

40%

还不是VIP/贵宾

-

威望
0
论坛币
0 个
通用积分
0
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
20 点
帖子
1
精华
0
在线时间
0 小时
注册时间
2018-11-12
最后登录
2018-11-12

楼主
金财晟 发表于 2025-11-13 16:31:21 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

一、ID3 决策树完整代码(treesID3.py)

from math import log
import operator
import matplotlib.pyplot as plt

# 解决中文显示问题
plt.rcParams["font.family"] = ["sans-serif"]
plt.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False


# ---------------------- ID3核心算法 ----------------------
def calcShannonEnt(dataSet):
    """计算数据集的香农熵"""
    numEntries = len(dataSet)
    labelCounts = {}
    # 统计每个标签出现的次数
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def splitDataSet(dataSet, axis, value):
    """按照给定特征划分数据集"""
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            # 去除当前特征列
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    """选择最优特征(基于信息增益)"""
    numFeatures = len(dataSet[0]) - 1  # 特征数量(最后一列是标签)
    baseEntropy = calcShannonEnt(dataSet)  # 基础熵
    bestInfoGain = 0.0
    bestFeature = -1
    infoGains = []  # 保存所有特征的信息增益

    for i in range(numFeatures):
        # 获取当前特征的所有取值
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0

        # 计算条件熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)

        # 计算信息增益
        infoGain = baseEntropy - newEntropy
        infoGains.append(infoGain)

        # 更新最佳特征
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature, infoGains  # 返回最佳特征索引和所有特征的信息增益


def majorityCnt(classList):
    """多数表决法决定叶子节点类别"""
    classCount = {}
    for vote in classList:
        if vote not in classCount:
            classCount[vote] = 0
        classCount[vote] += 1
    # 按出现次数排序
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    """构建ID3决策树"""
    classList = [example[-1] for example in dataSet]
    
    # 递归终止条件1:所有样本属于同一类别
    if classList.count(classList[0]) == len(classList):
        return "贷款" if classList[0] == 1 else "不贷款"
    
    # 递归终止条件2:没有特征可分,返回多数类
    if len(dataSet[0]) == 1:
        majority = majorityCnt(classList)
        return "贷款" if majority == 1 else "不贷款"

    # 选择最佳特征
    bestFeat, infoGains = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    infoGain = infoGains[bestFeat]  # 当前节点的信息增益

    # 构建树节点
    myTree = {
        'label': bestFeatLabel,
        'info_gain': round(infoGain, 4),  # 保留4位小数
        'children': {}
    }

    # 递归构建子树
    subLabels = labels[:]  # 复制标签列表(避免修改原列表)
    del(subLabels[bestFeat])  # 删除已使用的特征
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:
        myTree['children'][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), 
            subLabels
        )
    return myTree


def classify(inputTree, featLabels, testVec):
    """使用决策树进行分类预测"""
    if isinstance(inputTree, str):  # 叶子节点
        return 1 if inputTree == "贷款" else 0

    firstLabel = inputTree['label']
    featIndex = featLabels.index(firstLabel)
    secondDict = inputTree['children']

    # 递归查找匹配的子节点
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            return classify(secondDict[key], featLabels, testVec)
    return None  # 未找到匹配


# ---------------------- 可视化功能 ----------------------
# 节点样式定义
decision_node = dict(boxstyle="sawtooth,pad=1.2", fc="lightblue", ec="black", lw=1.5)
leaf_node = dict(boxstyle="round4,pad=1.2", fc="lightgreen", ec="black", lw=1.5)
arrow_args = dict(arrowstyle="<-", color="gray", lw=2,
                  shrinkA=0, shrinkB=22,
                  connectionstyle="arc3,rad=0.05")


def getNumLeafs(myTree):
    """获取叶子节点数量"""
    if isinstance(myTree, str):
        return 1
    numLeafs = 0
    for key in myTree['children'].keys():
        numLeafs += getNumLeafs(myTree['children'][key])
    return numLeafs


def getTreeDepth(myTree):
    """获取树的深度"""
    if isinstance(myTree, str):
        return 1
    maxDepth = 0
    for key in myTree['children'].keys():
        thisDepth = 1 + getTreeDepth(myTree['children'][key])
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """绘制节点"""
    plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                 xytext=centerPt, textcoords='axes fraction',
                 va="center", ha="center", bbox=nodeType,
                 arrowprops=arrow_args, fontsize=12)


def plotMidText(cntrPt, parentPt, txtString):
    """在父子节点之间添加文本(特征取值)"""
    xMid = (parentPt[0] + cntrPt[0]) / 2.0
    yMid = (parentPt[1] + cntrPt[1]) / 2.0

    # 微调文本位置避免重叠
    dx = cntrPt[0] - parentPt[0]
    dy = cntrPt[1] - parentPt[1]
    offset = 0.02
    if abs(dx) < 0.05:
        xMid += offset * (1 if cntrPt[0] > parentPt[0] else -1)
    else:
        yMid += offset * (1 if cntrPt[1] > parentPt[1] else -1)

    plt.text(xMid, yMid, txtString, va="center", ha="center",
             fontsize=10, color='darkred', weight='bold',
             bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.9))


def plotInfoGain(centerPt, infoGain):
    """在决策节点下方绘制信息增益"""
    plt.text(centerPt[0], centerPt[1] - 0.045,
             f"信息增益: {infoGain}",
             va="center", ha="center",
             fontsize=9, color='blue', weight='bold',
             bbox=dict(boxstyle="round,pad=0.2", facecolor='lightyellow', alpha=0.8))


def plotTree(myTree, parentPt, nodeTxt):
    """递归绘制决策树"""
    if isinstance(myTree, str):  # 叶子节点
        plotTree.xOff += 1.0 / plotTree.totalW
        leaf_center = (plotTree.xOff, plotTree.yOff)
        plotNode(myTree, leaf_center, parentPt, leaf_node)
        plotMidText(leaf_center, parentPt, nodeTxt)
        return

    numLeafs = getNumLeafs(myTree)
    firstStr = myTree['label']
    infoGain = myTree['info_gain']

    # 计算当前节点位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    # 绘制当前节点
    plotNode(firstStr, cntrPt, parentPt, decision_node)
    plotInfoGain(cntrPt, infoGain)
    plotMidText(cntrPt, parentPt, nodeTxt)

    # 递归绘制子节点
    secondDict = myTree['children']
    plotTree.yOff -= 1.1 / plotTree.totalD
    for key in secondDict.keys():
        plotTree(secondDict[key], cntrPt, str(key))
    plotTree.yOff += 1.1 / plotTree.totalD


def createPlot(inTree):
    """创建决策树可视化图表"""
    fig = plt.figure(1, facecolor='white', figsize=(16, 12))
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    plt.subplot(111, frameon=False, **axprops)

    # 初始化树布局参数
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0

    # 绘制树
    plotTree(inTree, (0.5, 1.0), '')

    # 添加图例
    plt.figtext(0.02, 0.02,
                "● 决策节点(蓝色): 特征名称\n● 叶节点(绿色): 分类结果\n● 边上数字: 特征取值",
                fontsize=10,
                bbox=dict(boxstyle="round", facecolor='lightgray', alpha=0.8))

    plt.title('决策树可视化 (ID3算法)', fontsize=16, pad=40, fontweight='bold')
    plt.tight_layout()
    plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.12)
    plt.show()


# ---------------------- 数据加载与主函数 ----------------------
def loadDataSet(filename):
    """从文件加载数据集"""
    dataSet = []
    try:
        with open(filename, 'r', encoding='utf-8') as fr:
            for line in fr.readlines():
                line = line.strip()
                if not line:  # 跳过空行
                    continue
                example = list(map(int, line.split(',')))
                dataSet.append(example)
        return dataSet
    except FileNotFoundError:
        print(f"错误:未找到文件 {filename},请放在代码同一目录下")
        return None


if __name__ == "__main__":
    # 特征标签
    labels = ['年龄段', '有工作', '有自己的房', '信贷情况']
    
    # 加载数据
    train_data = loadDataSet('dataset.txt')
    test_data = loadDataSet('testset.txt')

    if not train_data or not test_data:
        exit()

    # 构建决策树
    myTree = createTree(train_data, labels.copy())
    print("ID3决策树结构:")
    print(myTree)
    print("\n" + "-" * 50 + "\n")

    # 可视化决策树
    createPlot(myTree)

    # 测试集验证
    correct = 0
    print("测试集预测结果:")
    print(f"{'样本特征':<16} {'预测':<5} {'真实':<6} 结果")
    print("-" * 40)
    for sample in test_data:
        features = sample[:-1]
        true_label = sample[-1]
        pred_label = classify(myTree, labels, features)
        res = "正确" if pred_label == true_label else "错误"
        if res == "正确":
            correct += 1
        print(f"{str(features):<16} {pred_label:<5} {true_label:<6} {res}")
    
    # 计算准确率
    accuracy = correct / len(test_data) * 100
    print(f"\n准确率: {accuracy:.2f}%")

二、C4.5 决策树完整代码(treesC4.5.py)

from math import log
import operator
import matplotlib.pyplot as plt

# 解决中文显示问题
plt.rcParams["font.family"] = ["sans-serif"]
plt.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False


# ---------------------- C4.5核心算法 ----------------------
def calcShannonEnt(dataSet):
    """计算数据集的香农熵(与ID3相同)"""
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


def splitDataSet(dataSet, axis, value):
    """按照给定特征划分数据集(与ID3相同)"""
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def calcFeatureEntropy(dataSet, axis):
    """计算特征本身的熵(分裂信息)- C4.5特有"""
    numEntries = len(dataSet)
    featureCounts = {}
    # 统计特征每个取值的出现次数
    for featVec in dataSet:
        featureVal = featVec[axis]
        if featureVal not in featureCounts:
            featureCounts[featureVal] = 0
        featureCounts[featureVal] += 1

    # 计算特征熵
    featureEntropy = 0.0
    for key in featureCounts:
        prob = float(featureCounts[key]) / numEntries
        featureEntropy -= prob * log(prob, 2)
    return featureEntropy


def chooseBestFeatureToSplit(dataSet):
    """选择最优特征(基于信息增益率)- C4.5核心差异"""
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestGainRatio = 0.0
    bestFeature = -1
    infoGainRatios = []  # 保存所有特征的信息增益率

    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0

        # 1. 计算条件熵(与ID3相同)
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)

        # 2. 计算信息增益
        infoGain = baseEntropy - newEntropy

        # 3. 计算特征熵(分裂信息)
        featureEntropy = calcFeatureEntropy(dataSet, i)

        # 4. 计算信息增益率(避免除零错误)
        if featureEntropy == 0:
            gainRatio = 0
        else:
            gainRatio = infoGain / featureEntropy

        infoGainRatios.append(gainRatio)

        # 更新最佳特征
        if gainRatio > bestGainRatio:
            bestGainRatio = gainRatio
            bestFeature = i

    return bestFeature, infoGainRatios


def majorityCnt(classList):
    """多数表决法决定叶子节点类别(与ID3相同)"""
    classCount = {}
    for vote in classList:
        if vote not in classCount:
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    """构建C4.5决策树"""
    classList = [example[-1] for example in dataSet]
    
    # 递归终止条件(与ID3相同)
    if classList.count(classList[0]) == len(classList):
        return "贷款" if classList[0] == 1 else "不贷款"
    if len(dataSet[0]) == 1:
        majority = majorityCnt(classList)
        return "贷款" if majority == 1 else "不贷款"

    # 选择最佳特征(使用信息增益率)
    bestFeat, infoGainRatios = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    infoGainRatio = infoGainRatios[bestFeat]  # 当前节点的信息增益率

    # 构建树节点(存储信息增益率)
    myTree = {
        'label': bestFeatLabel,
        'info_gain_ratio': round(infoGainRatio, 4),  # 保留4位小数
        'children': {}
    }

    # 递归构建子树(与ID3流程相同)
    subLabels = labels[:]
    del(subLabels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:
        myTree['children'][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), 
            subLabels
        )
    return myTree


def classify(inputTree, featLabels, testVec):
    """分类预测函数(与ID3相同)"""
    if isinstance(inputTree, str):
        return 1 if inputTree == "贷款" else 0

    firstLabel = inputTree['label']
    featIndex = featLabels.index(firstLabel)
    secondDict = inputTree['children']

    for key in secondDict.keys():
        if testVec[featIndex] == key:
            return classify(secondDict[key], featLabels, testVec)
    return None


# ---------------------- 可视化功能 ----------------------
# 节点样式(与ID3相同)
decision_node = dict(boxstyle="sawtooth,pad=1.2", fc="lightblue", ec="black", lw=1.5)
leaf_node = dict(boxstyle="round4,pad=1.2", fc="lightgreen", ec="black", lw=1.5)
arrow_args = dict(arrowstyle="<-", color="gray", lw=2,
                  shrinkA=0, shrinkB=22,
                  connectionstyle="arc3,rad=0.05")


def getNumLeafs(myTree):
    """获取叶子节点数量(与ID3相同)"""
    if isinstance(myTree, str):
        return 1
    numLeafs = 0
    for key in myTree['children'].keys():
        numLeafs += getNumLeafs(myTree['children'][key])
    return numLeafs


def getTreeDepth(myTree):
    """获取树的深度(与ID3相同)"""
    if isinstance(myTree, str):
        return 1
    maxDepth = 0
    for key in myTree['children'].keys():
        thisDepth = 1 + getTreeDepth(myTree['children'][key])
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """绘制节点(与ID3相同)"""
    plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                 xytext=centerPt, textcoords='axes fraction',
                 va="center", ha="center", bbox=nodeType,
                 arrowprops=arrow_args, fontsize=12)


def plotMidText(cntrPt, parentPt, txtString):
    """在父子节点之间添加文本(与ID3相同)"""
    xMid = (parentPt[0] + cntrPt[0]) / 2.0
    yMid = (parentPt[1] + cntrPt[1]) / 2.0

    dx = cntrPt[0] - parentPt[0]
    dy = cntrPt[1] - parentPt[1]
    offset = 0.02
    if abs(dx) < 0.05:
        xMid += offset * (1 if cntrPt[0] > parentPt[0] else -1)
    else:
        yMid += offset * (1 if cntrPt[1] > parentPt[1] else -1)

    plt.text(xMid, yMid, txtString, va="center", ha="center",
             fontsize=10, color='darkred', weight='bold',
             bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.9))


def plotInfoGainRatio(centerPt, infoGainRatio):
    """在决策节点下方绘制信息增益率 - C4.5特有"""
    plt.text(centerPt[0], centerPt[1] - 0.045,
             f"信息增益率: {infoGainRatio}",
             va="center", ha="center",
             fontsize=9, color='blue', weight='bold',
             bbox=dict(boxstyle="round,pad=0.2", facecolor='lightyellow', alpha=0.8))


def plotTree(myTree, parentPt, nodeTxt):
    """递归绘制决策树(显示信息增益率)"""
    if isinstance(myTree, str):  # 叶子节点
        plotTree.xOff += 1.0 / plotTree.totalW
        leaf_center = (plotTree.xOff, plotTree.yOff)
        plotNode(myTree, leaf_center, parentPt, leaf_node)
        plotMidText(leaf_center, parentPt, nodeTxt)
        return

    numLeafs = getNumLeafs(myTree)
    firstStr = myTree['label']
    infoGainRatio = myTree['info_gain_ratio']  # 使用信息增益率

    # 计算当前节点位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    # 绘制当前节点
    plotNode(firstStr, cntrPt, parentPt, decision_node)
    plotInfoGainRatio(cntrPt, infoGainRatio)  # 绘制信息增益率
    plotMidText(cntrPt, parentPt, nodeTxt)

    # 递归绘制子节点
    secondDict = myTree['children']
    plotTree.yOff -= 1.1 / plotTree.totalD
    for key in secondDict.keys():
        plotTree(secondDict[key], cntrPt, str(key))
    plotTree.yOff += 1.1 / plotTree.totalD


def createPlot(inTree):
    """创建决策树可视化图表(标题和图例适配C4.5)"""
    fig = plt.figure(1, facecolor='white', figsize=(16, 12))
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    plt.subplot(111, frameon=False, **axprops)

    # 初始化树布局参数
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0

    # 绘制树
    plotTree(inTree, (0.5, 1.0), '')

    # 添加图例(说明使用信息增益率)
    plt.figtext(0.02, 0.02,
                "● 决策节点(蓝色): 特征名称\n● 叶节点(绿色): 分类结果\n● 边上数字: 特征取值\n● C4.5算法使用信息增益率",
                fontsize=10,
                bbox=dict(boxstyle="round", facecolor='lightgray', alpha=0.8))

    plt.title('决策树可视化 (C4.5算法)', fontsize=16, pad=40, fontweight='bold')
    plt.tight_layout()
    plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.12)
    plt.show()


# ---------------------- 数据加载与主函数 ----------------------
def loadDataSet(filename):
    """从文件加载数据集(与ID3相同)"""
    dataSet = []
    try:
        with open(filename, 'r', encoding='utf-8') as fr:
            for line in fr.readlines():
                line = line.strip()
                if not line:
                    continue
                example = list(map(int, line.split(',')))
                dataSet.append(example)
        return dataSet
    except FileNotFoundError:
        print(f"错误:未找到文件 {filename},请放在代码同一目录下")
        return None


if __name__ == "__main__":
    labels = ['年龄段', '有工作', '有自己的房', '信贷情况']
    train_data = loadDataSet('dataset.txt')
    test_data = loadDataSet('testset.txt')

    if not train_data or not test_data:
        exit()

    # 构建C4.5决策树
    myTree = createTree(train_data, labels.copy())
    print("C4.5决策树结构:")
    print(myTree)
    print("\n" + "-" * 50 + "\n")

    # 可视化决策树
    createPlot(myTree)

    # 测试集验证
    correct = 0
    print("测试集预测结果:")
    print(f"{'样本特征':<16} {'预测':<5} {'真实':<6} 结果")
    print("-" * 40)
    for sample in test_data:
        features = sample[:-1]
        true_label = sample[-1]
        pred_label = classify(myTree, labels, features)
        res = "正确" if pred_label == true_label else "错误"
        if res == "正确":
            correct += 1
        print(f"{str(features):<16} {pred_label:<5} {true_label:<6} {res}")
    
    # 计算准确率
    accuracy = correct / len(test_data) * 100
    print(f"\n准确率: {accuracy:.2f}%")

三、使用说明

将代码保存为 treesID3.py 和 treesC4.5.py。

准备 dataset.txt(训练数据)和 testset.txt(测试数据)。

直接运行脚本即可:

  • 输出决策树结构
  • 显示可视化决策树
  • 输出测试集预测结果及准确率

运行后会自动生成 matplotlib 绘制的决策树图形,清晰展示两种算法的分割策略差异。

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:决策树 Matplotlib Majority Children features

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
扫码
拉您进交流群
GMT+8, 2026-5-3 06:52