大咖直播-鸿蒙原生开发与智能提效实战!>>> 了解详情
写点什么

如何使用半监督学习为结构化数据训练出更好的深度学习模型

  • 2020-10-22
  • 本文字数:2368 字

    阅读完需:约 8 分钟

如何使用半监督学习为结构化数据训练出更好的深度学习模型

本文最初发表于 Towards Data Science 博客,经原作者 Youness Mansar 授权,InfoQ 中文站翻译并分享。


众所周知,深度学习在应用于文本、音频或图像等非结构化数据时效果很好,但在应用于结构化或表格化数据时,深度学习有时会落后于其他机器学习方法,如梯度提升等。在本文中,我们将使用半监督学习来提高深度神经模型在低数据环境下应用于结构化数据时的性能。我们将展示通过使用无监督的预训练,可以使神经模型的性能优于梯度提升。


本文是基于以下两篇论文:



我们实现了一个类似于 AutoInt 论文中提出的深度神经结构,使用了多头自注意力和特征嵌入。预训练部分取自 TabNet 的论文。

方法说明

我们将处理结构化数据,这意味着可以将数据写成具有列(数字、分类、序号)和行的表。我们还假设我们有大量的未标记样本,可以用于预训练,以及少量的标记样本,可用于监督学习。在接下来的实验中,我们将模拟这个环境来绘制学习曲线,并在使用不同大小的标记集时对该方法进行评估。

数据准备

让我们用一个例子来描述在将数据提供给神经网络之前我们是如何准备数据的。



在这个例子中,我们有三个样本和三个特征 {F1,F2,F3} 和一个目标。F1 是分类特征,而 F2 F3 是数字特征。


我们将为 F1 的每个模态 X 创建一个新特征 F1_X,如果 F1==X,则为其赋值 1,否则等于 0。


转换后的样本将写入一组 (Feature_Name, Feature_Value)


例如:


第一个样本 → {(F1_A, 1), (F2, 0.3), (F3, 1.3)}


第二个样本 → {(F1_B, 1), (F2, 0.4), (F3, 0.9)}


第三个样本 → {(F1_C, 1), (F2, 0.1), (F3, 0.8)}


特征名称将被馈送到嵌入层,然后与特征值相乘。

模型:

这里使用的模型是一个多头注意力块序列和逐点前馈层。在训练时,我们也使用池化的注意力跳过连接。多头注意力模块允许我们对特征之间可能存在的交互进行建模,而池化的注意力跳过连接允许我们从一组特征嵌入中获得单个向量。


预训练

在预训练步骤中,我们使用完整的未标记数据集,输入特征的损坏版本,并训练模型来预测未损坏的特征,类似于在去噪自动编码器中所做的操作。

监督式训练

在训练的监督部分,我们在编码器部分和输出端之间添加跳过连接,并尝试预测目标。


实验

在接下来的实验中,我们将使用四个数据集,其中两个用于回归,两个用于分类。


  • Sarco:有大约 5 万个样本,21 个特征和 7 个连续目标。

  • Online News:有 4 万个左右的样本,61 个特征和 1 个连续目标。

  • Adult Census:有大约 4 万个样本、15 个特征和 1 个二元目标。

  • Forest Cover:有大约 50 万个样本,54 个特征和 1 个分类目标。


我们将比较一个预训练神经模型和一个从零开始训练的神经模型,将重点关注地数据状态下的性能,这意味着几百到几千个标记样本。我们还将于一个流行的名为lightgbm的梯度提升实现进行比较。

Forest Cover:

Adult Census:


对于这个数据集,我们可以看到,如果训练集小于 2000,那么预训练是非常有效的。

Online News:

对于 Online News 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小为 500 或更大的情况下都超过了梯度提升。



对于 Sarco 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小的情况下超过了梯度提升。


旁注:用于重现结果的代码

重现结果的代码可以在这里找到:


https://github.com/CVxTz/DeepTabular


使用这段代码,你可以很轻松地训练分类或回归模型:


import pandas as pdfrom sklearn.model_selection import train_test_splitfrom deeptabular.deeptabular import DeepTabularClassifierif __name__ == "__main__":data = pd.read_csv("../data/census/adult.csv")train, test = train_test_split(data, test_size=0.2, random_state=1337)target = "income"num_cols = ["age", "fnlwgt", "capital.gain", "capital.loss", "hours.per.week"]cat_cols = ["workclass","education","education.num","marital.status","occupation","relationship","race","sex","native.country",]for k in num_cols:mean = train[k].mean()std = train[k].std()train[k] = (train[k] - mean) / stdtest[k] = (test[k] - mean) / stdtrain[target] = train[target].map({"<=50K": 0, ">50K": 1})test[target] = test[target].map({"<=50K": 0, ">50K": 1})classifier = DeepTabularClassifier(num_layers=10, cat_cols=cat_cols, num_cols=num_cols, n_targets=1,)classifier.fit(train, target_col=target, epochs=128)pred = classifier.predict(test)classifier.save_config("census_config.json")classifier.save_weigts("census_weights.h5")new_classifier = DeepTabularClassifier()new_classifier.load_config("census_config.json")new_classifier.load_weights("census_weights.h5")new_pred = new_classifier.predict(test)
复制代码

结论

在计算机视觉或自然语言领域,无监督预训练可以提高神经网络的性能。在本文中,我们展示了它在应用于结构化数据时也能起作用,使其在低数据环境与其他机器学习方法(如梯度提升)具有竞争力。


作者简介:


Youness Mansar,供职于 Fortia Financial Solutions 的数据科学家。巴黎中央理工学院(Ecole Centrale Paris)应用数学硕士学位和巴黎-萨克雷高等师范学校(École normale supérieure Paris-Saclay)机器学习硕士。作为 Fortia 的数据科学家,曾参与过多个涉及自然语言处理和深度学习的项目。


原文链接:


https://towardsdatascience.com/training-better-deep-learning-models-for-structured-data-using-semi-supervised-learning-8acc3b536319


2020-10-22 09:002742
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 599.2 次阅读, 收获喜欢 1982 次。

关注

评论

发布
暂无评论
发现更多内容

耗时5小时,用低代码搭了2套应用,我才明白它为什么能火了

优秀

低代码 低代码开发 低代码开发平台 低代码平台

中国区块链产业全景图

CECBC

技术应用

GitHub开源的中国亲戚关系计算器

不脱发的程序猿

GitHub 开源 程序员 4月日更 中国亲戚关系

飞桨中国行落地合肥,与当地企业共话产业智能化升级

百度大脑

智能化 飞桨中国行

浙江宁波市区块链研究机构发布首个全国性公证联盟运营链

CECBC

区块链

迪安精选:那些好用的浏览器扩展

迪安

浏览器 插件 扩展

阿里码农肝了2晚,整理的Java语法总结,网友:考试复习全靠它了

飞飞JAva

NumPy之:数据类型对象dtype

程序那些事

Python 数据分析 Numpy 程序那些事

软件 IT 专业大学生职业方向情况调查

李孟聊AI

大学生日常 IT 大学生

人类视觉神经科学助力音视频产业革命-弱网下的极限实时通信

张音乐

音视频 笔记 弱网下的极限实时视频通信

SpringSecurity+JWT认证流程解析

学Java关注我

Java 编程 程序人生 计算机 架构】

云图说|ModelArts Pro,为企业级AI应用打造的专业开发套件

华为云开发者联盟

AI 企业应用 ModelArts Pro 开发套件

ceph-csi源码分析(3)-rbd driver-服务入口分析

良凯尔

Kubernetes 源码分析 Ceph CSI

ceph-csi源码分析(4)-rbd driver-controllerserver分析

良凯尔

Kubernetes 源码分析 Ceph CSI

vue+webpack+vue-cli

Vue js 打包 webpack vuecli

流水线成功涨薪到年薪30W 只有努力才能成功

学Java关注我

Java 架构 程序人生 编程语言

一文带你了解华为云GaussDB的五大黑科技

华为云开发者联盟

数据库 华为云 GaussDB(for Influx) 时间线 tpmC

让宝妈宝爸告别安全顾虑,区块链构建母婴行业新生态

CECBC

母婴

Boss直聘转发超100W次Java面试突击手册 火遍全网

比伯

Java 编程 程序员 架构 计算机

群英荟萃 | UINO优锘科技ThingJS平台亮相华为开发者大会

ThingJS数字孪生引擎

物联网 3D可视化 数字孪生

Python3 print变量打印输出功能后面隐含的几个知识点

老猿Python

Python print str repr

LeetCode题解:191. 位1的个数,位运算,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

yarn的applicationMaster介绍

五分钟学大数据

YARN

量化策略倍投系统搭建,马丁策略交易

CloudQuery v1.3.7版本更新,新增「导出限制」

BinTools图尔兹

数据库 sql 数据安全 数据库管理

Faiss源码剖析:类结构分析

华为云开发者联盟

机器学习 KNN Faiss 类结构 Quantizer

抵制羊毛党,图计算“加持”互联网电商风控

华为云开发者联盟

风控 图计算 互联网电商 羊毛党

让电影票房飞一会儿,五一换个姿势重温经典

华为云开发者联盟

音视频 电影修复 视频超分 媒体处理 混合失真

图的学习总结

Nick

数据结构 数据结构与算法

Kubernetes 上如何控制容器的启动顺序?

张晓辉

Kubernetes istio

uni-app rtm插件集成指南及常见问题--iOS

anyRTC开发者

uni-app ios 音视频 WebRTC sdk

如何使用半监督学习为结构化数据训练出更好的深度学习模型_AI&大模型_Youness Mansar_InfoQ精选文章