写点什么

干货 | BERT fine-tune 终极实践教程

  • 2018-11-26
  • 本文字数:3768 字

    阅读完需:约 12 分钟

干货 | BERT fine-tune 终极实践教程

从 11 月初开始,google-research 就陆续开源了 BERT 的各个版本。google 此次开源的 BERT 是通过 tensorflow 高级 API—— tf.estimator 进行封装(wrapper)的。因此对于不同数据集的适配,只需要修改代码中的 processor 部分,就能进行代码的训练、交叉验证和测试。

在自己的数据集上运行 BERT

BERT 的代码同论文里描述的一致,主要分为两个部分。一个是训练语言模型(language model)的预训练(pretrain)部分。另一个是训练具体任务(task)的 fine-tune 部分。在开源的代码中,预训练的入口是在 run_pretraining.py 而 fine-tune 的入口针对不同的任务分别在 run_classifier.py 和 run_squad.py。其中 run_classifier.py 适用的任务为分类任务。如 CoLA、MRPC、MultiNLI 这些数据集。而 run_squad.py 适用的是阅读理解(MRC)任务,如 squad2.0 和 squad1.1。


预训练是 BERT 很重要的一个部分,与此同时,预训练需要巨大的运算资源。按照论文里描述的参数,其 Base 的设定在消费级的显卡 Titan x 或 Titan 1080ti(12GB RAM)上,甚至需要近几个月的时间进行预训练,同时还会面临显存不足的问题。不过所幸的是谷歌满足了 Issues#2 里各国开发者的请求,针对大部分语言都公布了 BERT 的预训练模型。因此在我们可以比较方便地在自己的数据集上进行 fine-tune。

下载预训练模型

对于中文而言,google 公布了一个参数较小的 BERT 预训练模型。具体参数数值如下所示:


Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters


模型的下载链接可以在 github 上 google 的开源代码里找到。对下载的压缩文件进行解压,可以看到文件里有五个文件,其中 bert_model.ckpt 开头的文件是负责模型变量载入的,而 vocab.txt 是训练时中文文本采用的字典,最后 bert_config.json 是 BERT 在训练时,可选调整的一些参数。

修改 processor

任何模型的训练、预测都是需要有一个明确的输入,而 BERT 代码中 processor 就是负责对模型的输入进行处理。我们以分类任务的为例,介绍如何修改 processor 来运行自己数据集上的 fine-tune。在 run_classsifier.py 文件中我们可以看到,google 对于一些公开数据集已经写了一些 processor,如 XnliProcessor,MnliProcessor,MrpcProcessor 和 ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写 processor。


对于一个需要执行训练、交叉验证和测试完整过程的模型而言,自定义的 processor 里需要继承 DataProcessor,并重载获取 label 的 get_labels 和获取单个输入的 get_train_examples,get_dev_examples 和 get_test_examples 函数。其分别会在 main 函数的 FLAGS.do_train、FLAGS.do_eval 和 FLAGS.do_predict 阶段被调用。


这三个函数的内容是相差无几的,区别只在于需要指定各自读入文件的地址。


以 get_train_examples 为例,函数需要返回一个由 InputExample 类组成的 list。InputExample 类是一个很简单的类,只有初始化函数,需要传入的参数中 guid 是用来区分每个 example 的,可以按照 train-%d’%(i)的方式进行定义。text_a 是一串字符串,text_b 则是另一串字符串。在进行后续输入处理后(BERT 代码中已包含,不需要自己完成) text_a 和 text_b 将组合成[CLS] text_a [SEP] text_b [SEP]的形式传入模型。最后一个参数 label 也是字符串的形式,label 的内容需要保证出现在 get_labels 函数返回的 list 里。


举一个例子,假设我们想要处理一个能够判断句子相似度的模型,现在在 data_dir 的路径下有一个名为 train.csv 的输入文件,如果我们现在输入文件的格式如下 csv 形式:


1,你好,您好0,你好,你家住哪 
复制代码


那么我们可以写一个如下的 get_train_examples 的函数。当然对于 csv 的处理,可以使用诸如 csv.reader 的形式进行读入。


def get_train_examples(self, data_dir):    file_path = os.path.join(data_dir, 'train.csv')    with open(file_path, 'r') as f:        reader = f.readlines()    examples = []    for index, line in enumerate(reader):        guid = 'train-%d'%index        split_line = line.strip().split(',')        text_a = tokenization.convert_to_unicode(split_line[1])        text_b = tokenization.convert_to_unicode(split_line[2])        label = split_line[0]        examples.append(InputExample(guid=guid, text_a=text_a,                                      text_b=text_b, label=label))    return examples
复制代码


同时对应判断句子相似度这个二分类任务,get_labels 函数可以写成如下的形式:


def get_labels(self):    return ['0','1']
复制代码


在对 get_dev_examples 和 get_test_examples 函数做类似 get_train_examples 的操作后,便完成了对 processor 的修改。其中 get_test_examples 可以传入一个随意的 label 数值,因为在模型的预测(prediction)中 label 将不会参与计算。

修改 processor 字典

修改完成 processor 后,需要在在原本 main 函数的 processor 字典里,加入修改后的 processor 类,即可在运行参数里指定调用该 processor。


 processors = {      "cola": ColaProcessor,      "mnli": MnliProcessor,      "mrpc": MrpcProcessor,      "xnli": XnliProcessor,       "selfsim": SelfProcessor #添加自己的processor  }
复制代码

运行 fine-tune

之后就可以直接运行 run_classsifier.py 进行模型的训练。在运行时需要制定一些参数,一个较为完整的运行参数如下所示:


export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 #全局变量 下载的预训练bert地址export MY_DATASET=/path/to/xnli #全局变量 数据集所在地址
python run_classifier.py \ --task_name=selfsim \ #自己添加processor在processors字典里的key名 --do_train=true \ --do_eval=true \ --dopredict=true \ --data_dir=$MY_DATASET \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ #模型参数 --train_batch_size=32 \ --learning_rate=5e-5 \ --num_train_epochs=2.0 \ --output_dir=/tmp/selfsim_output/ #模型输出路径
复制代码

BERT 源代码里还有什么

在开始训练我们自己 fine-tune 的 BERT 后,我们可以再来看看 BERT 代码里除了 processor 之外的一些部分。


我们可以发现,process 在得到字符串形式的输入后,在 file_based_convert_examples_to_features 里先是对字符串长度,加入[CLS]和[SEP]等一些处理后,将其写入成 TFrecord 的形式。这是为了能在 estimator 里有一个更为高效和简易的读入。


我们还可以发现,在 create_model 的函数里,除了从 modeling.py 获取模型主干输出之外,还有进行 fine-tune 时候的 loss 计算。因此,如果对于 fine-tune 的结构有自定义的要求,可以在这部分对代码进行修改。如进行 NER 任务的时候,可以按照 BERT 论文里的方式,不只读第一位的 logits,而是将每一位 logits 进行读取。


BERT 这次开源的代码,由于是考虑在 google 自己的 TPU 上高效地运行,因此采用的 estimator 是 tf.contrib.tpu.TPUEstimator,虽然 TPU 的 estimator 同样可以在 gpu 和 cpu 上运行,但若想在 gpu 上更高效地做一些提升,可以考虑将其换成 tf.estimator.Estimator,于此同时 model_fn 里一些 tf.contrib.tpu.TPUEstimatorSpec 也需要修改成 tf.estimator.EstimatorSpec 的形式,以及相关调用参数也需要做一些调整。在转换成较普通的 estimator 后便可以使用常用的方式对 estimator 进行处理,如生成用于部署的.pb 文件等。

GitHub Issues 里一些有趣的内容

从 google 对 BERT 进行开源开始,Issues 里的讨论便异常活跃,BERT 论文第一作者 Jacob Devlin 也积极地在 Issues 里进行回应,在交流讨论中,产生了一些很有趣的内容。


在 GitHub Issues#95 中大家讨论了 BERT 模型在今年 AI-Challenger 比赛上的应用。我们也同样尝试了 BERT 在 AI-Challenger 的机器阅读理解(mrc)赛道的表现。如果简单得地将 mrc 的文本连接成一个长字符串的形式,可以在 dev 集上得到 79.1%的准确率。


如果参考 openAI 的 GPT 论文里 multi-choice 的形式对 BERT 的输入输出代码进行修改则可以将准确率提高到 79.3%。采用的参数都是 BERT 默认的参数,而单一模型成绩在赛道的 test a 排名中已经能超过榜单上的第一名。因此,在相关中文的任务中,bert 能有很大的想象空间。


在 GitHub Issues#123 中,@hanxiao 给出了一个采用 ZeroMQ 便捷部署 BERT 的 service,可以直接调用训练好的模型作为应用的接口。同时他将 BERT 改为一个大的 encode 模型,将文本通过 BERT 进行 encode,来实现句子级的 encode。此外,他对比了多 GPU 上的性能,发现 bert 在多 GPU 并行上的出色表现。

总结

总的来说,google 此次开源的 BERT 和其预训练模型是非常有价值的,可探索和改进的内容也很多。相关数据集上已经出现了对 BERT 进行修改后的复合模型,如 squad2.0 上哈工大(HIT)的 AoA + DA + BERT 以及西湖大学(DAMO)的 SLQA + BERT。 在感谢 google 这份付出的同时,我们也可以借此站在巨人的肩膀上,尝试将其运用在自然语言处理领域的方方面面,让人工智能的梦想更近一步。





2018-11-26 14:195785

评论 1 条评论

发布
暂无评论
发现更多内容

selenium源码通读·10 |webdriver/common/proxy.py-Proxy类分析

Python 测试 自动化测试 源码剖析 selenium

北大GPT解题有数学老师内味了,用人话讲难题,从高中数学到高数都能搞定

Openlab_cosmoplat

人工智能 开源社区

提交代码「前置处理」,向前一小步,效率提升「亿点点」

极狐GitLab

DevOps 极狐GitLab git hook lefthook 代码前置

MegEngine 使用小技巧:量化

MegEngineBot

量化 MegEngine

九科信息流程挖掘产品bit-Miner即将开放面向对象流程挖掘能力

九科Ninetech

流程挖掘

MobTech MobLink|引流统计一站式服务

MobTech袤博科技

面试还不懂JVM性能调优,看这篇文章就够了!

程序员小毕

程序员 面试 后端 JVM jvm调优

乌合之众再次上演,打工人将被AI一键淘汰?

引迈信息

人工智能 AI 低代码 AIGC ChatGPT

“精准测试” 在商家地址专项的探索 | 得物技术

得物技术

手势识别:让你的手成为计算机的新界面

数据堂

行业分析| 视频监控——AI自动巡检

anyRTC开发者

人工智能 音视频 视频监控 自动巡检

构建云边端一体的分布式云架构,软硬结合驱动边缘计算创新场景

百度开发者中心

云计算 存储 边缘云

iOS MachineLearning 系列(4)—— 静态图像分析之物体识别与分类

珲少

世界读书日特辑 | 华为阅读深耕精品书,让读书变得赏心“悦目”

最新动态

vue 入门知识点有哪些?

海拥(haiyong.site)

三周年连更

烟雾弹?突然转变?如何看待微软发声:中国是主要的对手

加入高科技仿生人

人工智能 AI 数智化 ChatGPT

您有一份直播回放待查收!

BinTools图尔兹

直播回放 版本发布

如何开发一个小程序自定义组件

Onegun

小程序 前端 小程序组件

在企业内容城池边,它建立起一支保卫军

ToB行业头条

卷起来了!阿里最新出品“微服务全阶笔记”,涵盖微服务全部操作

收到请回复

架构 #编程 #微服务

深度分享 | API 测试经济学与 API First 践行

Apifox

程序员 前端 接口 后端 API

中国垂直行业SaaS,这样走可能是新出路

ToB行业头条

HarmonyOS Codelabs最新参考

坚果

OpenHarmony 三周年连更

让 AI 更简单 人工智能平台 SEAL 携手龙蜥落地达摩院算法能力 | 龙蜥案例

OpenAnolis小助手

开源 操作系统 SEAL 达摩院 龙蜥案例

向量嵌入:AutoGPT的幻觉解法?

OneFlow

给广场舞大妈讲讲什么是大语言模型!

FN0

AIGC

知名直播App被苹果商店下架,或涉及侵权问题

曲多多(嗨翻屋)版权音乐

ios iphone 软件开发

走进社区客户端测试 | 得物技术

得物技术

测试

智慧园区数字转型下的移动App建设策略

Onegun

移动应用 智慧城市 智慧园区

视频大文件传输的演变:从“卷轴男孩”到自动化

镭速

把“ai模型+低代码”应用在项目管理中,效率翻了好几倍

优秀

AI 低代码

干货 | BERT fine-tune 终极实践教程_AI&大模型_奇点机智_InfoQ精选文章