HarmonyOS开发者限时福利来啦!最高10w+现金激励等你拿~ 了解详情
写点什么

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

  • 2019-12-12
  • 本文字数:7860 字

    阅读完需:约 26 分钟

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

文本分类是指将给定文本按照其内容判别到一个或多个预先确定的文本类别中的过程。文本分类是一种典型的有监督的学习过程,根据已经被标记的文本集合,通过学习,得到一个文本特征和文本类别之间的关系模型,然后利用这个关系模型对新文本进行类别判断。文本分类计数用于识别文档主题,并将之归类到预先定义的主题或主题集合中。

需要注意的是,多类文本分类与多标签分类并不同,其中多类分类区别于二分类问题,即在 个类别中互斥地选取一个作为输出;而多标签分类,是在 n 个标签中非互斥地选取 个标签作为输出。本文介绍了如何基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类,非常实用,希望对读者有所启迪。


对自然语言处理(Natural Language Processing,NLP)领域来说,很多创新之处都是关于如何在词向量中加入上下文。常用的方法之一就是使用递归神经网络(Recurrent Neural Networks,RNN)。下面是递归神经网络的概念:


  • 它们利用顺序信息。

  • 它们具备记忆能力,能够记住到目前为止计算过的内容,也就是说,我最后说的内容将影响我接下来要讲的内容。

  • 递归神经网络是文本和语音分析的理想选择。

  • 最常用的递归神经网络是长短期记忆网络(Long-Short Term Memory,LSTM)。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/


上图是递归神经网络的架构。


  • “A” 是前馈神经网络(Feedforward neural network)的一层。

  • 如果我们只看右边的话,它会递归地遍历每个序列的元素。

  • 如果我们将左边展开,它看起来将会跟右边一模一样。


译注: 前馈神经网络(Feedforward neural network),是最早发明、最简单的人工神经网络类型。在它内部,参数从输入层向输出层单向传播。和递归神经网络不通,它内部不会构成有向环。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


假设我们正在解决新闻文章数据集的文档分类问题。


  • 我们输入每个单词,这些单词以某种方式相互关联。

  • 当我们看到文章中所有的单词时,我们会在文章末尾做出预测。

  • 递归神经网络通过传递上一次输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


  • 这对于短句很有效,但当我们处理一篇长文章时,将会有一个长期依赖问题。


因此,我们通常不是用普通的递归神经网络,而是使用长短期记忆网络。长短期记忆网络是一种递归神经网络,可以解决这种长期依赖问题。


译注: 长短期记忆网络(Long Short-Term Memory,LSTM),是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。基于长短期记忆网络的系统可以实现机器翻译、视频分析、文档摘要、语音识别、图像识别、手写识别、控制聊天机器人、合成音乐等任务。



在我们的新闻文章文档分类示例中,有这种多对一的关系。输入是单词序列,而输出是单个类或标签。


现在,我们将使用 TensorFlow 2.0Keras,解决一个使用长短期记忆网络的 BBC 新闻文档分类问题。数据集可以点击此链接来获取。


  • 首先,我们导入库,并确保 TensorFlow 是正确的版本。


import csvimport tensorflow as tfimport numpy as npfrom tensorflow.keras.preprocessing.text import Tokenizerfrom tensorflow.keras.preprocessing.sequence import pad_sequencesfrom nltk.corpus import stopwordsSTOPWORDS = set(stopwords.words('english'))print(tf.__version__)
复制代码



  • 将超参数置于顶部,如下所示,便于进行更改和编辑。

  • 届时,我们将会讲解每个超参数是如何工作的。


vocab_size = 5000embedding_dim = 64max_length = 200trunc_type = 'post'padding_type = 'post'oov_tok = '<OOV>'training_portion = .8
复制代码


  • 定义两个包含文章和标签的列表。同时,我们删除了停用词。


articles = []labels = []with open("bbc-text.csv", 'r') as csvfile:    reader = csv.reader(csvfile, delimiter=',')    next(reader)    for row in reader:        labels.append(row[0])        article = row[1]        for word in STOPWORDS:            token = ' ' + word + ' '            article = article.replace(token, ' ')            article = article.replace(' ', ' ')        articles.append(article)print(len(labels))print(len(articles))
复制代码



数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。


train_size = int(len(articles) * training_portion)train_articles = articles[0: train_size]train_labels = labels[0: train_size]validation_articles = articles[train_size:]validation_labels = labels[train_size:]print(train_size)print(len(train_articles))print(len(train_labels))print(len(validation_articles))print(len(validation_labels))
复制代码



词法分析器(Tokenizer)为我们承担了所有繁重的工作。在我们的文章中,它将进行标记化,需要 5000 个最常见的单词。oov_token 是在遇到不可见的单词时放入一个特殊的值。这意味着我们希望 <OOV> 用于不在 word_index 中的单词。fit_on_text 将遍历所有文本,并创建如下词典:


tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)tokenizer.fit_on_texts(train_articles)word_index = tokenizer.word_indexdict(list(word_index.items())[0:10])
复制代码


译注: 词法分析器(Tokenizer),是计算机科学中将字符串行转换为标记(token)串行的过程。进行词法分析的进程或者函数叫作词法分析器(lexical analyzer,简称 lexer),也叫扫描器(scanner)。词法分析器一般以函数的形式存在,供语法分析器调用。



我们可以看到,“”是我们语料库中最常见的令牌,其次是“said”、“mr”等等。


完成标记化之后,下一步就是将这些标记转换为序列列表。下面是已经转换成序列的训练数据中的第 11 篇文章。


train_sequences = tokenizer.texts_to_sequences(train_articles)print(train_sequences[10])



" 图 1"


当我们为自然语言处理训练神经网络时,我们需要相同大小的序列,这就是我们为什么使用填充的原因。如果你查看一下的话,就会发现,我们的 max_length 是 200,所以我们使用 pad_sequences ,将所有文章的长度都设置为 200。结果,你会看到第一篇文章长度为 426,变成了 200;第二篇是 192,也变成了 200。以此类推。


train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(train_sequences[0]))print(len(train_padded[0]))print(len(train_sequences[1]))print(len(train_padded[1]))print(len(train_sequences[10]))print(len(train_padded[10]))
复制代码



此外,还有 padding_typetruncating_type, 还有所有的 post,例如,第 11 篇文章的长度是 186,我们需要填充到 200,我们就在结尾处开始填充,也就是说,填充了 14 个 0。


print(train_padded[10])



" 图 2"


对于第一篇文章,它的长度为 426,我们需要将其截断到 200,我们就在结尾处截断。


然后,我们对验证序列执行同样的操作。


validation_sequences = tokenizer.texts_to_sequences(validation_articles)validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(validation_sequences))print(validation_padded.shape)
复制代码



现在,我们来看一下标签。因为我们的标签是文本,因此,我们将它们进行标记。在训练时,标签应该是 numpy 数组。所以,我们要将标签列表转换为 numpy 数组,如下所示:


label_tokenizer = Tokenizer()label_tokenizer.fit_on_texts(labels)training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))print(training_label_seq[0])print(training_label_seq[1])print(training_label_seq[2])print(training_label_seq.shape)print(validation_label_seq[0])print(validation_label_seq[1])print(validation_label_seq[2])print(validation_label_seq.shape)
复制代码



在训练深度神经网络之前,我们应该探索一下我们的原始文章和填充后的文章是什么样子的。运行下面的代码,我们浏览第 11 篇文章,可以看到,一些单词变成了“”,因为它们没有进入前 5000。


reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])def decode_article(text):    return ' '.join([reverse_word_index.get(i, '?') for i in text])print(decode_article(train_padded[10]))print('---')print(train_articles[10])
复制代码



“图 3”


现在,是实施长短期记忆网络的时候了。


  • 我们构建了一个 tf.keras.Sequential 模型,从嵌入层开始。嵌入层为每个单词存储一个向量。调用时,它将单词索引序列转换为向量序列。经过训练后,具有相似意义的单词,通常会具有相似的向量。

  • 双向包装器(Bidirectional wrapper)与 LSTM 层一起使用,它通过 LSTM 层向前和向后传播输入,然后连接输出。这有助于长短期记忆网络学习长期依赖关系。然后我们将其拟合到密集神经网络(Dense Neural Network)中进行分类。

  • 我们使用 relu 代替 than 函数,因为这两个函数能够彼此很好地相互替代。

  • 我们添加了 6 个单位和 softmax 激活的密集层(Dense Layer)。当我们有多个输出时,softmax 将输出层转换为概率分布。


model = tf.keras.Sequential([    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top    tf.keras.layers.Embedding(vocab_size, embedding_dim),    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),    # use ReLU in place of tanh function since they are very good alternatives of each other.    tf.keras.layers.Dense(embedding_dim, activation='relu'),    # Add a Dense layer with 6 units and softmax activation.    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.    tf.keras.layers.Dense(6, activation='softmax')])model.summary()
复制代码



“图 4”


在我们的模型摘要中,我们有嵌入,双向包含长短期记忆网络,然后就是两个密集层(Dense layer)。双向的输出为 128,因为它是我们在长短期记忆网络中输入的两倍。我们也可以堆叠 LSTM 层,但我们发现,结果反而更糟。


print(set(labels))



我们总共有 5 个标签,但因为我们没有对标签进行独热编码(One-hot encode),因此,我们不得不使用


sparse_categorical_crossentropy 作为损失函数,它似乎认为 0 也是一个可能的标签,而词法分析器对象是从整数 1 开始标记化,而不是整数 0。结果,尽管从未使用过 0,但最后一个密集层需要标签 0、1、2、3、4、5 的输出。


如果你希望最后一个密集层为 5,那么你就需要从训练和验证标签中减去 1。我决定保持现状。


我决定训练 10 个轮数,正如你将看到的,这是很多轮数。


model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])num_epochs = 10history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)
复制代码



“图 5”


def plot_graphs(history, string):  plt.plot(history.history[string])  plt.plot(history.history['val_'+string])  plt.xlabel("Epochs")  plt.ylabel(string)  plt.legend([string, 'val_'+string])  plt.show()plot_graphs(history, "accuracy")plot_graphs(history, "loss")
复制代码


图6


我们可能只需 3 到 4 个轮数。在训练结束时,我们可以发现有点过拟合。


在后续文章中,我们将致力于改进这一模型。


你可以在 Github 找到本文的 Jupyter notebook


参考文献:


作者介绍:

Susan Li,是加拿大多伦多的高级数据科学家。她的理想是,每次发表文章,就改变世界。


原文链接:


https://towardsdatascience.com/multi-class-text-classification-with-lstm-using-tensorflow-2-0-d88627c10a35


2019-12-12 08:002725
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 532.7 次阅读, 收获喜欢 1976 次。

关注

评论

发布
暂无评论
发现更多内容

以海洋为主题的元宇宙Aquqnee,为GameFi带来新的标杆

BlockChain先知

【OB实践】意出望外的一次相遇|利楚初探 OceanBase

OceanBase 数据库

oceanbase

2022年最新面试手册,在Github爆火,96人拿下大厂offer

爱好编程进阶

Java 面试 后端开发

Java流程控制语句-分支结构(选择结构)

爱好编程进阶

Java 面试 后端开发

Kubernetes中,微服务自动化发布系统详解

爱好编程进阶

Java 面试 后端开发

Element-UI 要怎么学?官方文档!

爱好编程进阶

Java 面试 后端开发

Flink on Yarn三部曲之二:部署和设置

爱好编程进阶

Java 面试 后端开发

Java应届生如何找到心仪工作?只要你啃透这些大厂必问面试题,Offer拿到手软

爱好编程进阶

Java 面试 后端开发

Java进阶之路:看完这篇Kubernetes的深入分析后,我完全掌握了这门技术

爱好编程进阶

Java 面试 后端开发

Java面试经验

爱好编程进阶

Java 面试 后端开发

java后台开发面试题

爱好编程进阶

Java 面试 后端开发

为什么要对我们的sql进行优化

乌龟哥哥

4月月更

iReport 使用手册(生成 PDF 表单)

爱好编程进阶

Java 面试 后端开发

kubebuilder实战之三:基础知识速览

爱好编程进阶

Java 面试 后端开发

1-4 云商城项目工程搭建

爱好编程进阶

Java 面试 后端开发

InnoDB 和 MyISAM 的数据分布是什么样的?

爱好编程进阶

Java 面试 后端开发

JSP实现医院住院管理系统

爱好编程进阶

Java 面试 后端开发

Java-教你简单玩扑克

爱好编程进阶

Java 面试 后端开发

Java架构师进阶必备24种设计模式学习资源,速速看过来!

爱好编程进阶

Java 面试 后端开发

Bootstrap.yml的作用

Rubble

4月日更 4月月更

ElasticSearch java API - 聚合查询

爱好编程进阶

Java 面试 后端开发

Elasticsearch的安装和基本使用

爱好编程进阶

Java 面试 后端开发

HDU-3038-How Many Answers Are Wrong【 带权并查集 】题解

爱好编程进阶

Java 面试 后端开发

Java 里面的异常

爱好编程进阶

Java 面试 后端开发

Canal 如何实现数据库库事务的一致性

爱好编程进阶

Java 面试 后端开发

ETCD 安全模式

爱好编程进阶

Java 面试 后端开发

IntelliJ Idea 常用快捷键列表

爱好编程进阶

Java 面试 后端开发

java内存溢出问题分析过程

爱好编程进阶

Java 面试 后端开发

java没有那么难,跟着我一起看看java 条件语句

爱好编程进阶

Java 面试 后端开发

亚信科技两方案入围工信部“数字技术融合创新解决方案”评选

亚信AntDB数据库

AntDB #数据库 奖项

GitHub上已获赞百万!阿里架构师最新发布的图解网络协议文档(2021版)开源分享

爱好编程进阶

Java 面试 后端开发

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类_AI&大模型_Susan Li_InfoQ精选文章