如何应用TensorFlow实现常用循环神经网络

2018 年 11 月 15 日

如何应用TensorFlow实现常用循环神经网络

前言


循环神经网络(Recurrent Neural Network,RNN)是用于处理序列化数据的神经网络。常用的序列化数据有股票市场价格、音频和视频数据、DNA 序列、传感器数据、自然语言文本等。在循环神经网络中,数据的输出结果与数据的序列关系有关。具体的表现形式为网络会对前面的信息进行记忆,并应用于当前输出结果。例如,预测句子中的下一个单词是什么时,由于句子中前后单词并不是相互独立的,所以需要用到前面的多个序列单词。本文主要讲解 RNN 结构及实现、双向 RNN 结构及实现、长短时记忆(Long Short-Term Memory,LTSM)结构及实现、门控循环单元(Gated Recurrent Unit,GRU)结构及实现。


RNN 结构及实现


循环神经网络结构如图-1 所示。



图1 循环神经网络结构


#循环神经网络搭建的基本思想#设定循环神经网络单元的初始状态,设定神经元的初始值为0state = cell.zero_state(...)#定义循环神经网络的输出outputs = []#构建循环神经网络,inputs为输入的序列数据for input_ in inputs:    #执行循环神经网络单元,返回单元的输出结果和单元状态    output, state = cell(input_, state)    outputs.append(output)#返回循环神经网络的运行结果和单元的最终状态return (outputs, state)
复制代码


循环神经网络结构主要有一对多关系、多对一关系、多对多关系,如图 2 所示。One to many 结构用于图像生成文字场景,根据输入图像内容输出该图像的文字描述。Many to one 结构用来处理序列数据的分类问题,例如,输入一个句子判断其情感倾向。Many to many 结构用来处理输入序列数据中每个数据的分类,例如中文分词,中文实体识别等场景。如果 many to many 结构中不限制输入序列和输出序列的长度,通常用来处理机器翻译、文本摘要生成、阅读理解、语音识别等场景。


图2 输入序列与输出序列的对应关系


在 TensorFlow 中,循环神经网络单元可用 BasicRNNCell 定义实现,如下所示。其中,hidden_units 为隐藏神经元的数量。


#定义循环神经网络单元hidden_units = 20tf.nn.rnn_cell.BasicRNNCell(num_units = hidden_units)
复制代码


TensorFlow 中循环神经网络的实现有两类,一类是静态 RNN(Static RNN),另一类是动态 RNN(Dynamic RNN)。Static RNN 是按照时间序列长度展开后的图,序列长度需要和图的拓扑结构保持一致。这样,每个 batch data 中的序列长度是一致,为最大的序列长度。Dynamic RNN 主要思想为通过循环语句,动态生成网络的拓扑结构。这样,不同 batch data 中序列的最大长度可以是不一样的,但同一个 batch data 内部的序列长度仍然是一样的。在动态 RNN 结构中,不需要让所有 batch data 的序列长度都填充到序列的最大长度,减少了训练数据的存储空间。Dynamic RNN 的 TensorFlow 模型程序用例,如下所示。


import tensorflow as tf
hidden_units = 20 #定义循环神经网络单元中隐藏神经元的数量rnnLayerNum = 1 #定义循环神经网络的层数rnnCells = [] #定义存储多层RNN单元#构建多层神经网络单元multiRnnCellfor i in range(rnnLayerNum): rnnCells.append(tf.nn.rnn_cell.BasicRNNCell(num_units=hidden_units))multiRnnCell = tf.nn.rnn_cell.MultiRNNCell(rnnCells)
timesteps = 5 #定义循环神经网络的序列步长batch_size = 2 #定义batch data大小#定义循环神经网络的输入数据的占位符,#维度信息为[batch_size, timesteps, 1],1表示序列中每个数据的编码大小input = tf.placeholder(tf.float32, [batch_size, timesteps, 1], name='input_x')sequence_length = [2, 5] #定义batch data中每个序列的有效长度
#定义多层循环神经网络的初始化状态initial_state = multiRnnCell.zero_state(batch_size=2, dtype=tf.float32)#定义动态RNN网络outputs, final_state = tf.nn.dynamic_rnn(multiRnnCell, input, sequence_length=sequence_length, initial_state=initial_state, dtype=tf.float32, time_major=False)
复制代码


双向 RNN 结构及实现


双向 RNN(Bi-directional RNN,BRNN)用于解决序列数据中,序列的输出结果不仅与前面的序列数据有关,还于后面的序列数据有关。例如,根据上下文信息预测一个句子中间缺失的单词是什么;根据上下文信息预测一个句子中每个单词的类别。双向 RNN 结构将两个 RNN 上下叠加在一起,如图 3 所示。


图3 双向RNN网络结构


#单层、双向RNN模型的TensorFlow实现import tensorflow as tf
hidden_units = 20 #RNN单元隐藏神经元的数量#定义Forward Layer的单元forwardCell = tf.nn.rnn_cell.BasicRNNCell(num_units=hidden_units)#定义Backward Layer的单元backwardCell = tf.nn.rnn_cell.BasicRNNCell(num_units=hidden_units)timesteps = 5 #序列步长batch_size = 2 #batch data大小#定义Input Layer输入数据的占位符input = tf.placeholder(tf.float32, [batch_size, timesteps, 1], name='input_x')
#生成bidirectional dynamic rnn的网络结构outputs, output_states = tf.nn.bidirectional_dynamic_rnn(inputs=input, cell_fw=forwardCell, cell_bw=backwardCell, dtype=tf.float32, time_major=False)
复制代码


LSTM 结构及实现


在反向传播算法中,BasicRNNCell 神经元的梯度会呈指数倍数的衰减,梯度倾向于消失。这样,BasicRNNCell 很难处理数据长期依赖的问题,很难处理长度超过 10 的序列数据。为了解决梯度消失问题,提出了长短期记忆(Long Short Term Memory,LSTM)单元,通过门的开关实现序列数据上的记忆功能。当误差从输出层反向传播回来时,使用 LSTM 单元记忆下来,LSTM 可以记住较长时间内的序列信息。LSTM 模块架构,如图 4 所示。



图4 LSTM结构


LSTM 通过门的结构选择性地让神经元中的信息通过,实现数据信息的记忆或遗忘功能。门结构由 sigmoid 激活函数实现。sigmoid 激活函数的输出值为[0, 1]间的数字,表示神经元可以通过多少数据信息。sigmoid 输出值为 0 时,表示门完全关闭,神经元数据信息全部不能通过。sigmoid 输出值为 1 时,表示门全部开启,神经元所有信息都可以通过。LSTM RNN 的 TensorFlow 程序用例如下所示。


import tensorflow as tf
hidden_units = 20 #隐藏神经元数量#定义LSTM单元cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_units)
timesteps = 5; #序列长度batch_size = 2 #batch data大小input = tf.placeholder(tf.float32, [batch_size, timesteps, 1], name='input_x')outputs, states = tf.nn.dynamic_rnn(cell=cell, inputs=input, dtype=tf.float32)
复制代码


GRU 结构及实现


门控循环单元(Gated Recurrent Unit,GRU)同样用于解决循环神经网络中梯度消失问题,属于 LSTM 的一种变体结构。LSTM 有三个门控制结构(输入门、遗忘门和输出门),结构比较复杂。GRU 结构中只有两个门,分为更新门 z_t 和重置门 r_t,直接将单元的隐藏状态传递给下个单元,比 LSTM 模型更加简单、参数更少及更容易收敛。重置门主要用于处理序列数据中短期的记忆关系。更新门主要用于处理序列数据中长期的记忆关系。GRU 门控循环单元结构,如图 5 所示。


图5 GRU门控循环单元结构


######基于GRU单元的RNN######import tensorflow as tf
hidden_units = 20 #隐藏神经元数量#定义GRU单元cell = tf.nn.rnn_cell.GRUCell(num_units=hidden_units)
timesteps = 5; #序列长度batch_size = 2 #batch data大小input = tf.placeholder(tf.float32, [batch_size, timesteps, 1], name='input_x')outputs, states = tf.nn.dynamic_rnn(cell=cell, inputs=input, dtype=tf.float32)
复制代码


总结


本文主要讲解处理序列数据的循环神经网络模型,包含 RNN 架构及实现、双向 RNN 架构及实现、LSTM 架构及实现、GRU 架构及实现。RNN 的网络结构主要有多对多关系、多对一关系和一对多关系。RNN 单元用 BasicRNNCell 定义,RNN 可用静态 RNN 接口和动态 RNN 接口实现。双向 RNN 结构将两个 RNN 上下叠加在一起,主要用于解决序列数据中,序列的输出结果与上下文环境信息有关。


LSTM 单元主要用于解决循环神经网络中梯度消失的问题,通过门的开关实现长序列数据上的记忆功能,主要包含输入门、遗忘门和输出门。GRU 属于 LSTM 的一种变体结构,可直接将单元的隐藏状态传递给下个单元,比 LSTM 模型更加简单、更容易收敛。


作者简介:武维(微信:allawnweiwu),博士,现为 IBM 架构师。主要从事深度学习模型、平台的研发工作。


另外本文作者曾在 InfoQ 平台上发表了一系列的 TensorFlow 文章,我们将这些文章集结成书,具体可以翻阅:《深度学习利器:TensorFlow程序设计》。


2018 年 11 月 15 日 15:551055

评论

发布
暂无评论
发现更多内容

每周学习总结 - 架构师培训 6期

Damon

进程、线程基础知识全家桶,30 张图一套带走

小林coding

Linux 操作系统 计算机基础 进程 进程线程区别

编程核心能力之抽象

顿晓

抽象 编程日课

Cache解决算法 Charles断点调试breakpoint John 易筋 ARTS 打卡 Week 08

John(易筋)

ARTS 打卡计划

MySQL实战45讲总结

`

MySQL

SpringBoot 入门:03 - 统一请求返回

封不羁

Java spring springboot

vue项目发布时去除console语句

网站,小程序,APP开发定制

架构师训练营第六周学习总结

CATTY

昆明市成立两大“高端”中心,区块链赋能生物医药和高原特色农业

CECBC区块链专委会

数据驱动 vs 关键字驱动:对UI自动化测试框架搭建的探索

Winfield

DevOps 敏捷 自动化测试

极客时间 - 架构师培训 - 6 期作业

Damon

redis系列之——高可用(主从、哨兵、集群)

诸葛小猿

redis redis集群 redis哨兵 redis主从

简述CAP理论

lei Shi

手把手整合SSM框架

JavaPub

每周学习总结 - 架构师培训 5 期

Damon

【计算机网络】为什么要三次握手四次挥手?

烫烫烫个喵啊

TCP 计算机网络

ARTS WEEK5

紫枫

ARTS 打卡计划

负载均衡方式

羽球

负载均衡

程序的机器级表示-程序的编码

引花眠

计算机基础

架构师课程第六周 作业

杉松壁

观智能化浪潮如何改变产业链创新

CECBC区块链专委会

ARTS-WEEK6

一周思进

ARTS 打卡计划

Go:Stringer命令,通过代码生成提高效率

陈思敏捷

go golang stringer

ARTS打卡 第7周

引花眠

ARTS 打卡计划

低代码与无代码

lidaobing

低代码 无代码开发

智慧4S店解决方案发布,看英特尔如何引领汽车销售行业变革

飞天鱼2017

一致性hash算法及标准差验证

Damon

【计算机网络】如何实现可靠数据传输?

烫烫烫个喵啊

ARTS打卡-06

Geek_yansheng25

ARTS打卡 - Week 07

teoking

分布式系统设计理念这么难学?

架构师修行之路

架构 分布式

如何应用TensorFlow实现常用循环神经网络
-InfoQ