# 算法可视化之CatBoost
**本系列为算法可视化的技术汇总,若需要了解算法本身的技术细节请移步算法拆解系列文章。**
### CatBoost简介
**CatBoost是由俄罗斯算法巨头Yandex开发的开源机器学习库。它是梯度提升决策树(GBDT)机器学习集成技术家族的重要成员,是在GBDT算法框架下的一种改进实现。我们知道,GBDT是大数据分析中分类和回归任务的强大工具。作为GBDT成员之一的CatBoost自2018年底首次亮相以来,研究人员已成功将其用于涉及大数据分析的各种机器学习项目。InfoWorld杂志曾将该库授予“最佳机器学习工具”。Kaggle将其列为世界上最常用的机器学习 (ML) 框架之一。它在 2020 年调查中被列为前 8 个最常用的 ML 框架,在 2021 年调查中被列为前 7 个最常用的 ML 框架。**
**CatBoost 作为基于决策树的算法,非常适合涉及分类、异构数据的机器学习任务。它提供了一个梯度提升框架,与经典算法相比,该框架尝试使用置换驱动的替代方案来解决分类特征。这一点从它的名字中可以看出来,CatBoost是由Categorical和Boosting组成。此外,CatBoost还解决了梯度偏差(Gradient Bias)以及预测偏移(Prediction shift)的问题,从而减少过拟合的发生,进而提高算法的准确性和泛化能力。**
**CatBoost的主要优点有:**
1. **在许多数据集上,与其他 GBDT 库相比,预测质量更高。**
2. **一流的预测速度。**
3. **同时支持数值和分类特征。**
4. **快速 GPU 和多 GPU 支持,开箱即用。**
5. **使用Apache Spark和CLI 进行快速且可重现的分布式训练。**
6. **包含直接可调用的可视化工具。**
**在机器学习领域,CatBoost是少数直接包含可视化工具的库之一,这也是该库的巨大优势。作为算法可视化的系列文章,我们来看一下如何调用CatBoost的可视化工具。**
### 代码实战
#### 训练模型
```
from catboost import CatBoostClassifier, Pool, metrics, cv # 导入库
category_feature = np.where(X.dtypes != float)[0] # 分类特征指示向量
model = CatBoostClassifier(
custom_metric=[metrics.Accuracy()], # 评估指标
) # 使用默认参数就可以获得性能非常好的模型
model.fit(
X_train, y_train,
cat_features= category_feature, # 分类特征指示
eval_set=(X_validation, y_validation),
plot=True # 重要参数,是否开启可视化工具
)
```
**这里的重要参数是plot = ture,将其打开后,就可调用CatBoost提供的非常炫酷的训练可视化功能,从下图可以看到Logloss正在不停的下降。另一个需要注意的参数是cat_features= ,CatBoost可以自动处理分类型特征,但是需要手动做一个用来指示哪个特征是分类型特征的指示向量,必须是array格式。上述代码生成的指示向量category_feature就是形如[ 0 1 2 3 5 6 ],表示训练集中这6个特征是分类特征。**
![1671171571337.png](/z_anli/upload/pgc/202212/19d93beec2c930ea27d22a5a563b0599.png)
**除了训练过程的可视化,CatBoost还提供了特征的统计信息的可视化包。若需要输出所有特征重要性的信息,则使用model.feature_importances_ 接口即可,和sklearn等机器学习库的用法相同。除此之外,CatBoost还提供了单个特征的统计信息,并且可以以可视化的形式呈现出来。数值型特征和分类型特征的输出结果有区别,我们分别来看**
#### 数值型特征
```
res = model.calc_feature_statistics(
X_train, y_train, feature= 'Fare', plot=True)
```
**生成的结果如下图,四个图例分别为**
1. **桶中的标签均值**
2. **桶中的预测值均值**
3. **桶中的样本数**
4. **对特征不同值的平均预测**
5.
![1671172537654.png](/z_anli/upload/pgc/202212/efd983af89dcac4d46610c2d5316f801.png)
#### 分类型特征
```
res = model.calc_feature_statistics(X_train, y_train, feature= 'Sex', plot=True)
```
![1671172603573.png](/z_anli/upload/pgc/202212/9635680a8c5e6c81b0248c5e837967a2.png)
评论(0)
暂无数据