2天时间,聊今年最热的 Agent、上下文工程、AI 产品创新等话题。2025 年最后一场~ 了解详情
写点什么

干货 | BERT fine-tune 终极实践教程

  • 2018-11-26
  • 本文字数:3768 字

    阅读完需:约 12 分钟

干货 | BERT fine-tune 终极实践教程

从 11 月初开始,google-research 就陆续开源了 BERT 的各个版本。google 此次开源的 BERT 是通过 tensorflow 高级 API—— tf.estimator 进行封装(wrapper)的。因此对于不同数据集的适配,只需要修改代码中的 processor 部分,就能进行代码的训练、交叉验证和测试。

在自己的数据集上运行 BERT

BERT 的代码同论文里描述的一致,主要分为两个部分。一个是训练语言模型(language model)的预训练(pretrain)部分。另一个是训练具体任务(task)的 fine-tune 部分。在开源的代码中,预训练的入口是在 run_pretraining.py 而 fine-tune 的入口针对不同的任务分别在 run_classifier.py 和 run_squad.py。其中 run_classifier.py 适用的任务为分类任务。如 CoLA、MRPC、MultiNLI 这些数据集。而 run_squad.py 适用的是阅读理解(MRC)任务,如 squad2.0 和 squad1.1。


预训练是 BERT 很重要的一个部分,与此同时,预训练需要巨大的运算资源。按照论文里描述的参数,其 Base 的设定在消费级的显卡 Titan x 或 Titan 1080ti(12GB RAM)上,甚至需要近几个月的时间进行预训练,同时还会面临显存不足的问题。不过所幸的是谷歌满足了 Issues#2 里各国开发者的请求,针对大部分语言都公布了 BERT 的预训练模型。因此在我们可以比较方便地在自己的数据集上进行 fine-tune。

下载预训练模型

对于中文而言,google 公布了一个参数较小的 BERT 预训练模型。具体参数数值如下所示:


Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters


模型的下载链接可以在 github 上 google 的开源代码里找到。对下载的压缩文件进行解压,可以看到文件里有五个文件,其中 bert_model.ckpt 开头的文件是负责模型变量载入的,而 vocab.txt 是训练时中文文本采用的字典,最后 bert_config.json 是 BERT 在训练时,可选调整的一些参数。

修改 processor

任何模型的训练、预测都是需要有一个明确的输入,而 BERT 代码中 processor 就是负责对模型的输入进行处理。我们以分类任务的为例,介绍如何修改 processor 来运行自己数据集上的 fine-tune。在 run_classsifier.py 文件中我们可以看到,google 对于一些公开数据集已经写了一些 processor,如 XnliProcessor,MnliProcessor,MrpcProcessor 和 ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写 processor。


对于一个需要执行训练、交叉验证和测试完整过程的模型而言,自定义的 processor 里需要继承 DataProcessor,并重载获取 label 的 get_labels 和获取单个输入的 get_train_examples,get_dev_examples 和 get_test_examples 函数。其分别会在 main 函数的 FLAGS.do_train、FLAGS.do_eval 和 FLAGS.do_predict 阶段被调用。


这三个函数的内容是相差无几的,区别只在于需要指定各自读入文件的地址。


以 get_train_examples 为例,函数需要返回一个由 InputExample 类组成的 list。InputExample 类是一个很简单的类,只有初始化函数,需要传入的参数中 guid 是用来区分每个 example 的,可以按照 train-%d’%(i)的方式进行定义。text_a 是一串字符串,text_b 则是另一串字符串。在进行后续输入处理后(BERT 代码中已包含,不需要自己完成) text_a 和 text_b 将组合成[CLS] text_a [SEP] text_b [SEP]的形式传入模型。最后一个参数 label 也是字符串的形式,label 的内容需要保证出现在 get_labels 函数返回的 list 里。


举一个例子,假设我们想要处理一个能够判断句子相似度的模型,现在在 data_dir 的路径下有一个名为 train.csv 的输入文件,如果我们现在输入文件的格式如下 csv 形式:


1,你好,您好0,你好,你家住哪 
复制代码


那么我们可以写一个如下的 get_train_examples 的函数。当然对于 csv 的处理,可以使用诸如 csv.reader 的形式进行读入。


def get_train_examples(self, data_dir):    file_path = os.path.join(data_dir, 'train.csv')    with open(file_path, 'r') as f:        reader = f.readlines()    examples = []    for index, line in enumerate(reader):        guid = 'train-%d'%index        split_line = line.strip().split(',')        text_a = tokenization.convert_to_unicode(split_line[1])        text_b = tokenization.convert_to_unicode(split_line[2])        label = split_line[0]        examples.append(InputExample(guid=guid, text_a=text_a,                                      text_b=text_b, label=label))    return examples
复制代码


同时对应判断句子相似度这个二分类任务,get_labels 函数可以写成如下的形式:


def get_labels(self):    return ['0','1']
复制代码


在对 get_dev_examples 和 get_test_examples 函数做类似 get_train_examples 的操作后,便完成了对 processor 的修改。其中 get_test_examples 可以传入一个随意的 label 数值,因为在模型的预测(prediction)中 label 将不会参与计算。

修改 processor 字典

修改完成 processor 后,需要在在原本 main 函数的 processor 字典里,加入修改后的 processor 类,即可在运行参数里指定调用该 processor。


 processors = {      "cola": ColaProcessor,      "mnli": MnliProcessor,      "mrpc": MrpcProcessor,      "xnli": XnliProcessor,       "selfsim": SelfProcessor #添加自己的processor  }
复制代码

运行 fine-tune

之后就可以直接运行 run_classsifier.py 进行模型的训练。在运行时需要制定一些参数,一个较为完整的运行参数如下所示:


export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 #全局变量 下载的预训练bert地址export MY_DATASET=/path/to/xnli #全局变量 数据集所在地址
python run_classifier.py \ --task_name=selfsim \ #自己添加processor在processors字典里的key名 --do_train=true \ --do_eval=true \ --dopredict=true \ --data_dir=$MY_DATASET \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ #模型参数 --train_batch_size=32 \ --learning_rate=5e-5 \ --num_train_epochs=2.0 \ --output_dir=/tmp/selfsim_output/ #模型输出路径
复制代码

BERT 源代码里还有什么

在开始训练我们自己 fine-tune 的 BERT 后,我们可以再来看看 BERT 代码里除了 processor 之外的一些部分。


我们可以发现,process 在得到字符串形式的输入后,在 file_based_convert_examples_to_features 里先是对字符串长度,加入[CLS]和[SEP]等一些处理后,将其写入成 TFrecord 的形式。这是为了能在 estimator 里有一个更为高效和简易的读入。


我们还可以发现,在 create_model 的函数里,除了从 modeling.py 获取模型主干输出之外,还有进行 fine-tune 时候的 loss 计算。因此,如果对于 fine-tune 的结构有自定义的要求,可以在这部分对代码进行修改。如进行 NER 任务的时候,可以按照 BERT 论文里的方式,不只读第一位的 logits,而是将每一位 logits 进行读取。


BERT 这次开源的代码,由于是考虑在 google 自己的 TPU 上高效地运行,因此采用的 estimator 是 tf.contrib.tpu.TPUEstimator,虽然 TPU 的 estimator 同样可以在 gpu 和 cpu 上运行,但若想在 gpu 上更高效地做一些提升,可以考虑将其换成 tf.estimator.Estimator,于此同时 model_fn 里一些 tf.contrib.tpu.TPUEstimatorSpec 也需要修改成 tf.estimator.EstimatorSpec 的形式,以及相关调用参数也需要做一些调整。在转换成较普通的 estimator 后便可以使用常用的方式对 estimator 进行处理,如生成用于部署的.pb 文件等。

GitHub Issues 里一些有趣的内容

从 google 对 BERT 进行开源开始,Issues 里的讨论便异常活跃,BERT 论文第一作者 Jacob Devlin 也积极地在 Issues 里进行回应,在交流讨论中,产生了一些很有趣的内容。


在 GitHub Issues#95 中大家讨论了 BERT 模型在今年 AI-Challenger 比赛上的应用。我们也同样尝试了 BERT 在 AI-Challenger 的机器阅读理解(mrc)赛道的表现。如果简单得地将 mrc 的文本连接成一个长字符串的形式,可以在 dev 集上得到 79.1%的准确率。


如果参考 openAI 的 GPT 论文里 multi-choice 的形式对 BERT 的输入输出代码进行修改则可以将准确率提高到 79.3%。采用的参数都是 BERT 默认的参数,而单一模型成绩在赛道的 test a 排名中已经能超过榜单上的第一名。因此,在相关中文的任务中,bert 能有很大的想象空间。


在 GitHub Issues#123 中,@hanxiao 给出了一个采用 ZeroMQ 便捷部署 BERT 的 service,可以直接调用训练好的模型作为应用的接口。同时他将 BERT 改为一个大的 encode 模型,将文本通过 BERT 进行 encode,来实现句子级的 encode。此外,他对比了多 GPU 上的性能,发现 bert 在多 GPU 并行上的出色表现。

总结

总的来说,google 此次开源的 BERT 和其预训练模型是非常有价值的,可探索和改进的内容也很多。相关数据集上已经出现了对 BERT 进行修改后的复合模型,如 squad2.0 上哈工大(HIT)的 AoA + DA + BERT 以及西湖大学(DAMO)的 SLQA + BERT。 在感谢 google 这份付出的同时,我们也可以借此站在巨人的肩膀上,尝试将其运用在自然语言处理领域的方方面面,让人工智能的梦想更近一步。





2018-11-26 14:195984

评论 1 条评论

发布
暂无评论
发现更多内容

代码理解技术应用实践介绍

百度Geek说

数据库 百度 企业号10月PK榜 代码理解

鞍钢集团财务共享平台部长高歌:与用友共建财务共享领先实践

用友BIP

2023全球商业创新大会

聊聊技术之外的面试问题-下

老张

面试 职场成长

Apache Kyuubi & Celeborn,助力 Spark 拥抱云原生

阿里云大数据AI技术

云原生

Java 21新特性-虚拟线程

越长大越悲伤

Java

Java基础面试题【六】线程(2)

派大星

Java 面试题

秒懂算法 | 字符串匹配算法实例分析之潜伏者、最低三元字符串

TiAmo

算法 字符串匹配

可以替代Mac访达的文件管理工具Path Finder

展初云

Mac软件 文件管理工具

为什么说代码注释是程序员必备的技能?

小齐写代码

Databend 开源周报第 114 期

Databend

IDEA工具第一篇:细节使用-习惯设置 | 京东云技术团队

京东科技开发者

Mac windows IDEA 企业号10月PK榜

从内核世界透视 mmap 内存映射的本质(源码实现篇)

bin的技术小屋

Linux 操作系统 内存管理 Linux Kenel mmap内存映射

视频回放编辑软件Mitti最新免激活版

胖墩儿不胖y

Mac软件 音频编辑 音频处理工具

开放原子开源基金会九月新增捐赠人

开放原子开源基金会

模型UV纹理设置工具

3D建模设计

材质 纹理 贴图

GLTF纹理贴图工具让模型更逼真

3D建模设计

材质 纹理 贴图

秒验:可以自定义UI的一键登录服务

MobTech袤博科技

大数据 智能推送

简单好用的Mac清理工具 BuhoCleaner

展初云

Mac软件 清理软件

Tongsuo 8.4.0-pre3 发布!

铜锁开源密码库

开源 算法 安全 同态加密 密码学

建立可观测性宏观认知-从概念到过去10年的实践发展

刘绍

运维 软件工程 可观测性 基础架构

我在前端写Java SpringBoot项目 | 京东云技术团队

京东科技开发者

MySQL Node Nest.js 企业号10月PK榜 Sequelize

全球采购,打造企业韧性供应链

用友BIP

全球采购

聚合电商API接口平台:让数据成为生产力!

Noah

数据api API 安全 电商api接口

国内首档大模型技术直播专栏重磅推出!

飞桨PaddlePaddle

开发者说 文心大模型

什么是 API 以及电子商务网站为何使用它们

Noah

电商 数据api API 安全

直播预约丨《实时湖仓实践五讲》第二讲:实时湖仓功能架构设计与落地实战

袋鼠云数栈

大数据 数据中台 湖仓一体 实时湖仓 直播课程

干货 | BERT fine-tune 终极实践教程_AI&大模型_奇点机智_InfoQ精选文章