AICon 上海站|日程100%上线,解锁Al未来! 了解详情
写点什么

干货 | BERT fine-tune 终极实践教程

  • 2018-11-26
  • 本文字数:3768 字

    阅读完需:约 12 分钟

干货 | BERT fine-tune 终极实践教程

从 11 月初开始,google-research 就陆续开源了 BERT 的各个版本。google 此次开源的 BERT 是通过 tensorflow 高级 API—— tf.estimator 进行封装(wrapper)的。因此对于不同数据集的适配,只需要修改代码中的 processor 部分,就能进行代码的训练、交叉验证和测试。

在自己的数据集上运行 BERT

BERT 的代码同论文里描述的一致,主要分为两个部分。一个是训练语言模型(language model)的预训练(pretrain)部分。另一个是训练具体任务(task)的 fine-tune 部分。在开源的代码中,预训练的入口是在 run_pretraining.py 而 fine-tune 的入口针对不同的任务分别在 run_classifier.py 和 run_squad.py。其中 run_classifier.py 适用的任务为分类任务。如 CoLA、MRPC、MultiNLI 这些数据集。而 run_squad.py 适用的是阅读理解(MRC)任务,如 squad2.0 和 squad1.1。


预训练是 BERT 很重要的一个部分,与此同时,预训练需要巨大的运算资源。按照论文里描述的参数,其 Base 的设定在消费级的显卡 Titan x 或 Titan 1080ti(12GB RAM)上,甚至需要近几个月的时间进行预训练,同时还会面临显存不足的问题。不过所幸的是谷歌满足了 Issues#2 里各国开发者的请求,针对大部分语言都公布了 BERT 的预训练模型。因此在我们可以比较方便地在自己的数据集上进行 fine-tune。

下载预训练模型

对于中文而言,google 公布了一个参数较小的 BERT 预训练模型。具体参数数值如下所示:


Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters


模型的下载链接可以在 github 上 google 的开源代码里找到。对下载的压缩文件进行解压,可以看到文件里有五个文件,其中 bert_model.ckpt 开头的文件是负责模型变量载入的,而 vocab.txt 是训练时中文文本采用的字典,最后 bert_config.json 是 BERT 在训练时,可选调整的一些参数。

修改 processor

任何模型的训练、预测都是需要有一个明确的输入,而 BERT 代码中 processor 就是负责对模型的输入进行处理。我们以分类任务的为例,介绍如何修改 processor 来运行自己数据集上的 fine-tune。在 run_classsifier.py 文件中我们可以看到,google 对于一些公开数据集已经写了一些 processor,如 XnliProcessor,MnliProcessor,MrpcProcessor 和 ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写 processor。


对于一个需要执行训练、交叉验证和测试完整过程的模型而言,自定义的 processor 里需要继承 DataProcessor,并重载获取 label 的 get_labels 和获取单个输入的 get_train_examples,get_dev_examples 和 get_test_examples 函数。其分别会在 main 函数的 FLAGS.do_train、FLAGS.do_eval 和 FLAGS.do_predict 阶段被调用。


这三个函数的内容是相差无几的,区别只在于需要指定各自读入文件的地址。


以 get_train_examples 为例,函数需要返回一个由 InputExample 类组成的 list。InputExample 类是一个很简单的类,只有初始化函数,需要传入的参数中 guid 是用来区分每个 example 的,可以按照 train-%d’%(i)的方式进行定义。text_a 是一串字符串,text_b 则是另一串字符串。在进行后续输入处理后(BERT 代码中已包含,不需要自己完成) text_a 和 text_b 将组合成[CLS] text_a [SEP] text_b [SEP]的形式传入模型。最后一个参数 label 也是字符串的形式,label 的内容需要保证出现在 get_labels 函数返回的 list 里。


举一个例子,假设我们想要处理一个能够判断句子相似度的模型,现在在 data_dir 的路径下有一个名为 train.csv 的输入文件,如果我们现在输入文件的格式如下 csv 形式:


1,你好,您好0,你好,你家住哪 
复制代码


那么我们可以写一个如下的 get_train_examples 的函数。当然对于 csv 的处理,可以使用诸如 csv.reader 的形式进行读入。


def get_train_examples(self, data_dir):    file_path = os.path.join(data_dir, 'train.csv')    with open(file_path, 'r') as f:        reader = f.readlines()    examples = []    for index, line in enumerate(reader):        guid = 'train-%d'%index        split_line = line.strip().split(',')        text_a = tokenization.convert_to_unicode(split_line[1])        text_b = tokenization.convert_to_unicode(split_line[2])        label = split_line[0]        examples.append(InputExample(guid=guid, text_a=text_a,                                      text_b=text_b, label=label))    return examples
复制代码


同时对应判断句子相似度这个二分类任务,get_labels 函数可以写成如下的形式:


def get_labels(self):    return ['0','1']
复制代码


在对 get_dev_examples 和 get_test_examples 函数做类似 get_train_examples 的操作后,便完成了对 processor 的修改。其中 get_test_examples 可以传入一个随意的 label 数值,因为在模型的预测(prediction)中 label 将不会参与计算。

修改 processor 字典

修改完成 processor 后,需要在在原本 main 函数的 processor 字典里,加入修改后的 processor 类,即可在运行参数里指定调用该 processor。


 processors = {      "cola": ColaProcessor,      "mnli": MnliProcessor,      "mrpc": MrpcProcessor,      "xnli": XnliProcessor,       "selfsim": SelfProcessor #添加自己的processor  }
复制代码

运行 fine-tune

之后就可以直接运行 run_classsifier.py 进行模型的训练。在运行时需要制定一些参数,一个较为完整的运行参数如下所示:


export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 #全局变量 下载的预训练bert地址export MY_DATASET=/path/to/xnli #全局变量 数据集所在地址
python run_classifier.py \ --task_name=selfsim \ #自己添加processor在processors字典里的key名 --do_train=true \ --do_eval=true \ --dopredict=true \ --data_dir=$MY_DATASET \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ #模型参数 --train_batch_size=32 \ --learning_rate=5e-5 \ --num_train_epochs=2.0 \ --output_dir=/tmp/selfsim_output/ #模型输出路径
复制代码

BERT 源代码里还有什么

在开始训练我们自己 fine-tune 的 BERT 后,我们可以再来看看 BERT 代码里除了 processor 之外的一些部分。


我们可以发现,process 在得到字符串形式的输入后,在 file_based_convert_examples_to_features 里先是对字符串长度,加入[CLS]和[SEP]等一些处理后,将其写入成 TFrecord 的形式。这是为了能在 estimator 里有一个更为高效和简易的读入。


我们还可以发现,在 create_model 的函数里,除了从 modeling.py 获取模型主干输出之外,还有进行 fine-tune 时候的 loss 计算。因此,如果对于 fine-tune 的结构有自定义的要求,可以在这部分对代码进行修改。如进行 NER 任务的时候,可以按照 BERT 论文里的方式,不只读第一位的 logits,而是将每一位 logits 进行读取。


BERT 这次开源的代码,由于是考虑在 google 自己的 TPU 上高效地运行,因此采用的 estimator 是 tf.contrib.tpu.TPUEstimator,虽然 TPU 的 estimator 同样可以在 gpu 和 cpu 上运行,但若想在 gpu 上更高效地做一些提升,可以考虑将其换成 tf.estimator.Estimator,于此同时 model_fn 里一些 tf.contrib.tpu.TPUEstimatorSpec 也需要修改成 tf.estimator.EstimatorSpec 的形式,以及相关调用参数也需要做一些调整。在转换成较普通的 estimator 后便可以使用常用的方式对 estimator 进行处理,如生成用于部署的.pb 文件等。

GitHub Issues 里一些有趣的内容

从 google 对 BERT 进行开源开始,Issues 里的讨论便异常活跃,BERT 论文第一作者 Jacob Devlin 也积极地在 Issues 里进行回应,在交流讨论中,产生了一些很有趣的内容。


在 GitHub Issues#95 中大家讨论了 BERT 模型在今年 AI-Challenger 比赛上的应用。我们也同样尝试了 BERT 在 AI-Challenger 的机器阅读理解(mrc)赛道的表现。如果简单得地将 mrc 的文本连接成一个长字符串的形式,可以在 dev 集上得到 79.1%的准确率。


如果参考 openAI 的 GPT 论文里 multi-choice 的形式对 BERT 的输入输出代码进行修改则可以将准确率提高到 79.3%。采用的参数都是 BERT 默认的参数,而单一模型成绩在赛道的 test a 排名中已经能超过榜单上的第一名。因此,在相关中文的任务中,bert 能有很大的想象空间。


在 GitHub Issues#123 中,@hanxiao 给出了一个采用 ZeroMQ 便捷部署 BERT 的 service,可以直接调用训练好的模型作为应用的接口。同时他将 BERT 改为一个大的 encode 模型,将文本通过 BERT 进行 encode,来实现句子级的 encode。此外,他对比了多 GPU 上的性能,发现 bert 在多 GPU 并行上的出色表现。

总结

总的来说,google 此次开源的 BERT 和其预训练模型是非常有价值的,可探索和改进的内容也很多。相关数据集上已经出现了对 BERT 进行修改后的复合模型,如 squad2.0 上哈工大(HIT)的 AoA + DA + BERT 以及西湖大学(DAMO)的 SLQA + BERT。 在感谢 google 这份付出的同时,我们也可以借此站在巨人的肩膀上,尝试将其运用在自然语言处理领域的方方面面,让人工智能的梦想更近一步。





2018-11-26 14:195716

评论 1 条评论

发布
暂无评论
发现更多内容

2023-02-12:给定正数N,表示用户数量,用户编号从0~N-1, 给定正数M,表示实验数量,实验编号从0~M-1, 给定长度为N的二维数组A, A[i] = { a, b, c }表示,用户i报

福大大架构师每日一题

算法 rust 福大大

KMP算法详解

javaadu

数据结构 字符串 KMP

架构实战营-模块一作业

🐢先生

架构实战营

三次握手与四次挥的问题,怎么回答?

loveX001

JavaScript

FL Studio2023最新版本音乐编曲制作软件

茶色酒

FL Studio2023

Shell分支语句

圆弧

分支 条件 shell脚本

零基础入门AI?先来把机器学习捣鼓明白吧

博文视点Broadview

百度前端常考vue面试题(附答案)

bb_xiaxia1998

Vue

Vue.$nextTick的原理是什么-vue面试进阶

bb_xiaxia1998

Vue

vivo 自研Jenkins资源调度系统设计与实践

vivo互联网技术

运维 jenkins 资源调度

前端必会面试题

loveX001

JavaScript

用Docker搭建更酷的本地开发环境

致知Fighting

Java Docker Linux 后端 开发

前端react面试题指南

beifeng1996

React

焕新启航,「龙蜥大讲堂」2023 年度招募来了!13 场技术分享先睹为快

OpenAnolis小助手

直播 开源社区 龙蜥大讲堂 机密计算 月度主题

为什么补码是取反加1?

Dinfan

问:React的setState为什么是异步的?

beifeng1996

React

promise执行顺序面试题令我头秃,你能作对几道

loveX001

JavaScript

1行Python代码去除图片水印,网友:干干净净!

程序员晚枫

Python GitHub 开源 去水印 自动化办公

为什么用元空间替代永久代?

王磊

java面试

前端react面试题(边面边更)

beifeng1996

React

Vue的computed和watch的区别是什么?

bb_xiaxia1998

Vue

高级前端二面vue面试题(持续更新中)

bb_xiaxia1998

Vue

手写JS函数的call、apply、bind

helloworld1024fd

JavaScript

一个容器,但是一整个k8s集群

newbe36524

C# Docker Kubernetes

实现一个简单的Database9(译文)

GreatSQL

sqlite greatsql greatsql社区

老生常谈React的diff算法原理-面试版

beifeng1996

React

产品的可持续发展

ShineScrum

产品 产品负责人 产品的可持续发展

A-Ops性能火焰图——适用于云原生的全栈持续性能监测工具

openEuler

Linux 运维 操作系统 定位 性能监控

前端一面常考手写面试题整理

helloworld1024fd

JavaScript

被流量和热度裹挟,自媒体行业必须坚守职业道德

石头IT视角

2023我的前端面试小结

loveX001

JavaScript

干货 | BERT fine-tune 终极实践教程_AI&大模型_奇点机智_InfoQ精选文章