写点什么

更小、更快、更便宜、更轻量:开源 DistilBERT,BERT 的精简版本

  • 2019-08-30
  • 本文字数:4284 字

    阅读完需:约 14 分钟

更小、更快、更便宜、更轻量:开源DistilBERT,BERT的精简版本

感兴趣的朋友可以点击此处获取重现 DistilBERT 训练以及 DistilBERT 预训练权重的代码。


过去 18 个月以来,大规模语言模型的迁移学习可谓遍地开花,在几乎所有自然语言处理任务当中都实现了显著的性能改进。


作为以 Vaswani 等人的 Transformer 架构为基础的解决方案,这些经过预训练的语言模型正变得日益庞大,并仍在立足更大的数据集进行训练。英伟达公司的最新模型拥有多达 83 亿个参数:24 倍于 BERT,5 倍于 GPT-2;而 Facebook AI 拿出的 RoBERTa 则利用 160 GB 文本训练而成。



社区中的不少从业者开始怀疑,到底有没有必要训练这些越来越臃肿的 Transformer,毕竟其在训练的经济与环境成本方面已经呈现出失控的状态。通过上图,我们一起来看部分最新大型模型及其参数数量(单位为百万)。


在 Hugging Face,我们亲身体验了这些模型高涨的人气,因为我们的 NLP 库(打包了其中大部分模型)在短短几个月之内就得到超过 40 万次安装。


然而,随着这些模型进入更大的社区,一个重要甚至说极具挑战性的问题开始出现——我们该如何把这些庞然大物投入生产?我们如何在低延迟约束条件下使用这类大型模型?我们是否需要昂贵的 GPU 服务器才能实现大规模服务?


对于许多研究人员及开发人员而言,这可能是个最现实的问题。


为了构建更尊重隐私的系统,我们意识到有必要在边缘位置运行机器学习系统,从而尽可能避免以调用云 API 的方式将个人数据发送至服务器端。这就意味着,我们需要能够在智能手机等小型设备上运行轻、反应灵敏且资源需求量较低的模型版本!


最后但同样重要的是,我们也越来越关注模型扩展过程当中,严苛计算资源需求所带来的环境成本。


那么,我们该如何帮助这些庞然大物成功瘦身?


不少现有技术都有望解决问题。最常见的工具包括量化(对准确率影响较小的网络权重进行近似化)以及权重修剪(删除网络中的某些连接)。对于此类技术,推荐大家参阅 Rasa 发布的BERT量化博文


但我们最终决定专注于模型蒸馏:这是一种能够将大型模型(被称为「老师」)压缩为较小模型(即「学生」)的技术。

知识蒸馏:迁移泛化能力

知识蒸馏(有时也称为师生学习)是一种压缩技术,要求对小型模型进行训练,以使其拥有类似于大型模型(或者模型集合)的行为特征。这项技术由 Bucila 等人提出,并得到了 Hinton 等人的推广。我们这里采用的,正是 Hinton 采取的方法。


在监督学习领域,我们在训练分类模型时往往会利用对数似然信号实现概率最大化(logits 的 softmax),进而预测出正确类。在大多数情况下,性能良好的模型能够利用具有高概率的正确类预测输出分布,同时其它类的发生概率则接近于零。


但是,某些“接近于零”的概率要比其它概率更大,这在一定程度上反映出模型的泛化能力。


例如,把普通椅子误认为扶手椅虽然属于错误,但这种错误远比将其误认为蘑菇来得轻微。这种不确定性,有时被称为“暗知识”。


我们也可以从另一个角度来理解蒸馏——用于防止模型对预测结果太过确定(类似于标签平滑)。


以下为具体实例。在语言建模当中,我们可以通过查看词汇表中的分布轻松观察到这种不确定性。下图为 BERT 对《卡萨布兰卡》电影当中经典台词下一句用词的猜测:



BERT 提出的 20 大高概率用词猜测结果。语言模型确定了两个可能性最高的选项(day 与 life),接下来的词汇相比之下概率要低得多。

我们如何复制这些“暗知识”?

在师生训练当中,我们训练学生网络,用于模拟老师网络的全部输出分布(也就是知识)。


我们通过匹配输出分布的方式训练学生网络,从而实现与老师网络相同的泛化方式。


我们并没有在硬目标上使用交叉熵训练(正确类的独热编码),而是通过软目标(老师概率)将交叉熵从老师处传递给学生。我们的训练损失因此变为:



其中 t 为来自老师的 logit,s 为学生的 logit。


这个损失函数属于更丰富的训练信号,因为单一示例要比单一硬目标拥有更高的强制约束效果。


为了进一步揭示分类结果的质量,Hinton 等人提出了 softmax 温度的概念:



T 为该温度参数。


T → 0 时,分布变为 Kronecker(相当于独热目标矢量);当 T →+∞时,则变为均匀分布。在训练过程中,将相同的温度参数应用于学生与老师网络,即可进一步为每个训练示例揭示更多信号。在推论当中,T 被设置为 1 以恢复标准 Softmax。

PyTorch 编码——压缩 BERT

我们希望利用蒸馏方法对大型语言模型加以压缩。在蒸馏方面,我们使用 Kullback-Leibler 损失函数,因为其拥有相同的优化效果:



在计算关于 q(学生网络分布)的梯度时,我们获得了相同的梯度结果。我们可以利用 PyTorch 实现加快计算速度:



PyTorch 中的知识蒸馏训练步骤。点击此处复制 gist。


利用老师信号,我们能够训练出一套较小的语言模型,我们称之为 DistilBERT,属于 BERT 的监督产物(我们使用 BERT 的英文 bert-base-uncased 版本)。


根据 Hinton 等人的发现,训练损失函数属于蒸馏损失与 masked 语言建模损失的线性组合。我们的学生网络属于 BERT 的一套小型版本,其中删除了 token-type 嵌入与 pooler(用于下一句分类任务),但其余部分架构保持不变,而层数也减少至原本的二分之一。


总体而言,我们的蒸馏模型 DistilBERT 在总体参数数量上约为 BERT 的一半,但在 GLUE 语言理解基准测试中能够保留 95%的 BERT 性能表现。


注 1 — 为什么不降低隐藏层的大小?

将 768 层减至 512 层,意味着总参数量约下降至原本的二分之一。希,在现代框架当中,大多数运算都经过高度优化,而且张量的最终维度(隐藏维度)的变化会对 Transformer 架构(线性分层与层规范化)中的大部分运算产生小幅影响。在我们的实验中,层数对于推理时间的影响要远高于隐藏层的大小。

因此,更小并不代表着一定更快……

注 2 — Tang 等人在蒸馏工作当中,直接在下游任务内使用 L2 距离作为蒸馏损失

我们的早期实验结果表明,在本案例中,交叉熵损失会明显提高性能水平。我们假定在语言建模设置当中,输出空间(词汇表)要明显大于下游任务输出空间的维度。因此,logits 可以在 L2 损失中相互补偿。


训练子网络的核心不只是建立架构,还要求我们为子网络找到正确的初始化方式以实现收敛。因此,我们以作为老师的 Bert 为基础对学生 DistilBERT 进行初始化,将层数削减一半,并采用相同的隐藏大小。


我们还用到了最近 RoBERTa 论文当中提到的一些训练技巧,这也再次证明 BERT 模型的训练方式对其最终表现有着至关重要的影响。与 RoBERTa 类似,我们对 DIstilBERT 进行大批次训练,使用梯度累积(每批最多 4000 个例子)、配合动态遮挡并删除了下一句预测目标。


我们在训练设置中对资源进行了主动限制。我们利用多伦多图书语料库与英语维基百科的串联数据集(与原始 BERT 相同),并配合八块 16 GB V100 GPU 进行了约三天半的训练。


DistilBert 的代码部分来自 Facebook XLM,也有一部分来自我们 PyTorch 版本的 Google AI Bert(可点此获取),以及针对 DistilBert 进行的精心调优。这一切,都是为了更好地重现 BERT 的预测性能。

DistilBERT 模型性能测试

我们将 DistilBERT 在 GLUE 基准测试开发集上的性能与两项基准进行了比较:其一为 BERT 基础(DistilBERT 的老师),其二为来自纽约大学的强大非 transformer 基准——ELMo 上的两个 BilSTM。我们利用纽约大学的 jiant 库获取 ELMo 基准,并使用 pytorch-transformers 获取 BERT 基准。


如下表所示,DistilBERT 的性能与基准相比更好一些,而参数数量只分别相当于二者的一半以及三分之一。在 9 项任务当中,DistilBERT 的 ELMo 基准成绩一直等同或者领先(在 QNLI 上的准确率高出 14%)。DistilBERT 的表现确实远超 BERT:我们保留了 95%以上的性能,同时将参数减少了 40%。



在 GLUE 基准测试开发集中的比较结果以及由作者上报的 ELMo 结果。BERT 与 DistilBERT 结果来自 5 次单独运行后的中位数取值。


在推理时间方面,DistilBERT 比 BERT 快 60%,体积比 BERT 小 60%,比 ELMo + BiLSTM 快 120%且模型体积更小。



为了进一步研究 DistilBERT 的加速/大小平衡点,我们在上表中比较了各个模型的参数数量,以及在 CPU 上完全处理 STS-B 开发集(批量大小为 1)所需要的推理时间。

下游任务:蒸馏与迁移学习

我们进一步研究了 DistilBERT 在有效推理约束下的下游应用效果。我们通过分类任务调优,实现对这套紧凑预训练语言模型的迁移。事实证明,这是种实现蒸馏预训练与迁移学习的好方法!



从 IMDB Review 数据集中提取到的电影评论。


我们选择了 IMDB 影评情感区中的素材,该分区共包含 5 万条英文评论,且标记为正面或负面:我们使用其中 2 万 5 千条进行训练,另外 2 万 5 千条进行测试(同时配合平衡类)。整个训练过程在单一 12 GB K80 上进行。


首先,我们在自己的数据集上训练 bert-base-uncased。我们亲爱的 BERT 老师达到了 99.98%的准确率(3 次运行取平均值)。相当完美!


接下来,我们训练 DistilBERT,使用同样的超参数。压缩模型的准确率达到 99.53%(3 次运行取平均值)。性能的绝对差为 0.5%,延迟降低 60%,大小减少 40%。


NLP 技术的另一种常见应用是问题解答。我们在 SQuAD 1.1 数据集上比较了 BERT bert-base-uncased 版本与 DistilBERT 的结果。在开发集上,BERT 的 F1 得分为 88.5,EM(完全匹配)得分为 81.2。我们利用同样的超参数进行 DistilBERT 训练,F1 分数与 EM 分数分别为 85.1 与 76.5,同 BERT 成绩的差距分别为 3 分与 5 分。


我们还研究了能否在适应阶段利用经过调优的 BERT 作为老师,配合知识蒸馏损失对 DistilBERT 实现 SQuAD 数据集上的调优。


在新案例中,我们将问题回答模型蒸馏为以往通过知识蒸馏预训练完成的语言模型,从而实现调优!这样,老师与学生将能够相互转换。


如此一来,考虑到网络规模,我们能够获得非常有趣的结果:F1 得分为 86.2,EM 得分为 78.1。与完整模型相比,差距保持在 3 分以内!

少即是多:小型模型也能带来理想性能

我们对 DistilBERT 的潜力感到非常兴奋。目前的成果只是刚刚起步,也给我们提出了很多新的问题:我们能够利用知识蒸馏技术将这些模型压缩到怎样的程度?这些技术能否用于进一步理解大型模型中存储的知识?在这类压缩当中,我们损失掉的是哪些语言/语义元素?……


在 HuggingFadce,我们一直将开源与知识共享视为自己的使命。所以,大家可以点击此处访问我们的 GitHub 库,这是我们每个人参与 NLP 深度学习项目并获取卓越成果的最简单、也最公平的方式。


因此,配合本篇博文,我们在pytorch-transformer库当中发布了实验代码(主要是重现训练与 DistilBERT 调优代码)以及一套经过训练的 DistilBERT 版本,感兴趣的朋友可以随意取用。


原文链接:


https://medium.com/huggingface/distilbert-8cf3380435b5


2019-08-30 08:008151

评论

发布
暂无评论
发现更多内容

「Java并发编程」从源码分析几道必问线程池的面试题?

Java架构师迁哥

深度解析ThreadLocal原理

AI乔治

Java 架构 线程 ThreadLocal

O'Reilly出版社又一经典之作——Python设计模式

计算机与AI

Python

当人脸识别对准执法者,AI的应用边界博弈

脑极体

Rethink:多版本文件的命名细节

小匚

团队 随笔杂谈

cglib入门后篇

Rayjun

Java cglib

代码简易调试方法.md

Albert

Java LeetCode 调试

实时指挥调度的发展和优势

anyRTC开发者

ios android 音视频 WebRTC RTC

记不住Spring中Scheduled中的Cron语法?让我们看看源码吧

AI乔治

Java spring 编程 架构

低代码开发平台核心功能设计——组件自定义交互实现

徐小夕

大前端 编辑器 H5 大屏可视化 lowcode

数字人民币都来了 黄金还有什么用?

CECBC

数字货币

可以解除程序员中年危机的职业规划

Java架构师迁哥

靠脑机接口“隔空探物”,大脑植入芯片可实现“心灵感应”

脑极体

Reactor中的Thread和Scheduler

程序那些事

响应式编程 reactor 多线程 程序那些事 reactivex

当我们在讨论实时性的时候,我们在讨论什么?

VoltDB

数据分析 5G 工业互联网

涨薪神作!华为内部操作系统与网络协议笔记爆火,Java程序员有福了

Java架构之路

Java 程序员 面试 编程语言

从零到千万用户,我是如何一步步优化MySQL数据库的?

冰河

数据库 架构 性能优化 分布式数据库 分布式存储

架构师训练营第 1 期第 8 周学习总结

好吃不贵

极客大学架构师训练营

Spring bean 加载顺序导致的 bug 问题

AI乔治

Java 架构 Spring Boot

甲方日常 48

句子

工作 随笔杂谈 日常

这份笔记我必啃完!美团T9首发内部JVM高级特性笔记,差距不止一点点

Java架构追梦

Java 源码 架构 面试 JVM

【T1543.003】利用 ACL 隐藏恶意 Windows 服务

比伯

Java 大数据 编程 架构 计算机

甲方日常 47

句子

工作 随笔杂谈 日常

微信视频号强制置顶朋友圈:盈利不可牺牲用户体验

石头IT视角

一个技术总监的忠告:精通那么多技术,你为何还是受不到重用?

四猿外

程序人生 技术管理 加薪 职场成长 源码阅读

区块链产业,怎样“链”住未来?

CECBC

区块链

2020双11:每秒58.3万笔!阿里云又扛住了!

云计算 互联网 运维 云原生 科技

什么?美团T9首发内部JVM高级特性笔记,看完差距不止一点

小Q

Java 学习 程序员 架构 面试

如何预防工业物联网中的恶意攻击?

VoltDB

大数据 数据分析 5G 工业互联网

5G为数字化转型插上翅膀

CECBC

5G网络安全

如何应对大促流量洪峰?揭秘京东技术人的备战手册

京东科技开发者

云计算 大数据 亿级流量

更小、更快、更便宜、更轻量:开源DistilBERT,BERT的精简版本_语言 & 开发_Victor Sanh_InfoQ精选文章