写点什么

运用计算图搭建递归神经网络(RNN)

  • 2019-09-17
  • 本文字数:5544 字

    阅读完需:约 18 分钟

运用计算图搭建递归神经网络(RNN)

继续玩我们的计算图框架。这一次我们运用计算图搭建递归神经网络(RNN,Recursive Neural Network)。RNN 处理前后有承接关系的序列状数据,例如时序数据。当然,前后的承接也不一定是时间上的,但总之是有前后关系的序列。

RNN

RNN 的思想是:网络也分步,每步以输入序列的该步数据(向量)和上一步数据(第一步没有)为输入,进行变换,得到这一步的输出(向量)。这样的话,序列的每一步就会对下一步产生影响。RNN 用变换的参数把握序列每一步之间的关系。最后一步的输出可以送给全连接层,最终用于分类或回归。RNN 有很多种,有一些复杂的变体,本文搭建一种最简单的 RNN ,它的结构是这样的:



蓝色长条表示 m 维输入向量,一共 n 个。这表示数据是长度为 n 的序列,每一步是一个 m 维向量。绿色的矩形就是每一步的变换。yi 是每一步的 k 维输出向量。每一步用 k x k 的权值矩阵 Y 去乘前一步的输出向量(第一步没有),用 k x m 的权值矩阵 W 去乘这一步的输入向量,加和后再加上 k 维偏置向量 b ,施加激活函数 ϕ (我们取 ReLU),就得到这一步的输出。


最后一步的输出也是 k 维向量,把它送给全连接层,最后施加 SoftMax 后得到各个类别的概率,再接上一个交叉熵损失就可以用来训练分类问题了。用我们的计算图框架可以这样搭建这个简单的 RNN(代码):


seq_len = 96  # 序列长度dimension = 16  # 序列每一步的向量维度hidden_dim = 12  # RNN 时间单元的输出维度
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
复制代码


这就是构造 RNN 以及交叉熵损失的计算图的代码,很简单,right ?有了计算图以及自动求导,我们只管搭建网络即可,网络的训练就交给计算图去做了。否则你可以想象,按照示意图表示的计算,推导交叉熵损失对 RNN 的各个权值矩阵和偏置的梯度是多么困难。

时间序列问题

我们构造一份数据,它包含两类时间序列,一类是方波,一类是正弦波,代码如下:


def get_sequence_data(number_of_classes=2, dimension=10, length=10, number_of_examples=1000, train_set_ratio=0.7, seed=42):    """    生成两类序列数据。    """    xx = []    xx.append(np.sin(np.arange(0, 10, 10 / length)))  # 正弦波    xx.append(np.array(signal.square(np.arange(0, 10, 10 / length))))  # 方波

data = [] for i in range(number_of_classes): x = xx[i] for j in range(number_of_examples): sequence = x + np.random.normal(0, 1.0, (dimension, len(x))) # 加入高斯噪声 label = np.array([int(i == j) for j in range(number_of_classes)])
data.append(np.c_[sequence.reshape(1, -1), label.reshape(1, -1)])
# 把各个类别的样本合在一起 data = np.concatenate(data, axis=0)
# 随机打乱样本顺序 np.random.shuffle(data)
# 计算训练样本数量 train_set_size = int(number_of_examples * train_set_ratio) # 训练集样本数量
# 将训练集和测试集、特征和标签分开 return (data[:train_set_size, :-number_of_classes], data[:train_set_size, -number_of_classes:], data[train_set_size:, :-number_of_classes], data[train_set_size:, -number_of_classes:])
复制代码


我们用这一行代码获取长度为 96 ,维度为 16 的两类(各 1000 个)序列:


# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
复制代码


看一看时间序列样本,先看正弦波:



正弦波序列


这是一个正弦波时间序列样本,它包含 16 条曲线,每一条都是 sin 曲线加噪声。之所以包含 16 条曲线,因为我们的时间序列的每一步是一个 16 维向量,按时间列起来就有了 16 条正弦曲线。正弦波时间序列是我们的正样本。方波时间序列是负样本:



方波序列


一个方波时间序列先维持 +1 一段时间,变为 -1 维持一段时间,再回到 +1 ,循环往复。由于我们的高斯噪声加得较大,可以看到正弦波和方波还是有可能混淆的,但也能看出它们之间的差异。

训练

现在就用我们构造的 RNN 训练一个分类模型,分类正弦波和方波,代码如下:


from sklearn.metrics import accuracy_score
from layer import *from node import *from optimizer import *
seq_len = 96 # 序列长度dimension = 16 # 序列每一步的向量维度hidden_dim = 12 # RNN 时间单元的输出维度
# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
# Adam 优化器optimizer = Adam(default_graph, loss, 0.005, batch_size=16)
# 训练print("start training", flush=True)for e in range(10):
for i in range(len(train_x)): x = np.mat(train_x[i, :]).reshape(dimension, seq_len) for j in range(seq_len): input_vectors[j].set_value(x[:, j]) label.set_value(np.mat(train_y[i, :]).T)
# 执行一步优化 optimizer.one_step()
if i > 1 and (i + 1) % 100 == 0:
# 在测试集上评估模型正确率 probs = [] losses = [] for j in range(len(test_x)): # x = test_x[j, :].reshape(dimension, seq_len) x = np.mat(test_x[j, :]).reshape(dimension, seq_len) for k in range(seq_len): input_vectors[k].set_value(x[:, k]) label.set_value(np.mat(test_y[j, :]).T)
# 前向传播计算概率 prob.forward() probs.append(prob.value.A1)
# 计算损失值 loss.forward() losses.append(loss.value[0, 0])
# print("test instance: {:d}".format(j))
# 取概率最大的类别为预测类别 pred = np.argmax(np.array(probs), axis=1) truth = np.argmax(test_y, axis=1) accuracy = accuracy_score(truth, pred)
default_graph.draw() print("epoch: {:d}, iter: {:d}, loss: {:.3f}, accuracy: {:.2f}%".format(e + 1, i + 1, np.mean(losses), accuracy * 100), flush=True)
复制代码


训练 10 个 epoch 后,测试集上的正确率达到了 99% :


epoch: 1, iter: 100, loss: 0.693, accuracy: 51.08%epoch: 1, iter: 200, loss: 0.692, accuracy: 51.08%epoch: 1, iter: 300, loss: 0.677, accuracy: 78.31%epoch: 1, iter: 400, loss: 0.573, accuracy: 49.31%epoch: 1, iter: 500, loss: 0.520, accuracy: 53.92%epoch: 1, iter: 600, loss: 0.599, accuracy: 97.08%epoch: 1, iter: 700, loss: 0.617, accuracy: 99.00%epoch: 2, iter: 100, loss: 0.601, accuracy: 94.46%epoch: 2, iter: 200, loss: 0.579, accuracy: 82.08%epoch: 2, iter: 300, loss: 0.558, accuracy: 76.15%epoch: 2, iter: 400, loss: 0.531, accuracy: 67.85%epoch: 2, iter: 500, loss: 0.507, accuracy: 63.77%epoch: 2, iter: 600, loss: 0.493, accuracy: 61.15%epoch: 2, iter: 700, loss: 0.479, accuracy: 62.23%epoch: 3, iter: 100, loss: 0.443, accuracy: 69.92%epoch: 3, iter: 200, loss: 0.393, accuracy: 85.85%epoch: 3, iter: 300, loss: 0.365, accuracy: 97.69%epoch: 3, iter: 400, loss: 0.284, accuracy: 95.08%epoch: 3, iter: 500, loss: 0.199, accuracy: 95.69%epoch: 3, iter: 600, loss: 0.490, accuracy: 80.62%epoch: 3, iter: 700, loss: 0.264, accuracy: 94.31%epoch: 4, iter: 100, loss: 0.320, accuracy: 83.46%epoch: 4, iter: 200, loss: 0.333, accuracy: 80.92%epoch: 4, iter: 300, loss: 0.276, accuracy: 90.15%epoch: 4, iter: 400, loss: 0.242, accuracy: 95.00%epoch: 4, iter: 500, loss: 0.217, accuracy: 96.38%epoch: 4, iter: 600, loss: 0.191, accuracy: 95.31%epoch: 4, iter: 700, loss: 0.167, accuracy: 94.00%epoch: 5, iter: 100, loss: 0.142, accuracy: 94.62%epoch: 5, iter: 200, loss: 0.111, accuracy: 96.85%epoch: 5, iter: 300, loss: 0.116, accuracy: 96.85%epoch: 5, iter: 400, loss: 0.080, accuracy: 96.77%epoch: 5, iter: 500, loss: 0.059, accuracy: 98.54%epoch: 5, iter: 600, loss: 0.054, accuracy: 98.54%epoch: 5, iter: 700, loss: 0.042, accuracy: 99.00%epoch: 6, iter: 100, loss: 0.047, accuracy: 98.46%epoch: 6, iter: 200, loss: 0.049, accuracy: 98.08%epoch: 6, iter: 300, loss: 0.030, accuracy: 99.15%epoch: 6, iter: 400, loss: 0.029, accuracy: 99.23%epoch: 6, iter: 500, loss: 0.028, accuracy: 99.08%epoch: 6, iter: 600, loss: 0.029, accuracy: 99.08%epoch: 6, iter: 700, loss: 0.024, accuracy: 99.15%epoch: 7, iter: 100, loss: 0.023, accuracy: 99.15%epoch: 7, iter: 200, loss: 0.031, accuracy: 98.85%epoch: 7, iter: 300, loss: 0.023, accuracy: 99.46%epoch: 7, iter: 400, loss: 0.022, accuracy: 99.54%epoch: 7, iter: 500, loss: 0.022, accuracy: 99.38%epoch: 7, iter: 600, loss: 0.027, accuracy: 98.77%epoch: 7, iter: 700, loss: 0.019, accuracy: 99.46%epoch: 8, iter: 100, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 200, loss: 0.018, accuracy: 99.46%epoch: 8, iter: 300, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 400, loss: 0.018, accuracy: 99.62%epoch: 8, iter: 500, loss: 0.017, accuracy: 99.54%epoch: 8, iter: 600, loss: 0.026, accuracy: 99.00%epoch: 8, iter: 700, loss: 0.021, accuracy: 99.23%epoch: 9, iter: 100, loss: 0.017, accuracy: 99.62%epoch: 9, iter: 200, loss: 0.016, accuracy: 99.54%epoch: 9, iter: 300, loss: 0.015, accuracy: 99.54%epoch: 9, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 500, loss: 0.014, accuracy: 99.62%epoch: 9, iter: 600, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 700, loss: 0.014, accuracy: 99.62%epoch: 10, iter: 100, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 200, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 300, loss: 0.015, accuracy: 99.69%epoch: 10, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 10, iter: 500, loss: 0.013, accuracy: 99.62%epoch: 10, iter: 600, loss: 0.016, accuracy: 99.38%epoch: 10, iter: 700, loss: 0.017, accuracy: 99.38%
复制代码


这就是我们的简单 RNN ,以后有机会我们再尝试搭建类似 LSTM 这种更复杂的 RNN 。


作者介绍


张觉非,本科毕业于复旦大学,硕士毕业于中国科学院大学,先后任职于新浪微博、阿里,目前就职于奇虎 360,任机器学习技术专家。


本文来自 DataFun 社区


原文链接


https://mp.weixin.qq.com/s?__biz=MzU1NTMyOTI4Mw==&mid=2247493606&idx=1&sn=bf89adb739302688e6b837084bff911a&chksm=fbd7558acca0dc9c6a9754975ee796239b5fa2c26a38f604c56d19a5189f8a0febd75698ddd7&scene=27#wechat_redirect


2019-09-17 08:001168

评论

发布
暂无评论
发现更多内容

企业实施知识管理建设的7条建议

小炮

【网易云商】TypeScript 进阶指南,突破基本类型

网易智企

typescript

GraphQL初探

RingCentral铃盛

JavaScript graphql

域名被劫持应该如何处理

源字节1号

软件开发

华为AppCube通过中国信通院“低代码开发平台通用能力要求”评估!

华为云开发者联盟

低代码 华为云 AppCube

详解SQL操作的窗口函数

华为云开发者联盟

sql 窗口函数 AP场景

低碳数据中心建设思路及未来趋势

H3C-Navigator

喜讯!「凡泰极客」中标「廊坊银行」小程序平台应用建设项目

FinClip

小程序 finclip 廊坊银行

敏捷领导力(CAL E+T+O)认证在线培训 | 2022年8月18-20日

ShineScrum

敏捷 敏捷领导力 CAL 世界级敏捷领导力大师

F5 NGINX 核心人员倾力打造,搞懂 NGINX 这一本就够了

图灵教育

nginx 程序员 服务器 计算机

一键式打造DAO,M-DAO或成Web3新宠儿

西柚子

元宇宙用户已准备就绪,但技术瓶颈仍制约其真正“落地”

CECBC

Flink 1.15 新功能架构解析:高效稳定的通用增量 Checkpoint

Apache Flink

大数据 flink 编程 流计算 实时计算

【等保】等保测评中双因素认证是什么意思?等于双因子认证吗?

行云管家

网络安全 等保 双因子认证 等级保护

AI简报-增强版GAN图像超分:ESRGAN

AIWeker

人工智能 深度学习 5月月更 AI简报

加码布局版式文档垂直赛道,福昕船舶图纸管理系统重磅发布

联营汇聚

回顾|Flink CDC Meetup(附 PPT 下载)

Apache Flink

大数据 flink 编程 流计算 实时计算

深度学习六十年简史

OneFlow

人工智能 机器学习 深度学习

揭秘亚马逊云科技软件开发工程师团队

亚马逊云科技 (Amazon Web Services)

软件开发 工程师

幸运哈希defi游戏系统开发方案(防作弊)

开发微hkkf5566

锅圈如何利用 Zadig 从容落地运维容器化建设

Zadig

DevOps 云原生 CI/CD 持续交付

二、KVM架构概述

穿过生命散发芬芳

kvm 5月月更

【云堡垒机】云堡垒机很贵吗?怎么收费?

行云管家

网络安全 数据安全 堡垒机 云堡垒机

【技术干货】代码示例:使用 Apache Flink 连接 TDengine

TDengine

数据库 tdengine

web前端培训复盘30+技术点(满满干货,建议收藏)

@零度

前端开发

监控系统报警级别设定

焦振清

监控系统 报警级别

数据库治理的云原生之道 —— Database Mesh 2.0

SphereEx

Apache 数据库 开源 ShardingSphere SphereEx

阿里云移动研发平台EMAS:4月产品动态更新

移动研发平台EMAS

阿里云 用户增长 研发工具 移动测试 移动推送

4种Springboot RestTemplate 服务里发送HTTP请求用法

华为云开发者联盟

Java Rest HTTP

带你学习MindSpore中算子使用方法

华为云开发者联盟

模型 mindspore 算子

大数据培训用SQL来实现用户行为漏斗分析

@零度

大数据开发

运用计算图搭建递归神经网络(RNN)_文化 & 方法_DataFunTalk_InfoQ精选文章