50万奖金+官方证书,深圳国际金融科技大赛正式启动,点击报名 了解详情
写点什么

Tensorflow2.0 实现 Deep-Q-Network

  • 2019-12-02
  • 本文字数:4042 字

    阅读完需:约 13 分钟

Tensorflow2.0实现Deep-Q-Network

深度 Q 网络(Deep - Q - Network) 介绍


在 Q-learning 算法中,当状态和动作空间是离散且维数不高时,可使用 Q-table 储存每个状态动作对的 Q 值,然后通过贝尔曼方差迭代求得每个状态动作对收敛的 Q 值,然后选择最优的动作当做策略。但是而当状态和动作空间是高维连续时,比如(游戏的状态动作对数目就很大)使用 Q-table 存储每个状态动作对就显得很不现实。


所以可以将 Q-Table 的更新问题变成一个函数拟合问题,相近的状态得到相近的输出动作。DQN 就是要设计一个神经网络结构,通过函数来拟合 Q 值。


下面引用一下自己写的一篇综述里面的 DQN 训练流程图,贴自己的图,不算侵犯版权吧,哈哈。知网上可以下载到这篇文章:http://kns.cnki.net/kcms/detail/detail.aspx?dbcode=CJFD&&filename=JSJX201801001。当时3个月大概看了百余篇DRL方向的论文,才写出来的,哈哈。



DQN 的亮点:


  • 通过 experience replay(经验池)的方法来解决相关性及非静态分布问题,在训练深度网络时,通常要求样本之间是相互独立的,所以通过这种随机采样的方式,大大降低了样本之间的关联性,从而提升了算法的稳定性。

  • 使用一个神经网络产生当前 Q 值,使用另外一个神经网络产生 Target Q 值。

  • DQN 损失函数和参数更新:


损失函数:



其中 yi 表示值函数的优化目标即目标网络的 Q 值:



参数更新的梯度为:



Tensorflow 2.0 实现 DQN


整体的代码是借鉴的莫烦大神,只不过现在用的接口都是 Tensorflow 2.0,所以代码显得很简单,风格很像 keras。


# -*- coding:utf-8 -*-# Author : zhaijianwei# Date : 2019/6/19 19:48
import tensorflow as tfimport numpy as npfrom tensorflow.python.keras import layersfrom tensorflow.python.keras.optimizers import RMSprop
from DQN.maze_env import Maze

class Eval_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network') self.layer1 = layers.Dense(10, activation='relu') self.logits = layers.Dense(num_actions, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class Target_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network_1') self.layer1 = layers.Dense(10, trainable=False, activation='relu') self.logits = layers.Dense(num_actions, trainable=False, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class DeepQNetwork: def __init__(self, n_actions, n_features, eval_model, target_model):
self.params = { 'n_actions': n_actions, 'n_features': n_features, 'learning_rate': 0.01, 'reward_decay': 0.9, 'e_greedy': 0.9, 'replace_target_iter': 300, 'memory_size': 500, 'batch_size': 32, 'e_greedy_increment': None }
# total learning step
self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.epsilon = 0 if self.params['e_greedy_increment'] is not None else self.params['e_greedy'] self.memory = np.zeros((self.params['memory_size'], self.params['n_features'] * 2 + 2))
self.eval_model = eval_model self.target_model = target_model
self.eval_model.compile( optimizer=RMSprop(lr=self.params['learning_rate']), loss='mse' ) self.cost_his = []
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory index = self.memory_counter % self.params['memory_size'] self.memory[index, :] = transition
self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.eval_model.predict(observation) print(actions_value) action = np.argmax(actions_value) else: action = np.random.randint(0, self.params['n_actions']) return action
def learn(self): # sample batch memory from all memory if self.memory_counter > self.params['memory_size']: sample_index = np.random.choice(self.params['memory_size'], size=self.params['batch_size']) else: sample_index = np.random.choice(self.memory_counter, size=self.params['batch_size'])
batch_memory = self.memory[sample_index, :]
q_next = self.target_model.predict(batch_memory[:, -self.params['n_features']:]) q_eval = self.eval_model.predict(batch_memory[:, :self.params['n_features']])
# change q_target w.r.t q_eval's action q_target = q_eval.copy()
batch_index = np.arange(self.params['batch_size'], dtype=np.int32) eval_act_index = batch_memory[:, self.params['n_features']].astype(int) reward = batch_memory[:, self.params['n_features'] + 1]
q_target[batch_index, eval_act_index] = reward + self.params['reward_decay'] * np.max(q_next, axis=1)
# check to replace target parameters if self.learn_step_counter % self.params['replace_target_iter'] == 0: for eval_layer, target_layer in zip(self.eval_model.layers, self.target_model.layers): target_layer.set_weights(eval_layer.get_weights()) print('\ntarget_params_replaced\n')
""" For example in this batch I have 2 samples and 3 actions: q_eval = [[1, 2, 3], [4, 5, 6]] q_target = q_eval = [[1, 2, 3], [4, 5, 6]] Then change q_target with the real q_target value w.r.t the q_eval's action. For example in: sample 0, I took action 0, and the max q_target value is -1; sample 1, I took action 2, and the max q_target value is -2: q_target = [[-1, 2, 3], [4, 5, -2]] So the (q_target - q_eval) becomes: [[(-1)-(1), 0, 0], [0, 0, (-2)-(6)]] We then backpropagate this error w.r.t the corresponding action to network, leave other action as error=0 cause we didn't choose it. """
# train eval network
self.cost = self.eval_model.train_on_batch(batch_memory[:, :self.params['n_features']], q_target)
self.cost_his.append(self.cost)
# increasing epsilon self.epsilon = self.epsilon + self.params['e_greedy_increment'] if self.epsilon < self.params['e_greedy'] \ else self.params['e_greedy'] self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()

def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset()
while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') env.destroy()

if __name__ == "__main__": # maze game env = Maze() eval_model = Eval_Model(num_actions=env.n_actions) target_model = Target_Model(num_actions=env.n_actions) RL = DeepQNetwork(env.n_actions, env.n_features, eval_model, target_model) env.after(100, run_maze) env.mainloop() RL.plot_cost()
复制代码


参考文献:


https://www.jianshu.com/p/10930c371cac


https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow


http://inoryy.com/post/tensorf


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/70009692


2019-12-02 16:221419

评论

发布
暂无评论
发现更多内容

2022年购买服务器运维管理软件选择哪家好?

行云管家

IT运维 服务器运维

硬核化解ISV四大痛点,华为云智联生活行业加速器助力伙伴实现商业成功

华为云开发者联盟

华为云 HarmonyOS 智联生活 华为云IoTDA 云云协同

OpenMLDB 12月会议纪要

第四范式开发者社区

人工智能 机器学习 第四范式 OpenMLDB

探索SaaS产业发展新机遇|鲁班会贵安首秀圆满收官

华为云开发者联盟

SaaS 华为云 应用构建

MongoDB基本介绍与安装(1)

Tom弹架构

Java mongodb

回顾2021,展望2022 | TDengine一年“成绩”汇总

TDengine

数据库 tdengine 2021年终总结

SpringMVC框架基础知识(01)

海拥(haiyong.site)

28天写作 12月日更

性能提升40%!阿里云神龙大数据加速引擎获TPCx-BB世界排名第一

阿里云弹性计算

阿里云 神龙

OceanBase 通过工信部电子标准院首批开源项目成熟度评估

OceanBase 数据库

数据库 工信部 OceanBase 开源

iOS 编译器__Attribute__的入门指南

37手游iOS技术运营团队

xcode LLVM Clang编译 Clang Attribute

云堡垒机和普通堡垒机的三大区别分析-行云管家

行云管家

云计算 网络安全 数据安全 堡垒机 云堡垒机

Greenplum内核源码分析-分布式事务(二)

王凤刚(ginobiliwang)

源码分析 分布式事务 greenplum

谁编写了区块链的规则?

CECBC

盘点 2021|一个新的开始

IT蜗壳-Tango

28天写作 12月日更 盘点2021 盘点 2021

“千言”开源数据集项目全面升级:数据驱动AI技术进步

百度开发者中心

千言

CRM系统为什么被认为是企业的重要资产?

低代码小观

企业管理 资产管理 CRM 企业管理系统 CRM系统

《国产分布式数据库选型及满意度调查报告》出炉,OceanBase获得双料第一

OceanBase 数据库

分布式数据库 OceanBase 开源 OceanBase 社区版

Veritas:2022年数据安全及合规领域行业预测

BeeWorks

全国首个!OceanBase 助力江西省养老保险全国统筹信息系统上线

OceanBase 数据库

OceanBase 开源 OceanBase 社区版 核心系统

Xcode 配置多套 App 图标的方法 --- AppStore 图标 A/B Test 实践

37手游iOS技术运营团队

ios xcode appstore 产品页优化 自定产品页

你设备中的木马藏在哪里?为什么查杀困难?

喀拉峻

黑客 网络安全 安全 信息安全 木马病毒

2021年末总结

编号94530

工作 架构设计 心得 2021 项目经验

数字化转型失败,有哪些原因?

禅道项目管理

数字化转型

大型购物平台的系统设计与架构

恒生LIGHT云社区

平台搭建 构架 平台架构

区块链赋能生猪养殖,让“猪”事有迹可循

CECBC

2021MongoDB技术实践与应用案例征集活动获奖通知

MongoDB中文社区

重塑企业创新方式 Serverless让云“开箱即用”

BeeWorks

HTTPDNS 快速入门

37手游iOS技术运营团队

DNS httpdns

COSCL开源评选名单公布!OceanBase 社区版荣获2021优秀开源项目奖

OceanBase 数据库

OceanBase 开源 OceanBase 社区版

Greenplum内核源码分析-分布式事务(三)

王凤刚(ginobiliwang)

源码分析 分布式事务 greenplum

DTC 2021 | 一体化架构的原生分布式数据库正在成为核心系统首选

OceanBase 数据库

数据库 OceanBase 开源 OceanBase 社区版

Tensorflow2.0实现Deep-Q-Network_语言 & 开发_Alex-zhai_InfoQ精选文章