QCon北京「鸿蒙专场」火热来袭!即刻报名,与创新同行~ 了解详情
写点什么

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

  • 2019-12-12
  • 本文字数:7860 字

    阅读完需:约 26 分钟

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

文本分类是指将给定文本按照其内容判别到一个或多个预先确定的文本类别中的过程。文本分类是一种典型的有监督的学习过程,根据已经被标记的文本集合,通过学习,得到一个文本特征和文本类别之间的关系模型,然后利用这个关系模型对新文本进行类别判断。文本分类计数用于识别文档主题,并将之归类到预先定义的主题或主题集合中。

需要注意的是,多类文本分类与多标签分类并不同,其中多类分类区别于二分类问题,即在 个类别中互斥地选取一个作为输出;而多标签分类,是在 n 个标签中非互斥地选取 个标签作为输出。本文介绍了如何基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类,非常实用,希望对读者有所启迪。


对自然语言处理(Natural Language Processing,NLP)领域来说,很多创新之处都是关于如何在词向量中加入上下文。常用的方法之一就是使用递归神经网络(Recurrent Neural Networks,RNN)。下面是递归神经网络的概念:


  • 它们利用顺序信息。

  • 它们具备记忆能力,能够记住到目前为止计算过的内容,也就是说,我最后说的内容将影响我接下来要讲的内容。

  • 递归神经网络是文本和语音分析的理想选择。

  • 最常用的递归神经网络是长短期记忆网络(Long-Short Term Memory,LSTM)。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/


上图是递归神经网络的架构。


  • “A” 是前馈神经网络(Feedforward neural network)的一层。

  • 如果我们只看右边的话,它会递归地遍历每个序列的元素。

  • 如果我们将左边展开,它看起来将会跟右边一模一样。


译注: 前馈神经网络(Feedforward neural network),是最早发明、最简单的人工神经网络类型。在它内部,参数从输入层向输出层单向传播。和递归神经网络不通,它内部不会构成有向环。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


假设我们正在解决新闻文章数据集的文档分类问题。


  • 我们输入每个单词,这些单词以某种方式相互关联。

  • 当我们看到文章中所有的单词时,我们会在文章末尾做出预测。

  • 递归神经网络通过传递上一次输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


  • 这对于短句很有效,但当我们处理一篇长文章时,将会有一个长期依赖问题。


因此,我们通常不是用普通的递归神经网络,而是使用长短期记忆网络。长短期记忆网络是一种递归神经网络,可以解决这种长期依赖问题。


译注: 长短期记忆网络(Long Short-Term Memory,LSTM),是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。基于长短期记忆网络的系统可以实现机器翻译、视频分析、文档摘要、语音识别、图像识别、手写识别、控制聊天机器人、合成音乐等任务。



在我们的新闻文章文档分类示例中,有这种多对一的关系。输入是单词序列,而输出是单个类或标签。


现在,我们将使用 TensorFlow 2.0Keras,解决一个使用长短期记忆网络的 BBC 新闻文档分类问题。数据集可以点击此链接来获取。


  • 首先,我们导入库,并确保 TensorFlow 是正确的版本。


import csvimport tensorflow as tfimport numpy as npfrom tensorflow.keras.preprocessing.text import Tokenizerfrom tensorflow.keras.preprocessing.sequence import pad_sequencesfrom nltk.corpus import stopwordsSTOPWORDS = set(stopwords.words('english'))print(tf.__version__)
复制代码



  • 将超参数置于顶部,如下所示,便于进行更改和编辑。

  • 届时,我们将会讲解每个超参数是如何工作的。


vocab_size = 5000embedding_dim = 64max_length = 200trunc_type = 'post'padding_type = 'post'oov_tok = '<OOV>'training_portion = .8
复制代码


  • 定义两个包含文章和标签的列表。同时,我们删除了停用词。


articles = []labels = []with open("bbc-text.csv", 'r') as csvfile:    reader = csv.reader(csvfile, delimiter=',')    next(reader)    for row in reader:        labels.append(row[0])        article = row[1]        for word in STOPWORDS:            token = ' ' + word + ' '            article = article.replace(token, ' ')            article = article.replace(' ', ' ')        articles.append(article)print(len(labels))print(len(articles))
复制代码



数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。


train_size = int(len(articles) * training_portion)train_articles = articles[0: train_size]train_labels = labels[0: train_size]validation_articles = articles[train_size:]validation_labels = labels[train_size:]print(train_size)print(len(train_articles))print(len(train_labels))print(len(validation_articles))print(len(validation_labels))
复制代码



词法分析器(Tokenizer)为我们承担了所有繁重的工作。在我们的文章中,它将进行标记化,需要 5000 个最常见的单词。oov_token 是在遇到不可见的单词时放入一个特殊的值。这意味着我们希望 <OOV> 用于不在 word_index 中的单词。fit_on_text 将遍历所有文本,并创建如下词典:


tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)tokenizer.fit_on_texts(train_articles)word_index = tokenizer.word_indexdict(list(word_index.items())[0:10])
复制代码


译注: 词法分析器(Tokenizer),是计算机科学中将字符串行转换为标记(token)串行的过程。进行词法分析的进程或者函数叫作词法分析器(lexical analyzer,简称 lexer),也叫扫描器(scanner)。词法分析器一般以函数的形式存在,供语法分析器调用。



我们可以看到,“”是我们语料库中最常见的令牌,其次是“said”、“mr”等等。


完成标记化之后,下一步就是将这些标记转换为序列列表。下面是已经转换成序列的训练数据中的第 11 篇文章。


train_sequences = tokenizer.texts_to_sequences(train_articles)print(train_sequences[10])



" 图 1"


当我们为自然语言处理训练神经网络时,我们需要相同大小的序列,这就是我们为什么使用填充的原因。如果你查看一下的话,就会发现,我们的 max_length 是 200,所以我们使用 pad_sequences ,将所有文章的长度都设置为 200。结果,你会看到第一篇文章长度为 426,变成了 200;第二篇是 192,也变成了 200。以此类推。


train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(train_sequences[0]))print(len(train_padded[0]))print(len(train_sequences[1]))print(len(train_padded[1]))print(len(train_sequences[10]))print(len(train_padded[10]))
复制代码



此外,还有 padding_typetruncating_type, 还有所有的 post,例如,第 11 篇文章的长度是 186,我们需要填充到 200,我们就在结尾处开始填充,也就是说,填充了 14 个 0。


print(train_padded[10])



" 图 2"


对于第一篇文章,它的长度为 426,我们需要将其截断到 200,我们就在结尾处截断。


然后,我们对验证序列执行同样的操作。


validation_sequences = tokenizer.texts_to_sequences(validation_articles)validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(validation_sequences))print(validation_padded.shape)
复制代码



现在,我们来看一下标签。因为我们的标签是文本,因此,我们将它们进行标记。在训练时,标签应该是 numpy 数组。所以,我们要将标签列表转换为 numpy 数组,如下所示:


label_tokenizer = Tokenizer()label_tokenizer.fit_on_texts(labels)training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))print(training_label_seq[0])print(training_label_seq[1])print(training_label_seq[2])print(training_label_seq.shape)print(validation_label_seq[0])print(validation_label_seq[1])print(validation_label_seq[2])print(validation_label_seq.shape)
复制代码



在训练深度神经网络之前,我们应该探索一下我们的原始文章和填充后的文章是什么样子的。运行下面的代码,我们浏览第 11 篇文章,可以看到,一些单词变成了“”,因为它们没有进入前 5000。


reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])def decode_article(text):    return ' '.join([reverse_word_index.get(i, '?') for i in text])print(decode_article(train_padded[10]))print('---')print(train_articles[10])
复制代码



“图 3”


现在,是实施长短期记忆网络的时候了。


  • 我们构建了一个 tf.keras.Sequential 模型,从嵌入层开始。嵌入层为每个单词存储一个向量。调用时,它将单词索引序列转换为向量序列。经过训练后,具有相似意义的单词,通常会具有相似的向量。

  • 双向包装器(Bidirectional wrapper)与 LSTM 层一起使用,它通过 LSTM 层向前和向后传播输入,然后连接输出。这有助于长短期记忆网络学习长期依赖关系。然后我们将其拟合到密集神经网络(Dense Neural Network)中进行分类。

  • 我们使用 relu 代替 than 函数,因为这两个函数能够彼此很好地相互替代。

  • 我们添加了 6 个单位和 softmax 激活的密集层(Dense Layer)。当我们有多个输出时,softmax 将输出层转换为概率分布。


model = tf.keras.Sequential([    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top    tf.keras.layers.Embedding(vocab_size, embedding_dim),    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),    # use ReLU in place of tanh function since they are very good alternatives of each other.    tf.keras.layers.Dense(embedding_dim, activation='relu'),    # Add a Dense layer with 6 units and softmax activation.    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.    tf.keras.layers.Dense(6, activation='softmax')])model.summary()
复制代码



“图 4”


在我们的模型摘要中,我们有嵌入,双向包含长短期记忆网络,然后就是两个密集层(Dense layer)。双向的输出为 128,因为它是我们在长短期记忆网络中输入的两倍。我们也可以堆叠 LSTM 层,但我们发现,结果反而更糟。


print(set(labels))



我们总共有 5 个标签,但因为我们没有对标签进行独热编码(One-hot encode),因此,我们不得不使用


sparse_categorical_crossentropy 作为损失函数,它似乎认为 0 也是一个可能的标签,而词法分析器对象是从整数 1 开始标记化,而不是整数 0。结果,尽管从未使用过 0,但最后一个密集层需要标签 0、1、2、3、4、5 的输出。


如果你希望最后一个密集层为 5,那么你就需要从训练和验证标签中减去 1。我决定保持现状。


我决定训练 10 个轮数,正如你将看到的,这是很多轮数。


model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])num_epochs = 10history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)
复制代码



“图 5”


def plot_graphs(history, string):  plt.plot(history.history[string])  plt.plot(history.history['val_'+string])  plt.xlabel("Epochs")  plt.ylabel(string)  plt.legend([string, 'val_'+string])  plt.show()plot_graphs(history, "accuracy")plot_graphs(history, "loss")
复制代码


图6


我们可能只需 3 到 4 个轮数。在训练结束时,我们可以发现有点过拟合。


在后续文章中,我们将致力于改进这一模型。


你可以在 Github 找到本文的 Jupyter notebook


参考文献:


作者介绍:

Susan Li,是加拿大多伦多的高级数据科学家。她的理想是,每次发表文章,就改变世界。


原文链接:


https://towardsdatascience.com/multi-class-text-classification-with-lstm-using-tensorflow-2-0-d88627c10a35


2019-12-12 08:002795
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 554.8 次阅读, 收获喜欢 1978 次。

关注

评论

发布
暂无评论
发现更多内容

2021年Android网络编程总结篇,retrofit面试

android 面试 移动开发

2021年Java开发突破20k有哪些有效的路径,2021Java面试笔试总结

Java 面试 后端

2021年Java程序员职业规划,华为Java面试题目

Java 面试 后端

2021年Java笔试题总,教你抓住面试的重点

Java 面试 后端

2021年Android开发者常见面试题,涨薪7K

android 面试 移动开发

2021年Android开发者跳槽指南,附超全教程文档

android 面试 移动开发

2021年Java网络编程总结篇,红黑树详细分析(图文详解)

Java 面试 后端

【等保知识】十个等保常见问题解答汇总

行云管家

网络安全 信息安全 等级保护 过等保 数据审计

2021年Java开发突破20k有哪些有效的路径,JVM发生内存溢出的8种原因

Java 面试 后端

2021年Android程序员职业规划,小白勿进

android 面试 移动开发

2021年Java者未来的出路在哪里,Java开发校招面试题

Java 面试 后端

2021年Android笔试题总,详解Android架构进阶面试题

android 面试 移动开发

Github上线仅六天,收获Star超55K+,这套笔记足够你拿下90%以上的Java面试!

Java 架构 面试 后端 计算机

2021年Java者未来的出路在哪里,让人抓狂的Nginx性能调优

Java 面试 后端

IT运维和自动化运维以及运维开发有啥不同?能解释下吗?

行云管家

互联网 运维 IT运维 自动化运维 云运维

2021年Java常见面试题,面试官让我回家等通知

Java 面试 后端

2021年Android社招面试题精选,附答案解析

android 面试 移动开发

2021年Java开发者常见面试题,初级Java面试题及答案

Java 面试 后端

2021年Android开发陷入饱和,又是一年金九银十

android 面试 移动开发

2021年Java开发前景如何,大厂Java面试真题精选

Java 面试 后端

2021年Java面经分享,别再说你不会JVM性能监控和调优了

Java 面试 后端

2021年Android程序员职业规划,阿里P7大牛亲自讲解

android 面试 移动开发

代码检查规则背景及总体介绍

百度开发者中心

最佳实践 代码规则

2021年Java面经分享,程序员必备技能:时间复杂度与空间复杂度的计算

Java 面试 后端

2021年Java面试心得,整理出这份8万字Java性能优化实战解析

Java 面试 后端

数据库排行榜|当 DB-Engines 遇见墨天轮国产数据库排行

墨天轮

MySQL 数据库 oracle TiDB 国产数据库

2021年Android社招面试题,阿里蚂蚁金服五面

android 面试 移动开发

对比会声会影与剪映哪个制作转场效果更专业

懒得勤快

2021年Java工作或更难找,springboot源码解读与原理分析

Java 面试 后端

2021年Java工作或更难找,华为Java面试社招

Java 面试 后端

2021年Java技术下半场在哪,35岁技术人如何转型做管理

Java 面试 后端

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类_AI&大模型_Susan Li_InfoQ精选文章