算法可视化之CatBoost

CDA老师1

Python 数据
# 算法可视化之CatBoost **本系列为算法可视化的技术汇总,若需要了解算法本身的技术细节请移步算法拆解系列文章。** ### CatBoost简介 **CatBoost是由俄罗斯算法巨头Yandex开发的开源机器学习库。它是梯度提升决策树(GBDT)机器学习集成技术家族的重要成员,是在GBDT算法框架下的一种改进实现。我们知道,GBDT是大数据分析中分类和回归任务的强大工具。作为GBDT成员之一的CatBoost自2018年底首次亮相以来,研究人员已成功将其用于涉及大数据分析的各种机器学习项目。InfoWorld杂志曾将该库授予“最佳机器学习工具”。Kaggle将其列为世界上最常用的机器学习 (ML) 框架之一。它在 2020 年调查中被列为前 8 个最常用的 ML 框架,在 2021 年调查中被列为前 7 个最常用的 ML 框架。** **CatBoost 作为基于决策树的算法,非常适合涉及分类、异构数据的机器学习任务。它提供了一个梯度提升框架,与经典算法相比,该框架尝试使用置换驱动的替代方案来解决分类特征。这一点从它的名字中可以看出来,CatBoost是由Categorical和Boosting组成。此外,CatBoost还解决了梯度偏差(Gradient Bias)以及预测偏移(Prediction shift)的问题,从而减少过拟合的发生,进而提高算法的准确性和泛化能力。** **CatBoost的主要优点有:** 1. **在许多数据集上,与其他 GBDT 库相比,预测质量更高。** 2. **一流的预测速度。** 3. **同时支持数值和分类特征。** 4. **快速 GPU 和多 GPU 支持,开箱即用。** 5. **使用Apache Spark和CLI 进行快速且可重现的分布式训练。** 6. **包含直接可调用的可视化工具。** **在机器学习领域,CatBoost是少数直接包含可视化工具的库之一,这也是该库的巨大优势。作为算法可视化的系列文章,我们来看一下如何调用CatBoost的可视化工具。** ### 代码实战 #### 训练模型 ``` from catboost import CatBoostClassifier, Pool, metrics, cv # 导入库 category_feature = np.where(X.dtypes != float)[0]   # 分类特征指示向量 model = CatBoostClassifier(   custom_metric=[metrics.Accuracy()],   # 评估指标 )   # 使用默认参数就可以获得性能非常好的模型 model.fit(   X_train, y_train,   cat_features= category_feature,   # 分类特征指示   eval_set=(X_validation, y_validation),   plot=True   # 重要参数,是否开启可视化工具 ) ``` **这里的重要参数是plot = ture,将其打开后,就可调用CatBoost提供的非常炫酷的训练可视化功能,从下图可以看到Logloss正在不停的下降。另一个需要注意的参数是cat_features= ,CatBoost可以自动处理分类型特征,但是需要手动做一个用来指示哪个特征是分类型特征的指示向量,必须是array格式。上述代码生成的指示向量category_feature就是形如[ 0 1 2 3 5 6 ],表示训练集中这6个特征是分类特征。** ![1671171571337.png](/z_anli/upload/pgc/202212/19d93beec2c930ea27d22a5a563b0599.png) **除了训练过程的可视化,CatBoost还提供了特征的统计信息的可视化包。若需要输出所有特征重要性的信息,则使用model.feature_importances_ 接口即可,和sklearn等机器学习库的用法相同。除此之外,CatBoost还提供了单个特征的统计信息,并且可以以可视化的形式呈现出来。数值型特征和分类型特征的输出结果有区别,我们分别来看** #### 数值型特征 ``` res = model.calc_feature_statistics(     X_train, y_train, feature= 'Fare', plot=True) ``` **生成的结果如下图,四个图例分别为** 1. **桶中的标签均值** 2. **桶中的预测值均值** 3. **桶中的样本数** 4. **对特征不同值的平均预测** 5. ![1671172537654.png](/z_anli/upload/pgc/202212/efd983af89dcac4d46610c2d5316f801.png) #### 分类型特征 ``` res = model.calc_feature_statistics(X_train, y_train, feature= 'Sex', plot=True) ``` ![1671172603573.png](/z_anli/upload/pgc/202212/9635680a8c5e6c81b0248c5e837967a2.png)
0.2224 4 0 关注作者 收藏 2022-12-16   阅读量: 836

评论(0)


暂无数据

博客推荐