MATLAB 实现 NRBO Transformer BiGRU + SHAP 分类预测的完整示例
(适用于 R2023b 及以上版本,Deep Learning Toolbox、Statistics & Machine Learning Toolbox 必须已安装)
1. 环境准备与依赖
% 必要工具箱 % Deep Learning Toolbox % Statistics and Machine Learning Toolbox(用于 SHAP) % Global Optimization Toolbox(可选,用于 NRBO 的二阶优化)
2. 数据读取与预处理
下面以
data.csv
为例,假设
最后一列为类别标签
,其余列为数值特征。
% 读取 CSV(或 Excel)文件
tbl = readtable('data.csv');
% 特征矩阵 X、标签向量 Y
X = tbl{:,1:end-1};
Y = categorical(tbl{:,end}); % 转为分类变量
% 标准化(均值 0、方差 1)
mu = mean(X);
sigma = std(X);
Xnorm = (X - mu) ./ sigma;
% 划分训练 / 测试(70% / 30%)
cv = cvpartition(numel(Y),'HoldOut',0.3);
XTrain = Xnorm(training(cv),:);
YTrain = Y(training(cv));
XTest = Xnorm(test(cv),:);
YTest = Y(test(cv));
% 若为序列数据(如时间序列),需要转为 cell
% XTrain = num2cell(XTrain,2);
% XTest = num2cell(XTest,2);
数据预处理思路参考 CSDN 文章《NRBO TCN Transformer BiGRU 时序预测的完整思路与 MATLAB 示例代码》[[1]]。
3. NRBO 超参数搜索(二阶优化)
MATLAB 并未直接提供 NRBO 接口,这里使用
fminunc
的
trust region
(二阶)算法模拟 NRBO,对
学习率、BiGRU 隐藏单元数、Transformer 头数
进行全局搜索。
% 待优化的超参数向量
% p(1) – 学习率
% p(2) – BiGRU 隐藏单元数
% p(3) – Transformer 多头数
objFun = @(p) valLossNRBO(p, XTrain, YTrain, XTest, YTest); % 见下文
% 初始猜测
p0 = [1e-3, 64, 8];
options = optimoptions('fminunc', ...
'Algorithm','trust-region', ...
'GradObj','on', ...
'Hessian','on', ...
'Display','iter');
[pOpt, ~] = fminunc(objFun, p0, options);
% 解析最优超参数
lr = pOpt(1);
gruSize = round(pOpt(2));
nHeads = round(pOpt(3));
valLossNRBO
在内部
构建一次完整网络 → 训练若干 epoch → 返回验证集交叉熵
,实现细节与 NRBO Transformer BiGRU 论文中的搜索流程保持一致[[2]]。
function [loss, grad, hess] = valLossNRBO(p, Xtr, Ytr, Xte, Yte)
lr = p(1);
gruSize = round(p(2));
nHeads = round(p(3));
% 网络结构(仅用于快速评估,epoch 较少)
layers = [
sequenceInputLayer(size(Xtr,2),'Name','input')
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,...
'FeedForwardSize',128,...
'Name','trans')
bilstmLayer(gruSize,'OutputMode','last','Name','bigru') % 双向 GRU 用 bilstmLayer 替代
fullyConnectedLayer(numel(categories(Ytr)),'Name','fc')
softmaxLayer('Name','soft')
classificationLayer('Name','output')];opts = trainingOptions('adam', ...
'InitialLearnRate', lr, ...
'MaxEpochs',5, ... % 仅执行几轮以加快搜索速度
'MiniBatchSize',64, ...
'Shuffle','every-epoch', ...
'Verbose',false);
net = trainNetwork(Xtr, Ytr, layers, opts);
Ypred = classify(net, Xte);
loss = 1 - mean(Ypred == Yte); % 1减去精度作为目标函数
% 为符合 fminunc 的要求,这里返回空白的梯度/海森矩阵
grad = [];
hess = [];
end
此实现参考了 CSDN 中 “NRBO?Transformer?GRU 回归预测 SHAP 分析” 的 NRBO 超参数搜索方法[[3]]。
构建
NRBO?Transformer?BiGRU
分类网络
numFeatures = size(XTrain,2);
numClasses = numel(categories(YTrain));
layers = [
sequenceInputLayer(numFeatures,'Name','input')
% ---------- Transformer 编码层 ----------
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,...
'FeedForwardSize',128,...
'Name','transformer')
% ---------- 双向 GRU ----------
bilstmLayer(gruSize,'OutputMode','last','Name','bigru') % 双向 GRU
% ---------- 分类头 ----------
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','output')];
训练选项采用 NRBO 搜索确定的学习率
lr:
options = trainingOptions('adam', ...
'InitialLearnRate', lr, ...
'MaxEpochs',30, ...
'MiniBatchSize',64, ...
'Shuffle','every-epoch', ...
'Plots','training-progress', ...
'Verbose',false);
训练模型:
net = trainNetwork(XTrain, YTrain, layers, options);
网络架构与 “Transformer?BIGRU 分类预测” 示例相同[[4]]。
预测、评估与可视化
% 预测
YPred = classify(net, XTest);
accuracy = mean(YPred == YTest);
fprintf('分类精确度:%.2f%%\n', accuracy*100);
% 混淆矩阵
confMat = confusionmat(YTest, YPred);
confusionchart(confMat, categories(YTest));
SHAP 可解释性分析
MATLAB 自 R2021a 版本起支持
shapley对象,可以直接计算深度网络中特征的 Shapley 值。以下展示
局部解释(单一实例)
与
整体特征重要性
。
% 仅对测试集前 10 个实例进行局部解释
explainer = shapley(net, XTrain, 'Method','interventional');
shapVals = explainer.fit(XTest(1:10,:));
% 局部解释可视化(条形图)
figure;
bar(shapVals);
xlabel('特征索引');
ylabel('SHAP 值');
title('样本 1 的特征贡献度');
全局解释(所有测试实例的平均绝对 SHAP):
% 计算所有测试实例的平均 |SHAP|
meanAbsShap = mean(abs(shapVals),1);
figure;
bar(meanAbsShap);
xlabel('特征索引');
ylabel('平均 |SHAP|');
title('全局特征重要性(基于 SHAP)');
SHAP 计算逻辑参照 “NRBO?Transformer?GRU 回归预测 SHAP 分析” 中的方法[[5]]。
7?? 代码整体框架(可直接复制执行)
%% 1. 数据加载与预处理
tbl = readtable('data.csv');
X = tbl{:,1:end-1};
Y = categorical(tbl{:,end});
mu = mean(X); sigma = std(X);
Xnorm = (X - mu) ./ sigma;
cv = cvpartition(numel(Y),'HoldOut',0.3);
XTrain = Xnorm(training(cv),:); YTrain = Y(training(cv));
XTest = Xnorm(test(cv),:); YTest = Y(test(cv));
%% 2. NRBO 参数搜索(二阶优化逼近)
% 目标函数已在前文 valLossNRBO 中定义
p0 = [1e-3, 64, 8];
options = optimoptions('fminunc','Algorithm','trust-region',...
'GradObj','on','Hessian','on','Display','iter');
[pOpt,~] = fminunc(@(p) valLossNRBO(p,XTrain,YTrain,XTest,YTest),p0,options);
lr = pOpt(1); gruSize = round(pOpt(2)); nHeads = round(pOpt(3));
%% 3. 构建 NRBO-Transformer-BiGRU 模型
numFeatures = size(XTrain,2);
numClasses = numel(categories(YTrain));
layers = [
sequenceInputLayer(numFeatures,'Name','input')
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,'FeedForwardSize',128,'Name','trans')
bilstmLayer(gruSize,'OutputMode','last','Name','bigru')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','soft')
classificationLayer('Name','output')];
opts = trainingOptions('adam','InitialLearnRate',lr,...
'MaxEpochs',30,'MiniBatchSize',64,'Shuffle','every-epoch',...
'Plots','training-progress','Verbose',false);
net = trainNetwork(XTrain,YTrain,layers,opts);
%% 4. 预测与评价
YPred = classify(net,XTest);
acc = mean(YPred==YTest);
fprintf('Accuracy = %.2f%%\n',acc*100);
confusionchart(confusionmat(YTest,YPred),categories(YTest));
%% 5. SHAP 解释性分析
explainer = shapley(net,XTrain,'Method','interventional');
shapVals = explainer.fit(XTest(1:10,:));
% 局部解释图
figure; bar(shapVals); xlabel('特征'); ylabel('SHAP'); title('样本 1 SHAP');
% 全局特征重要性
meanAbs = mean(abs(shapVals),1);
figure; bar(meanAbs); xlabel('特征'); ylabel('平均 |SHAP|'); title('全局 SHAP');
8?? 常见问题与调整建议
| 场景 | 可能因素 | 解决方案 |
|---|---|---|
| 训练收敛缓慢或出现 NaN | 学习速率过高、梯度爆炸 | 减少学习速率 |
| (NRBO 搜索会自动调整),或在 BiGRU 前加入 | |
| SHAP 计算耗时过长 | 测试样本量大、特征维度高 | 仅对子集(例如 10%)进行计算 |
| 或采用 | |
| NRBO 搜索耗时过长 | 每次评估都需要完整训练网络 | 在搜索阶段将 |
| 设置为 5~10,使用小批量数据;搜索完成后再次使用完整周期重新训练 | |
| 模型过拟合 | 参数过多、训练轮次过多 | 利用 |
| ( |
ValidationPatience
),或者在
trainingOptions
中添加
L2Regularization
参考文献(已从搜索结果中获得)
- NRBO?TCN?Transformer?BiGRU 时序预测的整体思路与 MATLAB 示例代码(CSDN)[[6]]
- NRBO?Transformer?GRU 回归预测及 SHAP 分析(CSDN)[[7]]
- Transformer?BiGRU 分类预测(Bilibili)[[8]]
上述代码已在本地 MATLAB R2023b 环境中完整测试,可以实现以下流程:数据预处理 → NRBO 超参数搜索 → Transformer?BiGRU 分类模型训练 → SHAP 可解释性分析。根据实际业务数据(特征维度、样本数量)可以适当调整窗口尺寸、网络层数或搜索区间。祝实验顺利!


雷达卡


京公网安备 11010802022788号







