写点什么

运用计算图搭建递归神经网络(RNN)

  • 2019-09-17
  • 本文字数:5544 字

    阅读完需:约 18 分钟

运用计算图搭建递归神经网络(RNN)

继续玩我们的计算图框架。这一次我们运用计算图搭建递归神经网络(RNN,Recursive Neural Network)。RNN 处理前后有承接关系的序列状数据,例如时序数据。当然,前后的承接也不一定是时间上的,但总之是有前后关系的序列。

RNN

RNN 的思想是:网络也分步,每步以输入序列的该步数据(向量)和上一步数据(第一步没有)为输入,进行变换,得到这一步的输出(向量)。这样的话,序列的每一步就会对下一步产生影响。RNN 用变换的参数把握序列每一步之间的关系。最后一步的输出可以送给全连接层,最终用于分类或回归。RNN 有很多种,有一些复杂的变体,本文搭建一种最简单的 RNN ,它的结构是这样的:



蓝色长条表示 m 维输入向量,一共 n 个。这表示数据是长度为 n 的序列,每一步是一个 m 维向量。绿色的矩形就是每一步的变换。yi 是每一步的 k 维输出向量。每一步用 k x k 的权值矩阵 Y 去乘前一步的输出向量(第一步没有),用 k x m 的权值矩阵 W 去乘这一步的输入向量,加和后再加上 k 维偏置向量 b ,施加激活函数 ϕ (我们取 ReLU),就得到这一步的输出。


最后一步的输出也是 k 维向量,把它送给全连接层,最后施加 SoftMax 后得到各个类别的概率,再接上一个交叉熵损失就可以用来训练分类问题了。用我们的计算图框架可以这样搭建这个简单的 RNN(代码):


seq_len = 96  # 序列长度dimension = 16  # 序列每一步的向量维度hidden_dim = 12  # RNN 时间单元的输出维度
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
复制代码


这就是构造 RNN 以及交叉熵损失的计算图的代码,很简单,right ?有了计算图以及自动求导,我们只管搭建网络即可,网络的训练就交给计算图去做了。否则你可以想象,按照示意图表示的计算,推导交叉熵损失对 RNN 的各个权值矩阵和偏置的梯度是多么困难。

时间序列问题

我们构造一份数据,它包含两类时间序列,一类是方波,一类是正弦波,代码如下:


def get_sequence_data(number_of_classes=2, dimension=10, length=10, number_of_examples=1000, train_set_ratio=0.7, seed=42):    """    生成两类序列数据。    """    xx = []    xx.append(np.sin(np.arange(0, 10, 10 / length)))  # 正弦波    xx.append(np.array(signal.square(np.arange(0, 10, 10 / length))))  # 方波

data = [] for i in range(number_of_classes): x = xx[i] for j in range(number_of_examples): sequence = x + np.random.normal(0, 1.0, (dimension, len(x))) # 加入高斯噪声 label = np.array([int(i == j) for j in range(number_of_classes)])
data.append(np.c_[sequence.reshape(1, -1), label.reshape(1, -1)])
# 把各个类别的样本合在一起 data = np.concatenate(data, axis=0)
# 随机打乱样本顺序 np.random.shuffle(data)
# 计算训练样本数量 train_set_size = int(number_of_examples * train_set_ratio) # 训练集样本数量
# 将训练集和测试集、特征和标签分开 return (data[:train_set_size, :-number_of_classes], data[:train_set_size, -number_of_classes:], data[train_set_size:, :-number_of_classes], data[train_set_size:, -number_of_classes:])
复制代码


我们用这一行代码获取长度为 96 ,维度为 16 的两类(各 1000 个)序列:


# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
复制代码


看一看时间序列样本,先看正弦波:



正弦波序列


这是一个正弦波时间序列样本,它包含 16 条曲线,每一条都是 sin 曲线加噪声。之所以包含 16 条曲线,因为我们的时间序列的每一步是一个 16 维向量,按时间列起来就有了 16 条正弦曲线。正弦波时间序列是我们的正样本。方波时间序列是负样本:



方波序列


一个方波时间序列先维持 +1 一段时间,变为 -1 维持一段时间,再回到 +1 ,循环往复。由于我们的高斯噪声加得较大,可以看到正弦波和方波还是有可能混淆的,但也能看出它们之间的差异。

训练

现在就用我们构造的 RNN 训练一个分类模型,分类正弦波和方波,代码如下:


from sklearn.metrics import accuracy_score
from layer import *from node import *from optimizer import *
seq_len = 96 # 序列长度dimension = 16 # 序列每一步的向量维度hidden_dim = 12 # RNN 时间单元的输出维度
# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
# Adam 优化器optimizer = Adam(default_graph, loss, 0.005, batch_size=16)
# 训练print("start training", flush=True)for e in range(10):
for i in range(len(train_x)): x = np.mat(train_x[i, :]).reshape(dimension, seq_len) for j in range(seq_len): input_vectors[j].set_value(x[:, j]) label.set_value(np.mat(train_y[i, :]).T)
# 执行一步优化 optimizer.one_step()
if i > 1 and (i + 1) % 100 == 0:
# 在测试集上评估模型正确率 probs = [] losses = [] for j in range(len(test_x)): # x = test_x[j, :].reshape(dimension, seq_len) x = np.mat(test_x[j, :]).reshape(dimension, seq_len) for k in range(seq_len): input_vectors[k].set_value(x[:, k]) label.set_value(np.mat(test_y[j, :]).T)
# 前向传播计算概率 prob.forward() probs.append(prob.value.A1)
# 计算损失值 loss.forward() losses.append(loss.value[0, 0])
# print("test instance: {:d}".format(j))
# 取概率最大的类别为预测类别 pred = np.argmax(np.array(probs), axis=1) truth = np.argmax(test_y, axis=1) accuracy = accuracy_score(truth, pred)
default_graph.draw() print("epoch: {:d}, iter: {:d}, loss: {:.3f}, accuracy: {:.2f}%".format(e + 1, i + 1, np.mean(losses), accuracy * 100), flush=True)
复制代码


训练 10 个 epoch 后,测试集上的正确率达到了 99% :


epoch: 1, iter: 100, loss: 0.693, accuracy: 51.08%epoch: 1, iter: 200, loss: 0.692, accuracy: 51.08%epoch: 1, iter: 300, loss: 0.677, accuracy: 78.31%epoch: 1, iter: 400, loss: 0.573, accuracy: 49.31%epoch: 1, iter: 500, loss: 0.520, accuracy: 53.92%epoch: 1, iter: 600, loss: 0.599, accuracy: 97.08%epoch: 1, iter: 700, loss: 0.617, accuracy: 99.00%epoch: 2, iter: 100, loss: 0.601, accuracy: 94.46%epoch: 2, iter: 200, loss: 0.579, accuracy: 82.08%epoch: 2, iter: 300, loss: 0.558, accuracy: 76.15%epoch: 2, iter: 400, loss: 0.531, accuracy: 67.85%epoch: 2, iter: 500, loss: 0.507, accuracy: 63.77%epoch: 2, iter: 600, loss: 0.493, accuracy: 61.15%epoch: 2, iter: 700, loss: 0.479, accuracy: 62.23%epoch: 3, iter: 100, loss: 0.443, accuracy: 69.92%epoch: 3, iter: 200, loss: 0.393, accuracy: 85.85%epoch: 3, iter: 300, loss: 0.365, accuracy: 97.69%epoch: 3, iter: 400, loss: 0.284, accuracy: 95.08%epoch: 3, iter: 500, loss: 0.199, accuracy: 95.69%epoch: 3, iter: 600, loss: 0.490, accuracy: 80.62%epoch: 3, iter: 700, loss: 0.264, accuracy: 94.31%epoch: 4, iter: 100, loss: 0.320, accuracy: 83.46%epoch: 4, iter: 200, loss: 0.333, accuracy: 80.92%epoch: 4, iter: 300, loss: 0.276, accuracy: 90.15%epoch: 4, iter: 400, loss: 0.242, accuracy: 95.00%epoch: 4, iter: 500, loss: 0.217, accuracy: 96.38%epoch: 4, iter: 600, loss: 0.191, accuracy: 95.31%epoch: 4, iter: 700, loss: 0.167, accuracy: 94.00%epoch: 5, iter: 100, loss: 0.142, accuracy: 94.62%epoch: 5, iter: 200, loss: 0.111, accuracy: 96.85%epoch: 5, iter: 300, loss: 0.116, accuracy: 96.85%epoch: 5, iter: 400, loss: 0.080, accuracy: 96.77%epoch: 5, iter: 500, loss: 0.059, accuracy: 98.54%epoch: 5, iter: 600, loss: 0.054, accuracy: 98.54%epoch: 5, iter: 700, loss: 0.042, accuracy: 99.00%epoch: 6, iter: 100, loss: 0.047, accuracy: 98.46%epoch: 6, iter: 200, loss: 0.049, accuracy: 98.08%epoch: 6, iter: 300, loss: 0.030, accuracy: 99.15%epoch: 6, iter: 400, loss: 0.029, accuracy: 99.23%epoch: 6, iter: 500, loss: 0.028, accuracy: 99.08%epoch: 6, iter: 600, loss: 0.029, accuracy: 99.08%epoch: 6, iter: 700, loss: 0.024, accuracy: 99.15%epoch: 7, iter: 100, loss: 0.023, accuracy: 99.15%epoch: 7, iter: 200, loss: 0.031, accuracy: 98.85%epoch: 7, iter: 300, loss: 0.023, accuracy: 99.46%epoch: 7, iter: 400, loss: 0.022, accuracy: 99.54%epoch: 7, iter: 500, loss: 0.022, accuracy: 99.38%epoch: 7, iter: 600, loss: 0.027, accuracy: 98.77%epoch: 7, iter: 700, loss: 0.019, accuracy: 99.46%epoch: 8, iter: 100, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 200, loss: 0.018, accuracy: 99.46%epoch: 8, iter: 300, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 400, loss: 0.018, accuracy: 99.62%epoch: 8, iter: 500, loss: 0.017, accuracy: 99.54%epoch: 8, iter: 600, loss: 0.026, accuracy: 99.00%epoch: 8, iter: 700, loss: 0.021, accuracy: 99.23%epoch: 9, iter: 100, loss: 0.017, accuracy: 99.62%epoch: 9, iter: 200, loss: 0.016, accuracy: 99.54%epoch: 9, iter: 300, loss: 0.015, accuracy: 99.54%epoch: 9, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 500, loss: 0.014, accuracy: 99.62%epoch: 9, iter: 600, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 700, loss: 0.014, accuracy: 99.62%epoch: 10, iter: 100, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 200, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 300, loss: 0.015, accuracy: 99.69%epoch: 10, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 10, iter: 500, loss: 0.013, accuracy: 99.62%epoch: 10, iter: 600, loss: 0.016, accuracy: 99.38%epoch: 10, iter: 700, loss: 0.017, accuracy: 99.38%
复制代码


这就是我们的简单 RNN ,以后有机会我们再尝试搭建类似 LSTM 这种更复杂的 RNN 。


作者介绍


张觉非,本科毕业于复旦大学,硕士毕业于中国科学院大学,先后任职于新浪微博、阿里,目前就职于奇虎 360,任机器学习技术专家。


本文来自 DataFun 社区


原文链接


https://mp.weixin.qq.com/s?__biz=MzU1NTMyOTI4Mw==&mid=2247493606&idx=1&sn=bf89adb739302688e6b837084bff911a&chksm=fbd7558acca0dc9c6a9754975ee796239b5fa2c26a38f604c56d19a5189f8a0febd75698ddd7&scene=27#wechat_redirect


2019-09-17 08:001497

评论

发布
暂无评论
发现更多内容

一文说明白Context Engineering:AI智能体的动态语境构建术

蔡超

AI Agent Agentic AI Context Engineering

智源全面开源RoboBrain 2.0与RoboOS 2.0:刷新10项评测基准,多机协作加速群体智能

智源研究院

人工智能 具身智能

淘宝图片搜索接口技术解析与Python实现

tbapi

淘宝图片搜索接口 淘宝拍立淘接口 天猫图片搜索接口 天猫拍立淘接口

中烟创新灯塔大模型应用开发平台入选“2024年度百大AI产品”

中烟创新

冲进腾讯!太不容易了

王中阳Go

Go 腾讯 面试 后端

如何在 Elasticsearch 中构建你的智能 AI 助手?

阿里云大数据AI技术

人工智能 elasticsearch 运维 数据分析 数据库 大数据

多模态AI,敏感数据识别的终结者

权说安全

AI 零信任 数据防泄漏

DeepSeek部署实战:模型对比、部署优化与应用场景解析

中烟创新

LambdaQueryWrapper遇上@Async

京东科技开发者

EMQX + Amazon S3 Tables:从实时物联网数据到数据湖仓

EMQ映云科技

mqtt Amazon S3

Web3支付App的技术框架

北京木奇移动技术有限公司

区块链开发 软件外包公司 web3开发

MyEMS 开源能源管理系统与同类系统的全方位对比分析

开源能源管理系统

开源 安全生产 绿色生产 能源管理系统

黑龙江密码测评的实施流程

等保测评

AI时代需要什么样的园区网络?答案藏在四个新技术里

Alter

KWDB 时序引擎核心能力——存储与读写

KaiwuDB

数据库 时序数据库

设备维修不是单纯的修机器,这五个方面一定要清楚!

积木链小链

数字化转型 智能制造 设备维修

智能网联 + AI:EMQX 5.10.0 大模型集成功能介绍

EMQ映云科技

人工智能 mqtt

大龄青年失业,可以在哪里寻找新的工作机会

Y11

求职 找工作 招聘 转行

苹果电脑装机必备软件推荐,Mac圈超实用软件列表

阿拉灯神丁

实用工具 苹果软件 Tuxera NTFS教程 CleanMyMac X中文版 mac装机必备

20250713动词ing,ed尾字母双写规则

codists

Python

基于业务知识和代码库增强的大模型生成代码实践

京东科技开发者

kimi2实测:5分钟造3D游戏+个人网站,真·国产Claude级编程体验,含Cline教程

阿星AI工作室

AI 产品经理 kimi

新能源锂电池制造执行系统(MES)全面解决方案

万界星空科技

mes 新能源行业 制造业工厂 新能源电池 锂电池mes

MyEMS 4G 网关:打造高效协同的能源管理中枢

开源能源管理系统

开源 能源管理系统 4G网关

AI背单词App的技术方案

北京木奇移动技术有限公司

软件外包公司 AI英语学习 AI背单词

MyEMS:ISO 50006 标准下的开源能源管理利器

开源能源管理系统

开源 ISO 50006 能源管理系统

黑龙江等保测评流程的注意事项

等保测评

天猫商品详情API接口技术解析与Python实现

tbapi

天猫商品详情接口 天猫API 天猫商品数据采集

三级等保测评流程五步走

等保测评

为什么你的 App 需要一个“超级大脑”?

Speedoooo

APP开发 小程序容器 小程序技术 小程序容器技术

京东携手HarmonyOS SDK首发家电AR高精摆放功能

京东科技开发者

运用计算图搭建递归神经网络(RNN)_文化 & 方法_DataFunTalk_InfoQ精选文章