DeepGBM 针对在线任务的GBDT2NN

CatNN、GBDT2NN

Posted by Jiayue Cai on August 21, 2019

Last updated on 2019-09-29…

《DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks》,发表在 KDD 2019

视频代码

介绍

CTR数据特征:

  • 通常包含离散和连续两种特征
  • 数据会实时增加,分布随之变化

问题的思考出发点:

  • GBDT:适合处理连续特征,但不善于处理大量离散特征,且每次全量数据训练。
  • NN:适合处理离散类别特征(要embedding),但不善于处理连续特征
  • GBDT+NN:由于GBDT的存在,依然难以在线训练

所以作者设计了DeepGBM,包含两部分:

  • CatNN:面向离散类别特征,实际就是DeepFM
  • GBDT2NN:面向连续数值特征,主要做了两件事:
    • 对gbdt生成的tree分组,然后用NN拟合叶子的index
    • 把叶子的index作embedding,单层全连接后sigmoid输出

离散部分 CatNN

实际就是DeepFM

连续部分 GBDT2NN

这部分主要做了三件事:

  • 特征选择:只使用tree中选定的特征
  • 结构蒸馏:使用NN拟合tree的叶子的index
  • 决策树输出:把叶子index映射到value,得到tree的输出

单树蒸馏

多树蒸馏

Tree Grouping

NN对应的一棵树转化代价太大,优化方法:降低蒸馏多颗tree的代价

实验

离线

上式右侧为树组T的嵌入损失,α、β为超参数

几种模型的效果(分类问题利用AUC,回归问题利用MSE)对比如下

离线实验表明:

  • 基于 GBDT 的模型较 NN 模型有更好的表现
  • GBDT2NN 在 GBDT 的基础上还有一定提升
  • DeepGBM 的表现比所有的baseline模型都要好

在线

在线时就不涉及GBDT,也不重新训练了

在线实验表明:

  • DeepGBM 契合在线场景,且效果比所有的baseline模型都要好

个人吐槽点

  • CatNN整个就是原来的DeepFM
  • online实验时的loss根本不涉及gbdt,也不涉及gbdt的更新