写点什么

Tensorflow2.0 实现 Deep-Q-Network

  • 2019-12-02
  • 本文字数:4042 字

    阅读完需:约 13 分钟

Tensorflow2.0实现Deep-Q-Network

深度 Q 网络(Deep - Q - Network) 介绍


在 Q-learning 算法中,当状态和动作空间是离散且维数不高时,可使用 Q-table 储存每个状态动作对的 Q 值,然后通过贝尔曼方差迭代求得每个状态动作对收敛的 Q 值,然后选择最优的动作当做策略。但是而当状态和动作空间是高维连续时,比如(游戏的状态动作对数目就很大)使用 Q-table 存储每个状态动作对就显得很不现实。


所以可以将 Q-Table 的更新问题变成一个函数拟合问题,相近的状态得到相近的输出动作。DQN 就是要设计一个神经网络结构,通过函数来拟合 Q 值。


下面引用一下自己写的一篇综述里面的 DQN 训练流程图,贴自己的图,不算侵犯版权吧,哈哈。知网上可以下载到这篇文章:http://kns.cnki.net/kcms/detail/detail.aspx?dbcode=CJFD&&filename=JSJX201801001。当时3个月大概看了百余篇DRL方向的论文,才写出来的,哈哈。



DQN 的亮点:


  • 通过 experience replay(经验池)的方法来解决相关性及非静态分布问题,在训练深度网络时,通常要求样本之间是相互独立的,所以通过这种随机采样的方式,大大降低了样本之间的关联性,从而提升了算法的稳定性。

  • 使用一个神经网络产生当前 Q 值,使用另外一个神经网络产生 Target Q 值。

  • DQN 损失函数和参数更新:


损失函数:



其中 yi 表示值函数的优化目标即目标网络的 Q 值:



参数更新的梯度为:



Tensorflow 2.0 实现 DQN


整体的代码是借鉴的莫烦大神,只不过现在用的接口都是 Tensorflow 2.0,所以代码显得很简单,风格很像 keras。


# -*- coding:utf-8 -*-# Author : zhaijianwei# Date : 2019/6/19 19:48
import tensorflow as tfimport numpy as npfrom tensorflow.python.keras import layersfrom tensorflow.python.keras.optimizers import RMSprop
from DQN.maze_env import Maze

class Eval_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network') self.layer1 = layers.Dense(10, activation='relu') self.logits = layers.Dense(num_actions, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class Target_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network_1') self.layer1 = layers.Dense(10, trainable=False, activation='relu') self.logits = layers.Dense(num_actions, trainable=False, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class DeepQNetwork: def __init__(self, n_actions, n_features, eval_model, target_model):
self.params = { 'n_actions': n_actions, 'n_features': n_features, 'learning_rate': 0.01, 'reward_decay': 0.9, 'e_greedy': 0.9, 'replace_target_iter': 300, 'memory_size': 500, 'batch_size': 32, 'e_greedy_increment': None }
# total learning step
self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.epsilon = 0 if self.params['e_greedy_increment'] is not None else self.params['e_greedy'] self.memory = np.zeros((self.params['memory_size'], self.params['n_features'] * 2 + 2))
self.eval_model = eval_model self.target_model = target_model
self.eval_model.compile( optimizer=RMSprop(lr=self.params['learning_rate']), loss='mse' ) self.cost_his = []
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory index = self.memory_counter % self.params['memory_size'] self.memory[index, :] = transition
self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.eval_model.predict(observation) print(actions_value) action = np.argmax(actions_value) else: action = np.random.randint(0, self.params['n_actions']) return action
def learn(self): # sample batch memory from all memory if self.memory_counter > self.params['memory_size']: sample_index = np.random.choice(self.params['memory_size'], size=self.params['batch_size']) else: sample_index = np.random.choice(self.memory_counter, size=self.params['batch_size'])
batch_memory = self.memory[sample_index, :]
q_next = self.target_model.predict(batch_memory[:, -self.params['n_features']:]) q_eval = self.eval_model.predict(batch_memory[:, :self.params['n_features']])
# change q_target w.r.t q_eval's action q_target = q_eval.copy()
batch_index = np.arange(self.params['batch_size'], dtype=np.int32) eval_act_index = batch_memory[:, self.params['n_features']].astype(int) reward = batch_memory[:, self.params['n_features'] + 1]
q_target[batch_index, eval_act_index] = reward + self.params['reward_decay'] * np.max(q_next, axis=1)
# check to replace target parameters if self.learn_step_counter % self.params['replace_target_iter'] == 0: for eval_layer, target_layer in zip(self.eval_model.layers, self.target_model.layers): target_layer.set_weights(eval_layer.get_weights()) print('\ntarget_params_replaced\n')
""" For example in this batch I have 2 samples and 3 actions: q_eval = [[1, 2, 3], [4, 5, 6]] q_target = q_eval = [[1, 2, 3], [4, 5, 6]] Then change q_target with the real q_target value w.r.t the q_eval's action. For example in: sample 0, I took action 0, and the max q_target value is -1; sample 1, I took action 2, and the max q_target value is -2: q_target = [[-1, 2, 3], [4, 5, -2]] So the (q_target - q_eval) becomes: [[(-1)-(1), 0, 0], [0, 0, (-2)-(6)]] We then backpropagate this error w.r.t the corresponding action to network, leave other action as error=0 cause we didn't choose it. """
# train eval network
self.cost = self.eval_model.train_on_batch(batch_memory[:, :self.params['n_features']], q_target)
self.cost_his.append(self.cost)
# increasing epsilon self.epsilon = self.epsilon + self.params['e_greedy_increment'] if self.epsilon < self.params['e_greedy'] \ else self.params['e_greedy'] self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()

def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset()
while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') env.destroy()

if __name__ == "__main__": # maze game env = Maze() eval_model = Eval_Model(num_actions=env.n_actions) target_model = Target_Model(num_actions=env.n_actions) RL = DeepQNetwork(env.n_actions, env.n_features, eval_model, target_model) env.after(100, run_maze) env.mainloop() RL.plot_cost()
复制代码


参考文献:


https://www.jianshu.com/p/10930c371cac


https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow


http://inoryy.com/post/tensorf


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/70009692


2019-12-02 16:221203

评论

发布
暂无评论
发现更多内容

数字政府建设如火如荼 区块链保证数据真实安全

CECBC

【得物技术】常用注册中心原理及比较

得物技术

zookeeper nacos Consul Eureka 注册中心

数据仓库的基本要求

奔向架构师

数据仓库 数据架构 7月日更

为什么搞一个副业项目如此之难?

张理查

模块一作业

君子意如何

「架构师训练营第 1 期」

🏆「作者推荐」【JVM 性能分析】精心准备了一套 JVM 分析工具的锦囊(中部)

洛神灬殇

JVM 性能分析 7月日更

Redisson 分布式锁源码 08:MultiLock 加锁与锁释放

程序员小航

Java 源码 分布式锁 redisson redison

hive 与传统数据库对比

五分钟学大数据

hive 7月日更

你有多少密码是123456

MySQL从删库到跑路

密码管理

都说数仓是面向主题建设的,那数仓的主题和主题域又应该怎么划分呢?

白贺BaiHe

数仓 7月日更 数仓主题 主题域 数仓建设

phpExcel:Excel数据导入导出最佳实战

devpoint

php Excel thinkphp 7月日更

利用 Vector 从日志创建指标来提高系统的可观测性

哈德韦

日志 可观测性 Prometheus SRE vector

免费分享Java Web 开发的优秀图书

Java入门到架构

Java Java书籍推荐

解读区块链在制药和物流管理中具备的优势

CECBC

🏆「作者推荐」【JVM性能分析】精心准备了一套JVM分析工具的锦囊(上部)

洛神灬殇

JVM 性能分析 jvm调优 7月日更

使用 Open Policy Agent 实现可信镜像仓库检查

张晓辉

Kubernetes 安全 OPA

Spring源码解析 -- SpringWeb请求映射Map初始化

Java spring 源码解析

吃药吗?AI造的!

脑极体

幸福来敲门

卢卡多多

幸福 7月日更

jTDS 驱动导致 cpu 100%

顾五木

cpu占用100% 线上程序问题

架构实战营 - 模块 8- 作业

泄矢的呼啦圈

架构实战营

话题讨论| 帮朋友拼多多助力会导致银行卡被盗刷?

石云升

拼多多 话题讨论 7月日更

Apollo配置中心如何实现配置热发布

慕枫技术笔记

微服务 后端 配置中心

查找——HASH

若尘

数据结构 hash

区块链技术在“三资”监管领域的应用

CECBC

设计消息队列存储消息数据的MySQL表格

Vincent

架构训练营

在线ASCII艺术字生成工具,SpringBoot banner生成工具

入门小站

工具

【Flutter 专题】91图解 Dart 单线程实现异步处理之 Future (二)

阿策小和尚

Flutter 小菜 0 基础学习 Flutter Android 小菜鸟 7月日更

PowerShell 哈希表

耳东@Erdong

PowerShell 7月日更

Linux之find xargs

入门小站

Linux

external-attacher源码分析(2)-核心处理逻辑分析

良凯尔

Kubernetes 源码分析 Ceph CSI Kubernetes Plugin

Tensorflow2.0实现Deep-Q-Network_语言 & 开发_Alex-zhai_InfoQ精选文章