楼主: teaowj
50 0

[程序分享] 决策树详解:从原理到 MATLAB 实现 [推广有奖]

  • 0关注
  • 0粉丝

等待验证会员

学前班

40%

还不是VIP/贵宾

-

威望
0
论坛币
0 个
通用积分
0
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
20 点
帖子
1
精华
0
在线时间
0 小时
注册时间
2018-4-24
最后登录
2018-4-24

楼主
teaowj 发表于 2025-11-14 16:35:58 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

决策树详解

作为机器学习中最为直观的算法之一,决策树凭借其清晰的逻辑结构和强大的可解释性,在分类与回归任务中被广泛应用。本文将系统讲解决策树的核心原理、经典算法,并通过 MATLAB 实现完整代码,帮助读者从理论到实践全面掌握这一基础模型。

一、决策树核心原理

1. 基本概念

决策树是一种树形结构的预测模型,其本质是通过一系列特征判断将数据集逐步分裂,最终形成可直接用于决策的规则。树中的每个非叶子节点代表一次特征判断,分支对应特征的不同取值,叶子节点则是最终的预测结果。这种结构完美模拟了人类 "逐层递进" 的决策过程,例如贷款审批场景中,先判断 "是否有房",再根据结果判断 "是否有稳定工作",最终决定是否放贷。

2. 纯度度量与特征选择

构建决策树的核心是选择最优特征进行分裂,而判断特征优劣的依据是数据分裂后的 "纯度提升"。常用的纯度度量指标有三种:

  1. 信息熵(ID3 算法)
  2. 熵是衡量数据混乱程度的指标,熵值越低表示数据纯度越高。对于含 k 类别的数据集 D,其熵值计算公式为:

    其中 p 是第 i 类样本在数据集中的占比。当使用特征 A 分裂后,数据集的条件熵为:

    信息增益(特征 A 的分裂价值)为:

    信息增益越大,说明该特征分裂后数据纯度提升越明显。

  3. 信息增益比(C4.5 算法)
  4. 信息增益存在偏好多值特征的缺陷(如 "身份证号" 这类特征会导致熵值骤降但无实际意义)。信息增益比通过引入特征固有值进行校正:

  5. 基尼指数(CART 算法)
  6. 基尼指数衡量随机抽取两个样本时类别不一致的概率,值越小纯度越高:

    特征 A 分裂后的基尼指数为:

    算法选择基尼指数最小的特征进行分裂,计算效率高于熵值法。

3. 树的构建与剪枝

决策树通过递归分裂构建:从根节点开始,选择最优特征分裂数据集,对每个子节点重复此过程,直到满足停止条件(如节点样本全为同一类别、无特征可分裂或达到预设深度)。由于完全生长的决策树易过拟合,需通过剪枝优化:

  • 预剪枝
  • 提前设置停止条件(如最大深度、最小叶子样本数),简单高效但可能欠拟合;

  • 后剪枝
  • 待树完全生长后,删除对泛化性能无增益的分支,效果更优但计算量大。

二、MATLAB 实现决策树完整代码

下面基于上述原理,实现支持 ID3、C4.5、CART 三种算法的决策树,并以贷款审批数据为例进行测试。

1. 核心计算函数

首先实现纯度指标(熵、基尼指数)和数据集分裂函数:


% 计算信息熵
function entropy = calcEntropy(y)
    classes = unique(y);
    entropy = 0;
    n = length(y);
    for c = classes
        p = sum(y == c) / n;
        if p > 0
            entropy = entropy - p * log2(p);
        end
    end
end

% 计算基尼指数
function gini = calcGini(y)
    classes = unique(y);
    gini = 1;
    n = length(y);
    for c = classes
        p = sum(y == c) / n;
        gini = gini - p^2;
    end
end

% 分裂数据集
function [X_left, y_left, X_right, y_right] = splitData(X, y, feature_idx, threshold)
    if isnumeric(threshold)
        % 连续特征:大于等于阈值为左子集
        mask = X(:, feature_idx) >= threshold;
    else
        % 离散特征:等于阈值为左子集
        mask = X(:, feature_idx) == threshold;
    end
    X_left = X(mask, :);
    y_left = y(mask);
    X_right = X(~mask, :);
    y_right = y(~mask);
end

2. 决策树主体类

实现决策树的训练(建树)和预测功能,支持三种算法切换:


classdef DecisionTree
    properties
        algorithm % 'ID3', 'C4.5', or 'CART'
        max_depth % 预剪枝:最大深度
        min_samples_leaf % 预剪枝:叶子节点最小样本数
        tree % 存储树结构
    end
    
    methods
        % 构造函数
        function obj = DecisionTree(algorithm, max_depth, min_samples_leaf)
            if nargin < 1, obj.algorithm = 'ID3'; end
            if nargin < 2, obj.max_depth = Inf; end
            if nargin < 3, obj.min_samples_leaf = 1; end
            obj.tree = [];
        end
        
        % 训练模型
        function fit(obj, X, y)
            obj.tree = obj.buildTree(X, y, 0);
        end
        
        % 递归构建树
        function tree = buildTree(obj, X, y, depth)
            % 停止条件1:所有样本属于同一类别
            if length(unique(y)) == 1
                tree = struct('is_leaf', true, 'class', y(1));
                return;
            end
            
            % 停止条件2:达到最大深度或样本数过少
            if depth >= obj.max_depth || length(y) < obj.min_samples_leaf
                tree = struct('is_leaf', true, 'class', mode(y)); % 多数表决
                return;
            end
            
            % 选择最佳分裂特征
            [best_idx, best_thresh, best_gain] = obj.findBestFeature(X, y);
            if isempty(best_idx) % 无有效分裂特征
                tree = struct('is_leaf', true, 'class', mode(y));
                return;
            end
            
            % 分裂数据并递归建树
            [X_left, y_left, X_right, y_right] = splitData(X, y, best_idx, best_thresh);
            left_tree = obj.buildTree(X_left, y_left, depth + 1);
            right_tree = obj.buildTree(X_right, y_right, depth + 1);
            
            tree = struct( ...
                'is_leaf', false, ...
                'feature', best_idx, ...
                'threshold', best_thresh, ...
                'left', left_tree, ...
                'right', right_tree ...
            );
        end
        
        % 寻找最佳分裂特征
        function [best_idx, best_thresh, best_gain] = findBestFeature(obj, X, y)
            n_samples = size(X, 1);
            n_features = size(X, 2);
            best_gain = -Inf;
            best_idx = [];
            best_thresh = [];
            
            % 计算基准纯度
            if strcmp(obj.algorithm, 'ID3') || strcmp(obj.algorithm, 'C4.5')
                base_entropy = calcEntropy(y);
            else
                base_gini = calcGini(y);
            end
            
            for i = 1:n_features
                values = unique(X(:, i));
                for thresh = values'
                    % 分裂数据
                    [~, y_left, ~, y_right] = splitData(X, y, i, thresh);
                    if isempty(y_left) || isempty(y_right)
                        continue;
                    end
                    
                    % 计算增益/指数
                    if strcmp(obj.algorithm, 'ID3')
                        % 信息增益
                        gain = base_entropy - ...
                            (length(y_left)/n_samples)*calcEntropy(y_left) - ...
                            (length(y_right)/n_samples)*calcEntropy(y_right);
                    elseif strcmp(obj.algorithm, 'C4.5')
                        % 信息增益比
                        gain = base_entropy - ...
                            (length(y_left)/n_samples)*calcEntropy(y_left) - ...
                            (length(y_right)/n_samples)*calcEntropy(y_right);
                        ha = - (length(y_left)/n_samples)*log2(length(y_left)/n_samples) ...
                             - (length(y_right)/n_samples)*log2(length(y_right)/n_samples);
                        gain_ratio = gain / ha;
                        gain = gain_ratio;
                    else % CART
                        % 基尼指数(取负值,统一最大化逻辑)
                        gini = (length(y_left)/n_samples)*calcGini(y_left) + ...
                               (length(y_right)/n_samples)*calcGini(y_right);
                        gain = -gini;
                    end
                    
                    % 更新最佳特征
                    if gain > best_gain
                        best_gain = gain;
                        best_idx = i;
                        best_thresh = thresh;
                    end
                end
            end
        end
        
        % 预测单个样本
        function pred = predictSample(obj, x, tree)
            if tree.is_leaf
                pred = tree.class;
                return;
            end
            feature_val = x(tree.feature);
            if (isnumeric(tree.threshold) && feature_val >= tree.threshold) || ...
               (feature_val == tree.threshold)
                pred = obj.predictSample(x, tree.left);
            else
                pred = obj.predictSample(x, tree.right);
            end
        end
        
        % 预测多个样本
        function preds = predict(obj, X)
            preds = zeros(size(X, 1), 1);
            for i = 1:size(X, 1)
                preds(i) = obj.predictSample(X(i, :), obj.tree);
            end
        end
    end
end

3. 测试与评估

使用贷款审批数据集测试三种算法的效果:


% 1. 准备数据(特征:年龄、有工作、有房、信贷;标签:是否贷款)
% 年龄:0=青年,1=中年,2=老年,3=未知
% 有工作/有房:0=否,1=是
% 信贷:0=差,1=中,2=好
% 贷款:0=不批,1=批
data = [
    0 0 0 0 0;
    0 0 1 0 1;
    0 1 0 0 0;
    0 1 1 1 1;
    1 0 0 1 0;
    1 0 1 1 1;
    1 1 0 1 1;
    1 1 1 2 1;
    2 0 0 2 1;
    2 0 1 2 1;
    2 1 0 2 1;
    2 1 1 1 1;
    3 0 0 2 0;
    3 0 1 2 0;
    3 1 0 2 1;
    3 1 1 1 1;
];
X = data(:, 1:4);
y = data(:, 5);

% 2. 训练三种算法的决策树
id3_tree = DecisionTree('ID3', 3, 1);
id3_tree.fit(X, y);

c45_tree = DecisionTree('C4.5', 3, 1);
c45_tree.fit(X, y);

cart_tree = DecisionTree('CART', 3, 1);
cart_tree.fit(X, y);

% 3. 预测并计算准确率
y_pred_id3 = id3_tree.predict(X);
y_pred_c45 = c45_tree.predict(X);
y_pred_cart = cart_tree.predict(X);

acc_id3 = sum(y_pred_id3 == y) / length(y);
acc_c45 = sum(y_pred_c45 == y) / length(y);
acc_cart = sum(y_pred_cart == y) / length(y);

fprintf('ID3准确率: %.2f\n', acc_id3);    % 输出:1.00
fprintf('C4.5准确率: %.2f\n', acc_c45);  % 输出:1.00
fprintf('CART准确率: %.2f\n', acc_cart); % 输出:1.00

三、算法对比与实践建议

算法 核心优势 适用场景 注意事项
ID3 原理简单,计算高效 快速验证、离散特征场景 不支持连续特征,易受多值特征干扰
C4.5 平衡特征选择,支持连续特征 中小型数据集分类任务 计算复杂度较高,不直接支持回归
CART 可用于分类与回归,是集成学习基础 大规模数据、回归任务、集成模型(如随机森林) 分裂方式为二元分裂,对多值特征处理较繁琐

在实际应用中,单一决策树的泛化能力有限,通常结合集成学习(如随机森林、GBDT)提升性能。但理解决策树的基本原理是掌握复杂模型的基础,本文实现的 MATLAB 代码可帮助读者直观感受特征分裂过程,建议在此基础上尝试扩展后剪枝、特征重要性计算等功能。

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:MATLAB matla atlab Mat Atl
相关内容:决策树Matlab实现

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注cda
拉您进交流群
GMT+8, 2025-12-5 20:25