写点什么

干货 | BERT fine-tune 终极实践教程

  • 2018-11-26
  • 本文字数:3768 字

    阅读完需:约 12 分钟

干货 | BERT fine-tune 终极实践教程

从 11 月初开始,google-research 就陆续开源了 BERT 的各个版本。google 此次开源的 BERT 是通过 tensorflow 高级 API—— tf.estimator 进行封装(wrapper)的。因此对于不同数据集的适配,只需要修改代码中的 processor 部分,就能进行代码的训练、交叉验证和测试。

在自己的数据集上运行 BERT

BERT 的代码同论文里描述的一致,主要分为两个部分。一个是训练语言模型(language model)的预训练(pretrain)部分。另一个是训练具体任务(task)的 fine-tune 部分。在开源的代码中,预训练的入口是在 run_pretraining.py 而 fine-tune 的入口针对不同的任务分别在 run_classifier.py 和 run_squad.py。其中 run_classifier.py 适用的任务为分类任务。如 CoLA、MRPC、MultiNLI 这些数据集。而 run_squad.py 适用的是阅读理解(MRC)任务,如 squad2.0 和 squad1.1。


预训练是 BERT 很重要的一个部分,与此同时,预训练需要巨大的运算资源。按照论文里描述的参数,其 Base 的设定在消费级的显卡 Titan x 或 Titan 1080ti(12GB RAM)上,甚至需要近几个月的时间进行预训练,同时还会面临显存不足的问题。不过所幸的是谷歌满足了 Issues#2 里各国开发者的请求,针对大部分语言都公布了 BERT 的预训练模型。因此在我们可以比较方便地在自己的数据集上进行 fine-tune。

下载预训练模型

对于中文而言,google 公布了一个参数较小的 BERT 预训练模型。具体参数数值如下所示:


Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters


模型的下载链接可以在 github 上 google 的开源代码里找到。对下载的压缩文件进行解压,可以看到文件里有五个文件,其中 bert_model.ckpt 开头的文件是负责模型变量载入的,而 vocab.txt 是训练时中文文本采用的字典,最后 bert_config.json 是 BERT 在训练时,可选调整的一些参数。

修改 processor

任何模型的训练、预测都是需要有一个明确的输入,而 BERT 代码中 processor 就是负责对模型的输入进行处理。我们以分类任务的为例,介绍如何修改 processor 来运行自己数据集上的 fine-tune。在 run_classsifier.py 文件中我们可以看到,google 对于一些公开数据集已经写了一些 processor,如 XnliProcessor,MnliProcessor,MrpcProcessor 和 ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写 processor。


对于一个需要执行训练、交叉验证和测试完整过程的模型而言,自定义的 processor 里需要继承 DataProcessor,并重载获取 label 的 get_labels 和获取单个输入的 get_train_examples,get_dev_examples 和 get_test_examples 函数。其分别会在 main 函数的 FLAGS.do_train、FLAGS.do_eval 和 FLAGS.do_predict 阶段被调用。


这三个函数的内容是相差无几的,区别只在于需要指定各自读入文件的地址。


以 get_train_examples 为例,函数需要返回一个由 InputExample 类组成的 list。InputExample 类是一个很简单的类,只有初始化函数,需要传入的参数中 guid 是用来区分每个 example 的,可以按照 train-%d’%(i)的方式进行定义。text_a 是一串字符串,text_b 则是另一串字符串。在进行后续输入处理后(BERT 代码中已包含,不需要自己完成) text_a 和 text_b 将组合成[CLS] text_a [SEP] text_b [SEP]的形式传入模型。最后一个参数 label 也是字符串的形式,label 的内容需要保证出现在 get_labels 函数返回的 list 里。


举一个例子,假设我们想要处理一个能够判断句子相似度的模型,现在在 data_dir 的路径下有一个名为 train.csv 的输入文件,如果我们现在输入文件的格式如下 csv 形式:


1,你好,您好0,你好,你家住哪 
复制代码


那么我们可以写一个如下的 get_train_examples 的函数。当然对于 csv 的处理,可以使用诸如 csv.reader 的形式进行读入。


def get_train_examples(self, data_dir):    file_path = os.path.join(data_dir, 'train.csv')    with open(file_path, 'r') as f:        reader = f.readlines()    examples = []    for index, line in enumerate(reader):        guid = 'train-%d'%index        split_line = line.strip().split(',')        text_a = tokenization.convert_to_unicode(split_line[1])        text_b = tokenization.convert_to_unicode(split_line[2])        label = split_line[0]        examples.append(InputExample(guid=guid, text_a=text_a,                                      text_b=text_b, label=label))    return examples
复制代码


同时对应判断句子相似度这个二分类任务,get_labels 函数可以写成如下的形式:


def get_labels(self):    return ['0','1']
复制代码


在对 get_dev_examples 和 get_test_examples 函数做类似 get_train_examples 的操作后,便完成了对 processor 的修改。其中 get_test_examples 可以传入一个随意的 label 数值,因为在模型的预测(prediction)中 label 将不会参与计算。

修改 processor 字典

修改完成 processor 后,需要在在原本 main 函数的 processor 字典里,加入修改后的 processor 类,即可在运行参数里指定调用该 processor。


 processors = {      "cola": ColaProcessor,      "mnli": MnliProcessor,      "mrpc": MrpcProcessor,      "xnli": XnliProcessor,       "selfsim": SelfProcessor #添加自己的processor  }
复制代码

运行 fine-tune

之后就可以直接运行 run_classsifier.py 进行模型的训练。在运行时需要制定一些参数,一个较为完整的运行参数如下所示:


export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 #全局变量 下载的预训练bert地址export MY_DATASET=/path/to/xnli #全局变量 数据集所在地址
python run_classifier.py \ --task_name=selfsim \ #自己添加processor在processors字典里的key名 --do_train=true \ --do_eval=true \ --dopredict=true \ --data_dir=$MY_DATASET \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ #模型参数 --train_batch_size=32 \ --learning_rate=5e-5 \ --num_train_epochs=2.0 \ --output_dir=/tmp/selfsim_output/ #模型输出路径
复制代码

BERT 源代码里还有什么

在开始训练我们自己 fine-tune 的 BERT 后,我们可以再来看看 BERT 代码里除了 processor 之外的一些部分。


我们可以发现,process 在得到字符串形式的输入后,在 file_based_convert_examples_to_features 里先是对字符串长度,加入[CLS]和[SEP]等一些处理后,将其写入成 TFrecord 的形式。这是为了能在 estimator 里有一个更为高效和简易的读入。


我们还可以发现,在 create_model 的函数里,除了从 modeling.py 获取模型主干输出之外,还有进行 fine-tune 时候的 loss 计算。因此,如果对于 fine-tune 的结构有自定义的要求,可以在这部分对代码进行修改。如进行 NER 任务的时候,可以按照 BERT 论文里的方式,不只读第一位的 logits,而是将每一位 logits 进行读取。


BERT 这次开源的代码,由于是考虑在 google 自己的 TPU 上高效地运行,因此采用的 estimator 是 tf.contrib.tpu.TPUEstimator,虽然 TPU 的 estimator 同样可以在 gpu 和 cpu 上运行,但若想在 gpu 上更高效地做一些提升,可以考虑将其换成 tf.estimator.Estimator,于此同时 model_fn 里一些 tf.contrib.tpu.TPUEstimatorSpec 也需要修改成 tf.estimator.EstimatorSpec 的形式,以及相关调用参数也需要做一些调整。在转换成较普通的 estimator 后便可以使用常用的方式对 estimator 进行处理,如生成用于部署的.pb 文件等。

GitHub Issues 里一些有趣的内容

从 google 对 BERT 进行开源开始,Issues 里的讨论便异常活跃,BERT 论文第一作者 Jacob Devlin 也积极地在 Issues 里进行回应,在交流讨论中,产生了一些很有趣的内容。


在 GitHub Issues#95 中大家讨论了 BERT 模型在今年 AI-Challenger 比赛上的应用。我们也同样尝试了 BERT 在 AI-Challenger 的机器阅读理解(mrc)赛道的表现。如果简单得地将 mrc 的文本连接成一个长字符串的形式,可以在 dev 集上得到 79.1%的准确率。


如果参考 openAI 的 GPT 论文里 multi-choice 的形式对 BERT 的输入输出代码进行修改则可以将准确率提高到 79.3%。采用的参数都是 BERT 默认的参数,而单一模型成绩在赛道的 test a 排名中已经能超过榜单上的第一名。因此,在相关中文的任务中,bert 能有很大的想象空间。


在 GitHub Issues#123 中,@hanxiao 给出了一个采用 ZeroMQ 便捷部署 BERT 的 service,可以直接调用训练好的模型作为应用的接口。同时他将 BERT 改为一个大的 encode 模型,将文本通过 BERT 进行 encode,来实现句子级的 encode。此外,他对比了多 GPU 上的性能,发现 bert 在多 GPU 并行上的出色表现。

总结

总的来说,google 此次开源的 BERT 和其预训练模型是非常有价值的,可探索和改进的内容也很多。相关数据集上已经出现了对 BERT 进行修改后的复合模型,如 squad2.0 上哈工大(HIT)的 AoA + DA + BERT 以及西湖大学(DAMO)的 SLQA + BERT。 在感谢 google 这份付出的同时,我们也可以借此站在巨人的肩膀上,尝试将其运用在自然语言处理领域的方方面面,让人工智能的梦想更近一步。





2018-11-26 14:195607

评论 1 条评论

发布
暂无评论
发现更多内容

技术案例 | 云原生微服务落地难?百度自用CRM这样做

百度开发者中心

微服务 CRM #百度智能云#

Git学习游戏化,从Learn Git Branching 开始

程序老王

git 学习 学习方法 git 学习

智慧党建管理系统,智慧组工平台开发方案

13530558032

浅谈基于ARP协议的网络攻击

行者AI

网络安全

第四章作业(二)

LouisN

国产芯片WiFi物联网智能插座—电耗采集功能设计

不脱发的程序猿

28天写作 国产芯片 电耗检测 电压电流 华大MCU

技术解析 | Doris Compaction机制解析

百度开发者中心

百度 apache doris

887页Java面试“成神”手册,已助朋友狂砍9个一二线大厂Offer

Java架构追梦

Java 阿里巴巴 架构 面试 金三银四

区块链农产品溯源平台,农产品区块链防伪

13530558032

四面美团开发岗,成功斩获offer,分享个人面经

Java架构之路

Java 程序员 架构 面试 编程语言

EEPROM CAT24CXX实现分页读、写数据

不脱发的程序猿

28天写作 CAT24C08 EEPROM 嵌入式软件 单片机

【LeetCode】区域和检索 - 数组不可变Java题解

Albert

算法 LeetCode 28天写作

从0到1建立数据分析指标体系底层逻辑

小飞象@木木自由

数据分析 数据指标 数据分析体系

算力挖矿系统开发|算力挖矿软件APP开发

系统开发

Pgbouncer最佳实践:系列一

PostgreSQLChina

数据库 postgresql 软件 开源社区

使用 pyVmomi 采集 vSphere 监控指标

冯骐

Python 运维 监控 Open-Falcon vpshere

山东青岛推进平安小区建设!源中瑞智慧社区平台解决方案

源中瑞-龙先生

解决方案 山东 源中瑞 青岛 智慧社区

接口测试--apipost中cookie管理器的使用

测试人生路

接口 Cookie

OS命令--shell中数组的操作

cloudcoder

数组 Shell 循环引用

LeetCode题解:123. 买卖股票的最佳时机 III,动态规划,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

#集赞送好礼#百度大脑AI开放平台的2020年

百度大脑

2021最新京东、字节跳动「3面面经」盘点大厂后端面试高频题

Java架构之路

Java 程序员 架构 面试 编程语言

程序员之禅(一)

每天读本书

读书笔记

Kubernetes 稳定性保障手册 -- 极简版

阿里巴巴云原生

云计算 容器 开发者 云原生 k8s

从0到1建立软件测试质量体系

程序员阿沐

软件测试 测试工程师 质量保证

Vim,人类史上最好用的文本编辑器

沉默王二

vim 开发工具 vim教程

QA视角看数据匿名化

BY林子

数据安全 测试右移 用户数据 数据脱敏

极限编程技术实践

Teobler

敏捷 敏捷开发 TDD 重构 极限编程

程序员成长第十七篇:项目转测

石云升

项目管理 程序员 28天写作 3月日更

Serverless 如何在阿里巴巴实现规模化落地?

阿里巴巴云原生

阿里巴巴 Serverless 容器 微服务 云原生

2021备战金三银四血拼一波算法:字节+百度+美团+网易+拼夕夕+腾讯+滴滴

比伯

Java 编程 程序员 架构 面试

干货 | BERT fine-tune 终极实践教程_AI&大模型_奇点机智_InfoQ精选文章