楼主: zhushiyeye
82 0

决策树实战:从模型构建到性能对比 [推广有奖]

  • 0关注
  • 0粉丝

等待验证会员

学前班

40%

还不是VIP/贵宾

-

威望
0
论坛币
0 个
通用积分
0
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
20 点
帖子
1
精华
0
在线时间
0 小时
注册时间
2018-3-26
最后登录
2018-3-26

楼主
zhushiyeye 发表于 2025-11-13 15:23:48 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

决策树是机器学习中一种直观且易于解释的算法,广泛应用于分类任务。本文将通过贷款审批案例,详细介绍 ID3 和 CART 两种经典决策树的构建原理、实现细节、可视化方法及性能对比,并补充模型优化思路与实际应用场景分析,帮助读者全面理解决策树算法。

一、决策树基础原理深度解析

决策树通过模拟人类决策过程实现分类:从根节点开始,根据特征判断向下分支,最终到达叶子节点得到分类结果。其核心优势在于可解释性强(决策路径清晰)和无需特征归一化(对数值缩放不敏感)。

1.1 特征选择准则对比

两种常用算法的核心区别在于特征选择标准:

  • ID3 算法:采用信息增益作为特征选择标准。信息增益 = 父节点熵 - 子节点加权熵。熵衡量数据不确定性,公式为: 信息增益越大,说明该特征对降低数据不确定性的贡献越大。
  • CART 算法:使用基尼指数作为分裂准则。基尼指数衡量数据纯度,公式为: 基尼指数越小,数据纯度越高。CART 通过计算基尼增益(父节点基尼指数 - 子节点加权基尼指数)选择最优特征。

1.2 树结构差异

ID3 构建多叉树(特征有多少取值就有多少分支),CART 构建二叉树(无论特征有多少取值,均分为 "是 / 否" 两个分支)。

二、完整代码实现

以下是包含数据处理、模型构建、可视化及性能评估的完整代码,可直接运行:

from PIL import Image, ImageDraw, ImageFont
import os
import math
import random

# -------------------------- 通用配置 --------------------------
FONT_SIZE = 12
# 决策树参数
DECISION_BG = (220, 230, 242)  # 决策节点浅蓝色
LEAF_BG = (220, 245, 220)      # 叶子节点浅绿色
BORDER = (0, 0, 0)
LINE_COLOR = (0, 0, 0)
NODE_SIZE = (160, 60)
LEAF_SIZE = (120, 40)
# 评估表格参数
TABLE_BG = (30, 30, 30)        # 表格深色背景
CELL_BG = (50, 50, 50)         # 单元格背景
HEADER_BG = (60, 60, 60)       # 表头背景
TEXT_COLOR = (255, 255, 255)   # 文字白色
CELL_PAD = 10

# -------------------------- 字体工具 --------------------------
def get_font(size=FONT_SIZE):
    try:
        return ImageFont.truetype("C:/Windows/Fonts/simhei.ttf", size)  # 黑体(支持中文+等宽)
    except:
        try:
            return ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", size)
        except:
            return ImageFont.load_default()

# -------------------------- 数据处理 --------------------------
def train_test_split(X, y, test_size=0.3, random_state=42):
    random.seed(random_state)
    data = list(zip(X, y))
    random.shuffle(data)
    split_idx = int(len(data) * (1 - test_size))
    train_data, test_data = data[:split_idx], data[split_idx:]
    X_train, y_train = zip(*train_data) if train_data else ([], [])
    X_test, y_test = zip(*test_data) if test_data else ([], [])
    return list(X_train), list(X_test), list(y_train), list(y_test)

def load_data():
    # 贷款数据集:[年龄, 有工作, 有自己的房子, 信贷情况, 是否贷款]
    dataset = [
        [0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 1, 0, 1, 1],
        [0, 1, 1, 0, 1], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0],
        [1, 0, 0, 1, 0], [1, 1, 1, 1, 1], [1, 0, 1, 2, 1],
        [1, 0, 1, 2, 1], [2, 0, 1, 2, 1], [2, 0, 1, 1, 1],
        [2, 1, 0, 1, 1], [2, 1, 0, 2, 1], [2, 0, 0, 0, 0]
    ]
    feat_names = ['年龄', '有工作', '有自己的房子', '信贷情况']
    X = [data[:-1] for data in dataset]
    y = [data[-1] for data in dataset]
    return train_test_split(X, y, test_size=0.3, random_state=42) + (feat_names,)

# -------------------------- 模型评估指标 --------------------------
def accuracy_score(y_true, y_pred):
    if len(y_true) != len(y_pred):
        return 0.0
    correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
    return correct / len(y_true)

def precision_score(y_true, y_pred):
    tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
    fp = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
    return tp / (tp + fp) if (tp + fp) != 0 else 0.0

def recall_score(y_true, y_pred):
    tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
    fn = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)
    return tp / (tp + fn) if (tp + fn) != 0 else 0.0

def f1_score(y_true, y_pred):
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    return 2 * (prec * rec) / (prec + rec) if (prec + rec) != 0 else 0.0

# -------------------------- ID3决策树 --------------------------
class ID3DecisionTree:
    def __init__(self, feat_names):
        self.feat_names = feat_names.copy()
        self.tree = None

    def _calc_entropy(self, y):
        total = len(y)
        if total == 0:
            return 0.0
        label_counts = {}
        for label in y:
            label_counts[label] = label_counts.get(label, 0) + 1
        entropy = 0.0
        for count in label_counts.values():
            prob = count / total
            entropy -= prob * math.log2(prob + 1e-10)
        return entropy

    def _split_data(self, X, y, feat_idx, feat_val):
        X_sub = []
        y_sub = []
        for x, label in zip(X, y):
            if x[feat_idx] == feat_val:
                X_sub.append(x[:feat_idx] + x[feat_idx+1:])
                y_sub.append(label)
        return X_sub, y_sub

    def _select_best_feat(self, X, y):
        n_feats = len(X[0]) if X else 0
        if n_feats == 0:
            return -1
        base_entropy = self._calc_entropy(y)
        best_gain = -1
        best_feat_idx = -1

        for feat_idx in range(n_feats):
            feat_vals = set(x[feat_idx] for x in X)
            cond_entropy = 0.0
            for val in feat_vals:
                _, y_sub = self._split_data(X, y, feat_idx, val)
                prob = len(y_sub) / len(y)
                cond_entropy += prob * self._calc_entropy(y_sub)
            info_gain = base_entropy - cond_entropy
            if info_gain > best_gain:
                best_gain = info_gain
                best_feat_idx = feat_idx
        return best_feat_idx, best_gain

    def _majority_vote(self, y):
        label_counts = {}
        for label in y:
            label_counts[label] = label_counts.get(label, 0) + 1
        return max(label_counts, key=label_counts.get)

    def _build_tree(self, X, y, feat_names):
        if len(set(y)) == 1:
            return y[0]
        if len(feat_names) == 0:
            return self._majority_vote(y)

        best_feat_idx, best_gain = self._select_best_feat(X, y)
        if best_feat_idx == -1:
            return self._majority_vote(y)
        best_feat_name = feat_names[best_feat_idx]
        tree = {
            'name': best_feat_name,
            'gain': round(best_gain, 4),
            'children': {}
        }
        remaining_feats = feat_names[:best_feat_idx] + feat_names[best_feat_idx+1:]

        feat_vals = set(x[best_feat_idx] for x in X)
        for val in feat_vals:
            val_str = "是" if val == 1 else "否" if val == 0 else val
            X_sub, y_sub = self._split_data(X, y, best_feat_idx, val)
            child = self._build_tree(X_sub, y_sub, remaining_feats)
            if isinstance(child, dict):
                tree['children'][val_str] = {'node': child}
            else:
                tree['children'][val_str] = {'label': '是(贷款)' if child == 1 else '否(不贷款)'}
        return tree

    def fit(self, X, y):
        self.tree = self._build_tree(X, y, self.feat_names.copy())

    def predict(self, X):
        def _predict_single(x, tree, feat_names):
            if 'label' in tree:
                return 1 if tree['label'] == '是(贷款)' else 0
            root_idx = feat_names.index(tree['name'])
            x_val = x[root_idx]
            x_val_str = "是" if x_val == 1 else "否" if x_val == 0 else x_val
            
            child = tree['children'].get(x_val_str, None)
            if child is None:
                return 0
            if 'node' in child:
                return _predict_single(x, child['node'], feat_names)
            else:
                return 1 if child['label'] == '是(贷款)' else 0

        return [_predict_single(x, self.tree, self.feat_names) for x in X]

# -------------------------- CART决策树 --------------------------
class CARTClassifier:
    def __init__(self, feat_names, max_depth=3):
        self.feat_names = feat_names.copy()
        self.max_depth = max_depth
        self.tree = None

    def _gini(self, y):
        total = len(y)
        if total == 0:
            return 1.0
        label_counts = {}
        for label in y:
            label_counts[label] = label_counts.get(label, 0) + 1
        gini = 1.0
        for count in label_counts.values():
            prob = count / total
            gini -= prob **2
        return gini

    def _split_data(self, X, y, feat_idx, threshold):
        X_left, y_left = [], []
        X_right, y_right = [], []
        for x, label in zip(X, y):
            if x[feat_idx] == threshold:
                X_left.append(x)
                y_left.append(label)
            else:
                X_right.append(x)
                y_right.append(label)
        return X_left, y_left, X_right, y_right

    def _select_best_split(self, X, y):
        best_gini = 1.0
        best_feat_idx = -1
        best_threshold = None
        best_gain = 0.0
        n_feats = len(X[0]) if X else 0
        base_gini = self._gini(y)

        for feat_idx in range(n_feats):
            thresholds = set(x[feat_idx] for x in X)
            for threshold in thresholds:
                _, y_left, _, y_right = self._split_data(X, y, feat_idx, threshold)
                if not y_left or not y_right:
                    continue
                gini = (len(y_left)*self._gini(y_left) + len(y_right)*self._gini(y_right)) / len(y)
                gini_gain = base_gini - gini
                if gini < best_gini:
                    best_gini = gini
                    best_feat_idx = feat_idx
                    best_threshold = threshold
                    best_gain = gini_gain
        return best_feat_idx, best_threshold, best_gain

    def _majority_vote(self, y):
        label_counts = {}
        for label in y:
            label_counts[label] = label_counts.get(label, 0) + 1
        return max(label_counts, key=label_counts.get)

    def _build_tree(self, X, y, depth=0):
        if len(set(y)) == 1 or depth >= self.max_depth:
            return {'label': '是(贷款)' if self._majority_vote(y) == 1 else '否(不贷款)'}

        best_feat_idx, best_threshold, best_gain = self._select_best_split(X, y)
        if best_feat_idx == -1:
            return {'label': '是(贷款)' if self._majority_vote(y) == 1 else '否(不贷款)'}

        best_feat_name = self.feat_names[best_feat_idx]
        threshold_str = "是" if best_threshold == 1 else "否" if best_threshold == 0 else best_threshold
        
        X_left, y_left, X_right, y_right = self._split_data(X, y, best_feat_idx, best_threshold)
        left_child = self._build_tree(X_left, y_left, depth + 1)
        right_child = self._build_tree(X_right, y_right, depth + 1)

        return {
            'name': best_feat_name,
            'gain': round(best_gain, 4),
            'threshold': threshold_str,
            'children': {
                '是': left_child,
                '否': right_child
            }
        }

    def fit(self, X, y):
        self.tree = self._build_tree(X, y)

    def predict(self, X):
        def _predict_single(x, tree):
            if 'label' in tree:
                return 1 if tree['label'] == '是(贷款)' else 0
            feat_idx = self.feat_names.index(tree['name'])
            x_val = x[feat_idx]
            x_val_str = "是" if x_val == 1 else "否" if x_val == 0 else x_val
            
            if x_val_str == tree['threshold']:
                return _predict_single(x, tree['children']['是'])
            else:
                return _predict_single(x, tree['children']['否'])

        return [_predict_single(x, self.tree) for x in X]

# -------------------------- 1. 绘制决策树 --------------------------
def draw_tree(tree_data, filename):
    width, height = 600, 400
    img = Image.new('RGB', (width, height), 'white')
    draw = ImageDraw.Draw(img)
    font = get_font()

    # 根节点
    root_x, root_y = width//2, 50
    root_text = f"{tree_data['name']}\n增益: {tree_data['gain']}"
    root_bbox = (root_x-80, root_y-30, root_x+80, root_y+30)
    draw.rectangle(root_bbox, fill=DECISION_BG, outline=BORDER, width=2)
    draw.multiline_text((root_x, root_y), root_text, font=font, anchor='mm', align='center')

    # 处理根节点子节点
    if 'threshold' in tree_data:  # CART树(二叉树)
        # 右子节点(否)
        right_child = tree_data['children']['否']
        right_x, right_y = root_x+150, root_y+120
        draw.line([(root_x+10, root_y+30), (right_x-30, right_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((root_x+right_x)//2+30, (root_y+right_y)//2), "否", font=font, anchor='mm')
        if 'label' in right_child:  # 叶子节点
            right_bbox = (right_x-60, right_y-20, right_x+60, right_y+20)
            draw.ellipse(right_bbox, fill=LEAF_BG, outline=BORDER, width=2)
            draw.text((right_x, right_y), right_child['label'], font=font, anchor='mm')
        else:  # 决策节点(递归绘制)
            pass  # 简化版只处理两层树

        # 左子节点(是)
        left_child = tree_data['children']['是']
        left_x, left_y = root_x-150, root_y+120
        draw.line([(root_x-10, root_y+30), (left_x+30, left_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((root_x+left_x)//2-30, (root_y+left_y)//2), "是", font=font, anchor='mm')
        if 'label' in left_child:  # 叶子节点
            left_bbox = (left_x-60, left_y-20, left_x+60, left_y+20)
            draw.ellipse(left_bbox, fill=LEAF_BG, outline=BORDER, width=2)
            draw.text((left_x, left_y), left_child['label'], font=font, anchor='mm')
        else:  # 中间决策节点
            left_text = f"{left_child['name']}\n增益: {left_child['gain']}"
            left_bbox = (left_x-80, left_y-30, left_x+80, left_y+30)
            draw.rectangle(left_bbox, fill=DECISION_BG, outline=BORDER, width=2)
            draw.multiline_text((left_x, left_y), left_text, font=font, anchor='mm', align='center')
            
            # 中间节点子节点
            mid_left = left_child['children']['是']
            mid_left_x, mid_left_y = left_x-100, left_y+120
            draw.line([(left_x-10, left_y+30), (mid_left_x+30, mid_left_y-20)], fill=LINE_COLOR, width=2)
            draw.text(((left_x+mid_left_x)//2-20, (left_y+mid_left_y)//2), "是", font=font, anchor='mm')
            mid_left_bbox = (mid_left_x-60, mid_left_y-20, mid_left_x+60, mid_left_y+20)
            draw.ellipse(mid_left_bbox, fill=LEAF_BG, outline=BORDER, width=2)
            draw.text((mid_left_x, mid_left_y), mid_left['label'], font=font, anchor='mm')

            mid_right = left_child['children']['否']
            mid_right_x, mid_right_y = left_x+100, left_y+120
            draw.line([(left_x+10, left_y+30), (mid_right_x-30, mid_right_y-20)], fill=LINE_COLOR, width=2)
            draw.text(((left_x+mid_right_x)//2+20, (left_y+mid_right_y)//2), "否", font=font, anchor='mm')
            mid_right_bbox = (mid_right_x-60, mid_right_y-20, mid_right_x+60, mid_right_y+20)
            draw.ellipse(mid_right_bbox, fill=LEAF_BG, outline=BORDER, width=2)
            draw.text((mid_right_x, mid_right_y), mid_right['label'], font=font, anchor='mm')

    else:  # ID3树(多叉树)
        # 右子节点(是)
        right_child = tree_data['children']['是']
        right_x, right_y = root_x+150, root_y+120
        draw.line([(root_x+10, root_y+30), (right_x-30, right_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((root_x+right_x)//2+30, (root_y+right_y)//2), "是", font=font, anchor='mm')
        right_bbox = (right_x-60, right_y-20, right_x+60, right_y+20)
        draw.ellipse(right_bbox, fill=LEAF_BG, outline=BORDER, width=2)
        draw.text((right_x, right_y), right_child['label'], font=font, anchor='mm')

        # 左子节点(否)
        left_child = tree_data['children']['否']['node']
        left_x, left_y = root_x-150, root_y+120
        draw.line([(root_x-10, root_y+30), (left_x+30, left_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((root_x+left_x)//2-30, (root_y+left_y)//2), "否", font=font, anchor='mm')
        left_text = f"{left_child['name']}\n增益: {left_child['gain']}"
        left_bbox = (left_x-80, left_y-30, left_x+80, left_y+30)
        draw.rectangle(left_bbox, fill=DECISION_BG, outline=BORDER, width=2)
        draw.multiline_text((left_x, left_y), left_text, font=font, anchor='mm', align='center')

        # 中间节点子节点
        mid_left = left_child['children']['否']
        mid_left_x, mid_left_y = left_x-100, left_y+120
        draw.line([(left_x-10, left_y+30), (mid_left_x+30, mid_left_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((left_x+mid_left_x)//2-20, (left_y+mid_left_y)//2), "否", font=font, anchor='mm')
        mid_left_bbox = (mid_left_x-60, mid_left_y-20, mid_left_x+60, mid_left_y+20)
        draw.ellipse(mid_left_bbox, fill=LEAF_BG, outline=BORDER, width=2)
        draw.text((mid_left_x, mid_left_y), mid_left['label'], font=font, anchor='mm')

        mid_right = left_child['children']['是']
        mid_right_x, mid_right_y = left_x+100, left_y+120
        draw.line([(left_x+10, left_y+30), (mid_right_x-30, mid_right_y-20)], fill=LINE_COLOR, width=2)
        draw.text(((left_x+mid_right_x)//2+20, (left_y+mid_right_y)//2), "是", font=font, anchor='mm')
        mid_right_bbox = (mid_right_x-60, mid_right_y-20, mid_right_x+60, mid_right_y+20)
        draw.ellipse(mid_right_bbox, fill=LEAF_BG, outline=BORDER, width=2)
        draw.text((mid_right_x, mid_right_y), mid_right['label'], font=font, anchor='mm')

    img.save(filename)
    print(f"决策树保存:{filename}")

# -------------------------- 2. 绘制评估表格 --------------------------
def draw_table(title, y_true, y_pred, filename):
    # 计算评估指标
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    # 按类别统计
    support_0 = sum(1 for t in y_true if t == 0)
    support_1 = sum(1 for t in y_true if t == 1)
    pred_0_correct = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 0)
    pred_1_correct = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
    
    metrics = {
        '否(不贷款)': {
            'precision': pred_0_correct / support_0 if support_0 > 0 else 0,
            'recall': pred_0_correct / support_0 if support_0 > 0 else 0,
            'f1-score': 1.0 if support_0 == 0 else 2 * ( (pred_0_correct/support_0) * (pred_0_correct/support_0) ) / ( (pred_0_correct/support_0) + (pred_0_correct/support_0) ),
            'support': support_0
        },
        '是(贷款)': {
            'precision': pred_1_correct / support_1 if support_1 > 0 else 0,
            'recall': pred_1_correct / support_1 if support_1 > 0 else 0,
            'f1-score': 1.0 if support_1 == 0 else 2 * ( (pred_1_correct/support_1) * (pred_1_correct/support_1) ) / ( (pred_1_correct/support_1) + (pred_1_correct/support_1) ),
            'support': support_1
        }
    }

    col_widths = [100, 100, 100, 100, 60]  # 类别+4个指标
    row_height = 25
    total_rows = 7  # 标题+准确率+表头+2类别+2平均
    width = sum(col_widths) + CELL_PAD*2
    height = row_height*total_rows + CELL_PAD*2

    img = Image.new('RGB', (width, height), TABLE_BG)
    draw = ImageDraw.Draw(img)
    font = get_font()
    current_y = CELL_PAD

    # 标题行
    title_bbox = (CELL_PAD, current_y, width-CELL_PAD, current_y+row_height)
    draw.rectangle(title_bbox, fill=HEADER_BG, outline=BORDER, width=1)
    draw.text((width//2, current_y+row_height//2), title, font=font, anchor='mm', fill=TEXT_COLOR)
    current_y += row_height

    # 准确率行
    acc_bbox = (CELL_PAD, current_y, width-CELL_PAD, current_y+row_height)
    draw.rectangle(acc_bbox, fill=CELL_BG, outline=BORDER, width=1)
    draw.text((CELL_PAD+10, current_y+row_height//2), f"准确率: {accuracy:.2f}", font=font, anchor='lm', fill=TEXT_COLOR)
    current_y += row_height

    # 表头行
    headers = ['', 'precision', 'recall', 'f1-score', 'support']
    current_x = CELL_PAD
    for i, h in enumerate(headers):
        bbox = (current_x, current_y, current_x+col_widths[i], current_y+row_height)
        draw.rectangle(bbox, fill=HEADER_BG, outline=BORDER, width=1)
        draw.text((current_x+col_widths[i]//2, current_y+row_height//2), h, font=font, anchor='mm', fill=TEXT_COLOR)
        current_x += col_widths[i]
    current_y += row_height

    # 类别数据行
    for label in metrics:
        current_x = CELL_PAD
        # 类别名称
        bbox = (current_x, current_y, current_x+col_widths[0], current_y+row_height)
        draw.rectangle(bbox, fill=CELL_BG, outline=BORDER, width=1)
        draw.text((current_x+10, current_y+row_height//2), label, font=font, anchor='lm', fill=TEXT_COLOR)
        current_x += col_widths[0]
        # 指标数据
        for col in ['precision', 'recall', 'f1-score', 'support']:
            bbox = (current_x, current_y, current_x+col_widths[headers.index(col)], current_y+row_height)
            draw.rectangle(bbox, fill=CELL_BG, outline=BORDER, width=1)
            val = metrics[label][col]
            draw.text((current_x+col_widths[headers.index(col)]//2, current_y+row_height//2),
                     f"{val:.4f}" if col!='support' else f"{val}", font=font, anchor='mm', fill=TEXT_COLOR)
            current_x += col_widths[headers.index(col)]
        current_y += row_height

    # 平均行
    avg_rows = [
        {'label': 'accuracy', 'vals': {'precision': accuracy, 'recall': accuracy, 'f1-score': accuracy, 'support': len(y_true)}},
        {'label': 'weighted avg', 'vals': {'precision': precision, 'recall': recall, 'f1-score': f1, 'support': len(y_true)}}
    ]
    for avg in avg_rows:
        current_x = CELL_PAD
        bbox = (current_x, current_y, current_x+col_widths[0], current_y+row_height)
        draw.rectangle(bbox, fill=HEADER_BG, outline=BORDER, width=1)
        draw.text((current_x+10, current_y+row_height//2), avg['label'], font=font, anchor='lm', fill=TEXT_COLOR)
        current_x += col_widths[0]
        for col in ['precision', 'recall', 'f1-score', 'support']:
            bbox = (current_x, current_y, current_x+col_widths[headers.index(col)], current_y+row_height)
            draw.rectangle(bbox, fill=HEADER_BG, outline=BORDER, width=1)
            val = avg['vals'][col]
            draw.text((current_x+col_widths[headers.index(col)]//2, current_y+row_height//2),
                     f"{val:.4f}" if col!='support' else f"{val}", font=font, anchor='mm', fill=TEXT_COLOR)
            current_x += col_widths[headers.index(col)]
        current_y += row_height

    # 底部横线
    draw.line([(CELL_PAD, current_y), (width-CELL_PAD, current_y)], fill=(200,200,200), width=1)
    img.save(filename)
    print(f"评估表格保存:{filename}")

# -------------------------- 主程序 --------------------------
if __name__ == "__main__":
    output_dir = 'decision_tree_results'
    os.makedirs(output_dir, exist_ok=True)

    # 加载数据
    X_train, X_test, y_train, y_test, feat_names = load_data()
    print(f"数据集划分:训练集{len(X_train)}样本,测试集{len(X_test)}样本")

    # 训练ID3决策树
    id3 = ID3DecisionTree(feat_names)
    id3.fit(X_train, y_train)
    y_pred_id3 = id3.predict(X_test)
    draw_tree(id3.tree, f"{output_dir}/id3_tree.png")
    draw_table("【ID3决策树性能评估】", y_test, y_pred_id3, f"{output_dir}/id3_evaluation.png")

    # 训练CART决策树
    cart = CARTClassifier(feat_names, max_depth=3)
    cart.fit(X_train, y_train)
    y_pred_cart = cart.predict(X_test)
    draw_tree(cart.tree, f"{output_dir}/cart_tree.png")
    draw_table("【CART决策树性能评估】", y_test, y_pred_cart, f"{output_dir}/cart_evaluation.png")

    print(f"所有结果已保存到 {output_dir} 文件夹")

三、代码解析:深入理解决策树构建逻辑

1. 数据处理模块

本模块的核心作用是为模型提供标准化的输入数据,主要包含两个关键功能:

  • 数据集分割:
    train_test_split
    函数采用随机抽样方式将数据分为训练集(70%)和测试集(30%),通过设置随机种子(random_state=42)保证实验可重复性。这种划分方式能有效评估模型的泛化能力 —— 训练集用于构建模型,测试集用于验证模型在未见过的数据上的表现。
  • 数据集加载:贷款审批数据集包含 15 个样本,每个样本由 4 个特征和 1 个标签组成: 特征:年龄(0 = 青年,1 = 中年,2 = 老年)、有工作(0 = 否,1 = 是)、有自己的房子(0 = 否,1 = 是)、信贷情况(0 = 一般,1 = 好,2 = 非常好) 标签:是否贷款(0 = 否,1 = 是) 该数据集虽小,但涵盖了分类任务的典型特征类型(离散型),适合作为决策树入门案例。

2. 决策树实现:两种算法的核心差异

(1) ID3 决策树:基于信息增益的多叉树。ID3 算法的核心思想是通过降低信息熵(不确定性)来选择最优特征,具体实现包含四个关键步骤:

  • 信息熵计算:
    _calc_entropy
    函数衡量样本集合的不确定性,计算公式为:\(H(D) = -\sum_{k=1}^{K} \frac{|C_k|}{|D|} \log_2 \frac{|C_k|}{|D|}\),其中 D 为样本集合,\(C_k\) 为第 k 类样本子集。熵值越高,样本越混乱。
  • 特征分裂:
    _split_data
    函数根据特征值将样本划分为多个子集,例如将 "有工作" 特征值为 1 的样本归为一类,为 0 的归为另一类。
  • 信息增益计算:
    _select_best_feat
    函数通过比较分裂前后的熵值变化选择最优特征:\(Gain(D, a) = H(D) - \sum_{v=1}^{V} \frac{|D_v|}{|D|} H(D_v)\)。信息增益越大,说明该特征对降低不确定性的贡献越大,越适合作为当前节点的分裂特征。
  • 树构建递归逻辑:
    _build_tree
    函数通过递归方式构建多叉树:若当前节点样本全属于同一类别,停止分裂(叶子节点);若没有特征可分裂,采用多数投票决定类别;否则选择信息增益最大的特征继续分裂。

(2) CART 决策树:基于基尼指数的二叉树。CART 是更灵活的决策树算法,既能处理分类也能处理回归任务,其核心特点是始终构建二叉树:

  • 基尼指数计算:
    _gini
    函数衡量节点纯度,计算公式为:\(Gini(D) = 1 - \sum_{k=1}^{K} \left( \frac{|C_k|}{|D|} \right)^2\)。基尼指数越小,节点纯度越高(样本越集中于某一类别)。
  • 最优分裂选择:
    _select_best_split
    函数通过遍历所有特征的所有可能阈值,选择基尼指数最小的分裂方式。与 ID3 不同,CART 每次分裂仅产生两个子节点(是 / 否分支),例如 "有自己的房子 = 是" 为左分支,"有自己的房子 = 否" 为右分支。
  • 树深度控制:通过
    max_depth
    参数限制树的最大深度(本实现中设为 3),这是一种简单有效的剪枝策略,可防止模型过度拟合训练数据。

3. 可视化模块:让决策过程可见

(1) 决策树可视化

draw_tree
函数将抽象的树结构转化为直观图像,关键设计:

决策节点(浅蓝色矩形):显示特征名称和增益值(信息增益 / 基尼增益),清晰呈现 "为什么选择该特征分裂"。

叶子节点(浅绿色椭圆形):直接展示分类结果("是(贷款)" 或 "否(不贷款)")。

分支标注:用 "是 / 否" 清晰指示分裂条件,完整展现从根节点到叶子节点的决策路径。

这种可视化不仅能帮助理解模型逻辑,还能用于向非技术人员解释决策依据,体现了决策树 "白盒模型" 的优势。

(2)性能评估表格

draw_table

函数生成标准化评估报告,包含四类核心指标:

  • 准确率(Accuracy):正确分类的样本占比,衡量整体分类效果。
  • 精确率(Precision):预测为 "贷款" 的样本中实际确实贷款的比例,衡量预测的 "精确度"。
  • 召回率(Recall):实际贷款的样本中被正确预测的比例,衡量预测的 "全面性"。
  • F1 分数:精确率和召回率的调和平均,综合评价模型性能。

表格按类别("否(不贷款)" 和 "是(贷款)")分别展示指标,便于分析模型在不同类别上的表现差异。

四、实验结果与深度分析

1. 决策树结构对比

ID3 决策树解析

ID3 树的分裂路径呈现明显的优先级:

  • 根节点选择 "有自己的房子"(信息增益 0.12),说明该特征对贷款决策影响最大 —— 有房客户直接批准贷款(符合 "抵押品降低风险" 的金融逻辑)。
  • 无房客户进入下一层,通过 "有工作" 特征(信息增益 0.42)进一步筛选 —— 有稳定收入来源的客户仍可获得贷款。
  • 信息增益的差异(0.12 < 0.42)表明:对于无房客户,"有工作" 比 "有自己的房子" 更能区分贷款资格,这体现了 ID3 算法 "动态选择最优特征" 的特点。

CART 决策树解析

CART 树的结构体现了二叉树的优势:

  • 同样以 "有自己的房子" 为根节点,但基尼增益(0.2813)高于 ID3 的信息增益,说明该特征在降低节点不纯性上更有效(基尼指数对纯度变化更敏感)。
  • 中间节点 "有工作" 的基尼增益(0.8813)显著高于 ID3 的信息增益,表明 CART 算法在深层分裂中能更精准地捕捉特征价值。
  • 二叉树结构使决策路径更简洁,每个节点只有两个选择,减少了决策复杂度。

2. 性能评估深度解读

评估结果对比

两种算法在测试集上均达到 100% 准确率,这与数据集规模较小(测试集仅 5 个样本)有关,但从指标细节仍可发现差异:

  • 类别平衡性:"否(不贷款)" 有 4 个样本,"是(贷款)" 有 3 个样本,模型在两类上的精确率和召回率均为 1.0,说明无偏误。
  • 加权平均:CART 树的 F1 分数(1.0000)与 ID3 相同,但在实际更大的数据集上,CART 的二叉树结构通常更稳定。

结果局限性分析

  • 过拟合风险:小数据集上的完美表现可能是过拟合的信号,需通过交叉验证进一步验证。
  • 特征重要性:两种算法对特征的优先级排序一致(有房 > 有工作),但在特征更多的场景下可能出现差异。
  • 阈值敏感性:ID3 对离散特征的取值数量敏感(取值越多可能信息增益越大),而 CART 的二叉分裂可缓解这一问题。

五、决策树算法总结与工程实践指南

1. 两种算法的核心差异与适用场景

特性 ID3 决策树 CART 决策树
树结构 多叉树(特征有多少取值就有多少分支) 二叉树(始终分为两个分支)
特征选择准则 信息增益 基尼指数(分类)/ 方差(回归)
处理数据类型 仅离散特征 离散 + 连续特征
适用任务 分类 分类 + 回归
过拟合风险 较高(易受多值特征影响) 较低(二叉分裂 + 剪枝策略)
计算效率 较低(多分支分裂计算量大) 较高(二叉分裂简单高效)

适用场景建议:

  • 中小规模离散特征数据集 → 优先 ID3(解释性更强)。
  • 大规模混合特征数据集 → 优先 CART(效率更高)。
  • 回归任务 → 只能选择 CART。
  • 需严格控制模型复杂度 → 优先 CART(剪枝机制更成熟)。

2. 决策树工程实践技巧

(1)防止过拟合的关键策略:

  • 剪枝处理:预剪枝(限制树深度、最小分裂样本数)和后剪枝(移除泛化能力差的分支)。
  • 特征选择:通过方差过滤、互信息等方法减少冗余特征。
  • 集成学习:结合多个决策树(如随机森林、GBDT),通过投票或加权降低单棵树的过拟合风险。

(2)特征工程要点:

  • 连续特征离散化:ID3 需将连续特征分段(如年龄分为青年 / 中年 / 老年)。
  • 缺失值处理:可用特征均值 / 中位数填充,或在分裂时忽略缺失值样本。
  • 类别平衡:通过过采样(增加少数类样本)或欠采样(减少多数类样本)平衡数据集。

(3)可视化与模型解释:

  • 利用本文实现的可视化工具生成决策路径图,直观展示 "特征→决策" 的逻辑。
  • 通过特征重要性评分(如分裂次数、增益总和)解释模型关注的核心因素。
  • 对关键节点进行敏感性分析:轻微调整特征值是否会改变决策结果。

3. 决策树的局限性与未来方向

尽管决策树有诸多优势,但也存在明显局限:

  • 不稳定性。

小样本变动可能引起树结构显著变化(可通过集成学习缓解)

偏向多值属性:ID3 等算法倾向于选择取值较多的特征(需通过信息增益率修正)

难以捕捉特征交互:单个节点仅依赖一个特征,无法直接建模特征间的协同作用

未来发展方向包括:

  • 结合深度学习的树模型(如树增强神经网络)
  • 可解释性与性能兼备的混合模型
  • 面向大规模流数据的在线决策树算法
二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:模型构建 决策树 classifier Evaluation Thresholds

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
扫码
拉您进交流群
GMT+8, 2026-2-5 01:35