HarmonyOS开发者限时福利来啦!最高10w+现金激励等你拿~ 了解详情
写点什么

Tensorflow2.0 实现 Deep-Q-Network

  • 2019-12-02
  • 本文字数:4042 字

    阅读完需:约 13 分钟

Tensorflow2.0实现Deep-Q-Network

深度 Q 网络(Deep - Q - Network) 介绍


在 Q-learning 算法中,当状态和动作空间是离散且维数不高时,可使用 Q-table 储存每个状态动作对的 Q 值,然后通过贝尔曼方差迭代求得每个状态动作对收敛的 Q 值,然后选择最优的动作当做策略。但是而当状态和动作空间是高维连续时,比如(游戏的状态动作对数目就很大)使用 Q-table 存储每个状态动作对就显得很不现实。


所以可以将 Q-Table 的更新问题变成一个函数拟合问题,相近的状态得到相近的输出动作。DQN 就是要设计一个神经网络结构,通过函数来拟合 Q 值。


下面引用一下自己写的一篇综述里面的 DQN 训练流程图,贴自己的图,不算侵犯版权吧,哈哈。知网上可以下载到这篇文章:http://kns.cnki.net/kcms/detail/detail.aspx?dbcode=CJFD&&filename=JSJX201801001。当时3个月大概看了百余篇DRL方向的论文,才写出来的,哈哈。



DQN 的亮点:


  • 通过 experience replay(经验池)的方法来解决相关性及非静态分布问题,在训练深度网络时,通常要求样本之间是相互独立的,所以通过这种随机采样的方式,大大降低了样本之间的关联性,从而提升了算法的稳定性。

  • 使用一个神经网络产生当前 Q 值,使用另外一个神经网络产生 Target Q 值。

  • DQN 损失函数和参数更新:


损失函数:



其中 yi 表示值函数的优化目标即目标网络的 Q 值:



参数更新的梯度为:



Tensorflow 2.0 实现 DQN


整体的代码是借鉴的莫烦大神,只不过现在用的接口都是 Tensorflow 2.0,所以代码显得很简单,风格很像 keras。


# -*- coding:utf-8 -*-# Author : zhaijianwei# Date : 2019/6/19 19:48
import tensorflow as tfimport numpy as npfrom tensorflow.python.keras import layersfrom tensorflow.python.keras.optimizers import RMSprop
from DQN.maze_env import Maze

class Eval_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network') self.layer1 = layers.Dense(10, activation='relu') self.logits = layers.Dense(num_actions, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class Target_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network_1') self.layer1 = layers.Dense(10, trainable=False, activation='relu') self.logits = layers.Dense(num_actions, trainable=False, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class DeepQNetwork: def __init__(self, n_actions, n_features, eval_model, target_model):
self.params = { 'n_actions': n_actions, 'n_features': n_features, 'learning_rate': 0.01, 'reward_decay': 0.9, 'e_greedy': 0.9, 'replace_target_iter': 300, 'memory_size': 500, 'batch_size': 32, 'e_greedy_increment': None }
# total learning step
self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.epsilon = 0 if self.params['e_greedy_increment'] is not None else self.params['e_greedy'] self.memory = np.zeros((self.params['memory_size'], self.params['n_features'] * 2 + 2))
self.eval_model = eval_model self.target_model = target_model
self.eval_model.compile( optimizer=RMSprop(lr=self.params['learning_rate']), loss='mse' ) self.cost_his = []
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory index = self.memory_counter % self.params['memory_size'] self.memory[index, :] = transition
self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.eval_model.predict(observation) print(actions_value) action = np.argmax(actions_value) else: action = np.random.randint(0, self.params['n_actions']) return action
def learn(self): # sample batch memory from all memory if self.memory_counter > self.params['memory_size']: sample_index = np.random.choice(self.params['memory_size'], size=self.params['batch_size']) else: sample_index = np.random.choice(self.memory_counter, size=self.params['batch_size'])
batch_memory = self.memory[sample_index, :]
q_next = self.target_model.predict(batch_memory[:, -self.params['n_features']:]) q_eval = self.eval_model.predict(batch_memory[:, :self.params['n_features']])
# change q_target w.r.t q_eval's action q_target = q_eval.copy()
batch_index = np.arange(self.params['batch_size'], dtype=np.int32) eval_act_index = batch_memory[:, self.params['n_features']].astype(int) reward = batch_memory[:, self.params['n_features'] + 1]
q_target[batch_index, eval_act_index] = reward + self.params['reward_decay'] * np.max(q_next, axis=1)
# check to replace target parameters if self.learn_step_counter % self.params['replace_target_iter'] == 0: for eval_layer, target_layer in zip(self.eval_model.layers, self.target_model.layers): target_layer.set_weights(eval_layer.get_weights()) print('\ntarget_params_replaced\n')
""" For example in this batch I have 2 samples and 3 actions: q_eval = [[1, 2, 3], [4, 5, 6]] q_target = q_eval = [[1, 2, 3], [4, 5, 6]] Then change q_target with the real q_target value w.r.t the q_eval's action. For example in: sample 0, I took action 0, and the max q_target value is -1; sample 1, I took action 2, and the max q_target value is -2: q_target = [[-1, 2, 3], [4, 5, -2]] So the (q_target - q_eval) becomes: [[(-1)-(1), 0, 0], [0, 0, (-2)-(6)]] We then backpropagate this error w.r.t the corresponding action to network, leave other action as error=0 cause we didn't choose it. """
# train eval network
self.cost = self.eval_model.train_on_batch(batch_memory[:, :self.params['n_features']], q_target)
self.cost_his.append(self.cost)
# increasing epsilon self.epsilon = self.epsilon + self.params['e_greedy_increment'] if self.epsilon < self.params['e_greedy'] \ else self.params['e_greedy'] self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()

def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset()
while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') env.destroy()

if __name__ == "__main__": # maze game env = Maze() eval_model = Eval_Model(num_actions=env.n_actions) target_model = Target_Model(num_actions=env.n_actions) RL = DeepQNetwork(env.n_actions, env.n_features, eval_model, target_model) env.after(100, run_maze) env.mainloop() RL.plot_cost()
复制代码


参考文献:


https://www.jianshu.com/p/10930c371cac


https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow


http://inoryy.com/post/tensorf


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/70009692


2019-12-02 16:221097

评论

发布
暂无评论
发现更多内容

实用fcpx插件:Photo Montage(轻松制作照片动画)

南屿

fcpx fcpx插件

5分钟带您了解DRS录制回放

华为云开发者联盟

数据库 后端 华为云 华为云开发者联盟

电子签章接口调用,以契约锁为例

Geek_2a38d5

电子签章 契约锁

PS磨皮滤镜降噪插件Imagenomic Professional 支持ps2024 兼容M1

南屿

磨皮插件 ps滤镜下载 Imagenomic Imagenomic Professional

微店获得微店商品详情 API(micro.item_get)在电商中的发展

技术冰糖葫芦

API

AE蓝宝石插件BorisFX Sapphire 2024 for Mac破解版 及新功能介绍

南屿

如何定位和优化程序CPU、内存等性能之巅

雪奈椰子

SD-WAN服务简介及挑选服务商指南

Ogcloud

SD-WAN SD-WAN组网 SD-WAN服务商

eBPF运行时安全

统信软件

安全 ebpf 运行时

ps一键磨皮插件Delicious Retouch 5怎么安装 支持M芯片

南屿

磨皮插件 Photoshop 插件

Sketch Measure for Mac中文破解版 sketch标注插件下载

南屿

Sketch Measure mac中文版 sketch标注插件

荣耀开发者大会2023 · 一张图读懂设计分论坛

荣耀开发者服务平台

AI 设计 开发者大会 honor

如何利用 APM 追踪完整的类函数调用

心有千千结

APM Datadog OpenTelemetry 系统可观测性 DDTrace

外贸自建站推广为何首选谷歌广告?谷歌广告的优势在哪?

九凌网络

软件测试/测试开发/全日制/测试管理丨兼容性测试

测试人

软件测试

photoshop色轮插件Coolorus怎么安装 附Coolorus 许可证

南屿

Coolorus mac版 PS调色插件 Coolorus许可证 Coolorus安装教程

软件测试/测试开发/全日制/测试管理丨iOS 自动化相关工具

测试人

软件测试

App加固:不同类型和费用对比

NFTScan | 01.08~01.14 NFT 市场热点汇总

NFT Research

NFT NFT\ NFTScan

软件测试/测试开发/全日制/测试管理丨CSS Selector

测试人

软件测试

LED透明显示屏前景发展怎么样?

Dylan

LED显示屏 全彩LED显示屏 led显示屏厂家 市场 #研发

ScaleUp插件使用方法 附ScaleUp for Mac破解版资源

南屿

高级视频增强工具 ScaleUp插件下载 ScaleUp mac破解版 AE/PR插件

Lightroom预设资源-高级食物lr预设 附lr预设导入教程

南屿

高级食物lr预设 Lightroom预设下载 lr预设怎么导入

Authing 入选中国信通院《 2023 高质量数字化转型产品及服务全景图》

Authing

中国信通院 信通院 Authing

云联接:揭开SD-WAN神秘面纱,颠覆你对网络的认知!

博文视点Broadview

30款绚彩天空背景特效PS渐变-Photoshop天空渐变

南屿

ps渐变 天空背景特效 Photoshop素材

软件测试/测试开发/全日制/测试管理丨Android WebView 技术原理

测试人

软件测试

QCN9024: The future of wireless communications, five major advantages over competitors

wallysSK

堡垒机和数据库防水坝的区别一二

行云管家

数据库 网络安全 堡垒机 数据库防水坝

FCPX插件-动态视频运动模糊视觉特效 mMotion Blur 支持Intel和Apple M芯片

南屿

fcpx动态视频 运动模糊视觉特效 fcpx插件下载 fcpx特效

Tensorflow2.0实现Deep-Q-Network_语言 & 开发_Alex-zhai_InfoQ精选文章