HarmonyOS开发者限时福利来啦!最高10w+现金激励等你拿~ 了解详情
写点什么

运用计算图搭建递归神经网络(RNN)

  • 2019-09-17
  • 本文字数:5544 字

    阅读完需:约 18 分钟

运用计算图搭建递归神经网络(RNN)

继续玩我们的计算图框架。这一次我们运用计算图搭建递归神经网络(RNN,Recursive Neural Network)。RNN 处理前后有承接关系的序列状数据,例如时序数据。当然,前后的承接也不一定是时间上的,但总之是有前后关系的序列。

RNN

RNN 的思想是:网络也分步,每步以输入序列的该步数据(向量)和上一步数据(第一步没有)为输入,进行变换,得到这一步的输出(向量)。这样的话,序列的每一步就会对下一步产生影响。RNN 用变换的参数把握序列每一步之间的关系。最后一步的输出可以送给全连接层,最终用于分类或回归。RNN 有很多种,有一些复杂的变体,本文搭建一种最简单的 RNN ,它的结构是这样的:



蓝色长条表示 m 维输入向量,一共 n 个。这表示数据是长度为 n 的序列,每一步是一个 m 维向量。绿色的矩形就是每一步的变换。yi 是每一步的 k 维输出向量。每一步用 k x k 的权值矩阵 Y 去乘前一步的输出向量(第一步没有),用 k x m 的权值矩阵 W 去乘这一步的输入向量,加和后再加上 k 维偏置向量 b ,施加激活函数 ϕ (我们取 ReLU),就得到这一步的输出。


最后一步的输出也是 k 维向量,把它送给全连接层,最后施加 SoftMax 后得到各个类别的概率,再接上一个交叉熵损失就可以用来训练分类问题了。用我们的计算图框架可以这样搭建这个简单的 RNN(代码):


seq_len = 96  # 序列长度dimension = 16  # 序列每一步的向量维度hidden_dim = 12  # RNN 时间单元的输出维度
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
复制代码


这就是构造 RNN 以及交叉熵损失的计算图的代码,很简单,right ?有了计算图以及自动求导,我们只管搭建网络即可,网络的训练就交给计算图去做了。否则你可以想象,按照示意图表示的计算,推导交叉熵损失对 RNN 的各个权值矩阵和偏置的梯度是多么困难。

时间序列问题

我们构造一份数据,它包含两类时间序列,一类是方波,一类是正弦波,代码如下:


def get_sequence_data(number_of_classes=2, dimension=10, length=10, number_of_examples=1000, train_set_ratio=0.7, seed=42):    """    生成两类序列数据。    """    xx = []    xx.append(np.sin(np.arange(0, 10, 10 / length)))  # 正弦波    xx.append(np.array(signal.square(np.arange(0, 10, 10 / length))))  # 方波

data = [] for i in range(number_of_classes): x = xx[i] for j in range(number_of_examples): sequence = x + np.random.normal(0, 1.0, (dimension, len(x))) # 加入高斯噪声 label = np.array([int(i == j) for j in range(number_of_classes)])
data.append(np.c_[sequence.reshape(1, -1), label.reshape(1, -1)])
# 把各个类别的样本合在一起 data = np.concatenate(data, axis=0)
# 随机打乱样本顺序 np.random.shuffle(data)
# 计算训练样本数量 train_set_size = int(number_of_examples * train_set_ratio) # 训练集样本数量
# 将训练集和测试集、特征和标签分开 return (data[:train_set_size, :-number_of_classes], data[:train_set_size, -number_of_classes:], data[train_set_size:, :-number_of_classes], data[train_set_size:, -number_of_classes:])
复制代码


我们用这一行代码获取长度为 96 ,维度为 16 的两类(各 1000 个)序列:


# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
复制代码


看一看时间序列样本,先看正弦波:



正弦波序列


这是一个正弦波时间序列样本,它包含 16 条曲线,每一条都是 sin 曲线加噪声。之所以包含 16 条曲线,因为我们的时间序列的每一步是一个 16 维向量,按时间列起来就有了 16 条正弦曲线。正弦波时间序列是我们的正样本。方波时间序列是负样本:



方波序列


一个方波时间序列先维持 +1 一段时间,变为 -1 维持一段时间,再回到 +1 ,循环往复。由于我们的高斯噪声加得较大,可以看到正弦波和方波还是有可能混淆的,但也能看出它们之间的差异。

训练

现在就用我们构造的 RNN 训练一个分类模型,分类正弦波和方波,代码如下:


from sklearn.metrics import accuracy_score
from layer import *from node import *from optimizer import *
seq_len = 96 # 序列长度dimension = 16 # 序列每一步的向量维度hidden_dim = 12 # RNN 时间单元的输出维度
# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
# Adam 优化器optimizer = Adam(default_graph, loss, 0.005, batch_size=16)
# 训练print("start training", flush=True)for e in range(10):
for i in range(len(train_x)): x = np.mat(train_x[i, :]).reshape(dimension, seq_len) for j in range(seq_len): input_vectors[j].set_value(x[:, j]) label.set_value(np.mat(train_y[i, :]).T)
# 执行一步优化 optimizer.one_step()
if i > 1 and (i + 1) % 100 == 0:
# 在测试集上评估模型正确率 probs = [] losses = [] for j in range(len(test_x)): # x = test_x[j, :].reshape(dimension, seq_len) x = np.mat(test_x[j, :]).reshape(dimension, seq_len) for k in range(seq_len): input_vectors[k].set_value(x[:, k]) label.set_value(np.mat(test_y[j, :]).T)
# 前向传播计算概率 prob.forward() probs.append(prob.value.A1)
# 计算损失值 loss.forward() losses.append(loss.value[0, 0])
# print("test instance: {:d}".format(j))
# 取概率最大的类别为预测类别 pred = np.argmax(np.array(probs), axis=1) truth = np.argmax(test_y, axis=1) accuracy = accuracy_score(truth, pred)
default_graph.draw() print("epoch: {:d}, iter: {:d}, loss: {:.3f}, accuracy: {:.2f}%".format(e + 1, i + 1, np.mean(losses), accuracy * 100), flush=True)
复制代码


训练 10 个 epoch 后,测试集上的正确率达到了 99% :


epoch: 1, iter: 100, loss: 0.693, accuracy: 51.08%epoch: 1, iter: 200, loss: 0.692, accuracy: 51.08%epoch: 1, iter: 300, loss: 0.677, accuracy: 78.31%epoch: 1, iter: 400, loss: 0.573, accuracy: 49.31%epoch: 1, iter: 500, loss: 0.520, accuracy: 53.92%epoch: 1, iter: 600, loss: 0.599, accuracy: 97.08%epoch: 1, iter: 700, loss: 0.617, accuracy: 99.00%epoch: 2, iter: 100, loss: 0.601, accuracy: 94.46%epoch: 2, iter: 200, loss: 0.579, accuracy: 82.08%epoch: 2, iter: 300, loss: 0.558, accuracy: 76.15%epoch: 2, iter: 400, loss: 0.531, accuracy: 67.85%epoch: 2, iter: 500, loss: 0.507, accuracy: 63.77%epoch: 2, iter: 600, loss: 0.493, accuracy: 61.15%epoch: 2, iter: 700, loss: 0.479, accuracy: 62.23%epoch: 3, iter: 100, loss: 0.443, accuracy: 69.92%epoch: 3, iter: 200, loss: 0.393, accuracy: 85.85%epoch: 3, iter: 300, loss: 0.365, accuracy: 97.69%epoch: 3, iter: 400, loss: 0.284, accuracy: 95.08%epoch: 3, iter: 500, loss: 0.199, accuracy: 95.69%epoch: 3, iter: 600, loss: 0.490, accuracy: 80.62%epoch: 3, iter: 700, loss: 0.264, accuracy: 94.31%epoch: 4, iter: 100, loss: 0.320, accuracy: 83.46%epoch: 4, iter: 200, loss: 0.333, accuracy: 80.92%epoch: 4, iter: 300, loss: 0.276, accuracy: 90.15%epoch: 4, iter: 400, loss: 0.242, accuracy: 95.00%epoch: 4, iter: 500, loss: 0.217, accuracy: 96.38%epoch: 4, iter: 600, loss: 0.191, accuracy: 95.31%epoch: 4, iter: 700, loss: 0.167, accuracy: 94.00%epoch: 5, iter: 100, loss: 0.142, accuracy: 94.62%epoch: 5, iter: 200, loss: 0.111, accuracy: 96.85%epoch: 5, iter: 300, loss: 0.116, accuracy: 96.85%epoch: 5, iter: 400, loss: 0.080, accuracy: 96.77%epoch: 5, iter: 500, loss: 0.059, accuracy: 98.54%epoch: 5, iter: 600, loss: 0.054, accuracy: 98.54%epoch: 5, iter: 700, loss: 0.042, accuracy: 99.00%epoch: 6, iter: 100, loss: 0.047, accuracy: 98.46%epoch: 6, iter: 200, loss: 0.049, accuracy: 98.08%epoch: 6, iter: 300, loss: 0.030, accuracy: 99.15%epoch: 6, iter: 400, loss: 0.029, accuracy: 99.23%epoch: 6, iter: 500, loss: 0.028, accuracy: 99.08%epoch: 6, iter: 600, loss: 0.029, accuracy: 99.08%epoch: 6, iter: 700, loss: 0.024, accuracy: 99.15%epoch: 7, iter: 100, loss: 0.023, accuracy: 99.15%epoch: 7, iter: 200, loss: 0.031, accuracy: 98.85%epoch: 7, iter: 300, loss: 0.023, accuracy: 99.46%epoch: 7, iter: 400, loss: 0.022, accuracy: 99.54%epoch: 7, iter: 500, loss: 0.022, accuracy: 99.38%epoch: 7, iter: 600, loss: 0.027, accuracy: 98.77%epoch: 7, iter: 700, loss: 0.019, accuracy: 99.46%epoch: 8, iter: 100, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 200, loss: 0.018, accuracy: 99.46%epoch: 8, iter: 300, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 400, loss: 0.018, accuracy: 99.62%epoch: 8, iter: 500, loss: 0.017, accuracy: 99.54%epoch: 8, iter: 600, loss: 0.026, accuracy: 99.00%epoch: 8, iter: 700, loss: 0.021, accuracy: 99.23%epoch: 9, iter: 100, loss: 0.017, accuracy: 99.62%epoch: 9, iter: 200, loss: 0.016, accuracy: 99.54%epoch: 9, iter: 300, loss: 0.015, accuracy: 99.54%epoch: 9, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 500, loss: 0.014, accuracy: 99.62%epoch: 9, iter: 600, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 700, loss: 0.014, accuracy: 99.62%epoch: 10, iter: 100, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 200, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 300, loss: 0.015, accuracy: 99.69%epoch: 10, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 10, iter: 500, loss: 0.013, accuracy: 99.62%epoch: 10, iter: 600, loss: 0.016, accuracy: 99.38%epoch: 10, iter: 700, loss: 0.017, accuracy: 99.38%
复制代码


这就是我们的简单 RNN ,以后有机会我们再尝试搭建类似 LSTM 这种更复杂的 RNN 。


作者介绍


张觉非,本科毕业于复旦大学,硕士毕业于中国科学院大学,先后任职于新浪微博、阿里,目前就职于奇虎 360,任机器学习技术专家。


本文来自 DataFun 社区


原文链接


https://mp.weixin.qq.com/s?__biz=MzU1NTMyOTI4Mw==&mid=2247493606&idx=1&sn=bf89adb739302688e6b837084bff911a&chksm=fbd7558acca0dc9c6a9754975ee796239b5fa2c26a38f604c56d19a5189f8a0febd75698ddd7&scene=27#wechat_redirect


2019-09-17 08:001100

评论

发布
暂无评论
发现更多内容

神级编程网站,堪称程序员的充电站,我给你找好了不能错过

C语言与CPP编程

编程 程序员 网站 计算机 编程语言‘

云资源管理概述

阿泽🧸

云资源 9月月更

YAML管理Kubernetes应用

CTO技术共享

致敬经典!这款华为主题能让你的手机秒变“历代Mate”

最新动态

自动化测试如何管理测试数据

老张

自动化测试

库调多了,都忘了最基础的概念-《方法篇》

知识浅谈

9月月更

信息安全之我见(45/100)

hackstoic

信息安全

云原生的学习心得

Geek_e8bfe4

《小米创业思考》之一:小米历程

郭明

读书笔记

一款开源的电商框架介绍:Spartacus

汪子熙

typescript angular SAP Spartacus 9月月更

深入学习SAP UI5框架代码系列之二:UI5 Module的懒加载机制

汪子熙

JavaScript 前端框架 SAP ui5 9月月更

【CSS】 position : static | absolute | relative | fixed | sticky

翼同学

CSS 前端 9月月更

云原生(三十六) | Kubernetes篇之Harbor入门和安装

Lansonli

云原生 9月月更

[教你做小游戏] 滑动选中!PC端+移动端适配!完美用户体验!斗地主手牌交互示范

HullQin

CSS JavaScript html 前端 9月月更

1分钟了解什么是数据湖?标准的数据湖什么样?

雨果

数据湖

你真的理解C语言中的 “ 数组 ” 吗?(初阶篇)

Albert Edison

数组 C语言 开发语言 二维数组 9月月更

使用 VUE 和 Go 触摸 WebAssembly

devpoint

Go Vue webassembly 9月月更

2022-09-04:以下go语言代码输出什么?A:不能编译;B:45;C:45.2;D:45.0。 package main import ( “fmt“ ) func main() {

福大大架构师每日一题

golang 福大大 选择题

Java进阶(五)Junit测试

No Silver Bullet

JUnit 测试 单元测试 9月月更

NFT数字藏品介绍:NFT数字藏品(交易平台)系统开发

开源直播系统源码

区块链 NFT 数字藏品 数字馆藏

「趣学前端」SVG,边学边做

叶一一

JavaScript 前端 9月月更

你猜 1 行Python代码能干什么呢?神奇的单行 Python 代码

梦想橡皮擦

Python Python. 9月月更

Containerd ctr、crictl、nerdctl 实战

CTO技术共享

数字化转型和信息化的区别是什么?

雨果

数字化转型 企业信息化

浅述AIOps与DevOps的区别在哪里

穿过生命散发芬芳

DevOps AIOPS 9月月更

生产环境中使用 Linkerd

CTO技术共享

查看k8s的etcd数据

程序员欣宸

Kubernetes 9月月更

计算机网络的组成

StackOverflow

编程 计算机网络 9月月更

【精通内核】Linux内核并发控制原理信号量与P-V原语源码解析

小明Java问道之路

Linux 并发控制 内核 Linux内核 9月月更

真的破防了!在华为主题熄屏显示找到我的第一台Mate

最新动态

随机生成也是需要有效控制的

zxhtom

9月月更

运用计算图搭建递归神经网络(RNN)_文化 & 方法_DataFunTalk_InfoQ精选文章