写点什么

Tensorflow2.0 实现 Deep-Q-Network

  • 2019-12-02
  • 本文字数:4042 字

    阅读完需:约 13 分钟

Tensorflow2.0实现Deep-Q-Network

深度 Q 网络(Deep - Q - Network) 介绍


在 Q-learning 算法中,当状态和动作空间是离散且维数不高时,可使用 Q-table 储存每个状态动作对的 Q 值,然后通过贝尔曼方差迭代求得每个状态动作对收敛的 Q 值,然后选择最优的动作当做策略。但是而当状态和动作空间是高维连续时,比如(游戏的状态动作对数目就很大)使用 Q-table 存储每个状态动作对就显得很不现实。


所以可以将 Q-Table 的更新问题变成一个函数拟合问题,相近的状态得到相近的输出动作。DQN 就是要设计一个神经网络结构,通过函数来拟合 Q 值。


下面引用一下自己写的一篇综述里面的 DQN 训练流程图,贴自己的图,不算侵犯版权吧,哈哈。知网上可以下载到这篇文章:http://kns.cnki.net/kcms/detail/detail.aspx?dbcode=CJFD&&filename=JSJX201801001。当时3个月大概看了百余篇DRL方向的论文,才写出来的,哈哈。



DQN 的亮点:


  • 通过 experience replay(经验池)的方法来解决相关性及非静态分布问题,在训练深度网络时,通常要求样本之间是相互独立的,所以通过这种随机采样的方式,大大降低了样本之间的关联性,从而提升了算法的稳定性。

  • 使用一个神经网络产生当前 Q 值,使用另外一个神经网络产生 Target Q 值。

  • DQN 损失函数和参数更新:


损失函数:



其中 yi 表示值函数的优化目标即目标网络的 Q 值:



参数更新的梯度为:



Tensorflow 2.0 实现 DQN


整体的代码是借鉴的莫烦大神,只不过现在用的接口都是 Tensorflow 2.0,所以代码显得很简单,风格很像 keras。


# -*- coding:utf-8 -*-# Author : zhaijianwei# Date : 2019/6/19 19:48
import tensorflow as tfimport numpy as npfrom tensorflow.python.keras import layersfrom tensorflow.python.keras.optimizers import RMSprop
from DQN.maze_env import Maze

class Eval_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network') self.layer1 = layers.Dense(10, activation='relu') self.logits = layers.Dense(num_actions, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class Target_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network_1') self.layer1 = layers.Dense(10, trainable=False, activation='relu') self.logits = layers.Dense(num_actions, trainable=False, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class DeepQNetwork: def __init__(self, n_actions, n_features, eval_model, target_model):
self.params = { 'n_actions': n_actions, 'n_features': n_features, 'learning_rate': 0.01, 'reward_decay': 0.9, 'e_greedy': 0.9, 'replace_target_iter': 300, 'memory_size': 500, 'batch_size': 32, 'e_greedy_increment': None }
# total learning step
self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.epsilon = 0 if self.params['e_greedy_increment'] is not None else self.params['e_greedy'] self.memory = np.zeros((self.params['memory_size'], self.params['n_features'] * 2 + 2))
self.eval_model = eval_model self.target_model = target_model
self.eval_model.compile( optimizer=RMSprop(lr=self.params['learning_rate']), loss='mse' ) self.cost_his = []
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory index = self.memory_counter % self.params['memory_size'] self.memory[index, :] = transition
self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.eval_model.predict(observation) print(actions_value) action = np.argmax(actions_value) else: action = np.random.randint(0, self.params['n_actions']) return action
def learn(self): # sample batch memory from all memory if self.memory_counter > self.params['memory_size']: sample_index = np.random.choice(self.params['memory_size'], size=self.params['batch_size']) else: sample_index = np.random.choice(self.memory_counter, size=self.params['batch_size'])
batch_memory = self.memory[sample_index, :]
q_next = self.target_model.predict(batch_memory[:, -self.params['n_features']:]) q_eval = self.eval_model.predict(batch_memory[:, :self.params['n_features']])
# change q_target w.r.t q_eval's action q_target = q_eval.copy()
batch_index = np.arange(self.params['batch_size'], dtype=np.int32) eval_act_index = batch_memory[:, self.params['n_features']].astype(int) reward = batch_memory[:, self.params['n_features'] + 1]
q_target[batch_index, eval_act_index] = reward + self.params['reward_decay'] * np.max(q_next, axis=1)
# check to replace target parameters if self.learn_step_counter % self.params['replace_target_iter'] == 0: for eval_layer, target_layer in zip(self.eval_model.layers, self.target_model.layers): target_layer.set_weights(eval_layer.get_weights()) print('\ntarget_params_replaced\n')
""" For example in this batch I have 2 samples and 3 actions: q_eval = [[1, 2, 3], [4, 5, 6]] q_target = q_eval = [[1, 2, 3], [4, 5, 6]] Then change q_target with the real q_target value w.r.t the q_eval's action. For example in: sample 0, I took action 0, and the max q_target value is -1; sample 1, I took action 2, and the max q_target value is -2: q_target = [[-1, 2, 3], [4, 5, -2]] So the (q_target - q_eval) becomes: [[(-1)-(1), 0, 0], [0, 0, (-2)-(6)]] We then backpropagate this error w.r.t the corresponding action to network, leave other action as error=0 cause we didn't choose it. """
# train eval network
self.cost = self.eval_model.train_on_batch(batch_memory[:, :self.params['n_features']], q_target)
self.cost_his.append(self.cost)
# increasing epsilon self.epsilon = self.epsilon + self.params['e_greedy_increment'] if self.epsilon < self.params['e_greedy'] \ else self.params['e_greedy'] self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()

def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset()
while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') env.destroy()

if __name__ == "__main__": # maze game env = Maze() eval_model = Eval_Model(num_actions=env.n_actions) target_model = Target_Model(num_actions=env.n_actions) RL = DeepQNetwork(env.n_actions, env.n_features, eval_model, target_model) env.after(100, run_maze) env.mainloop() RL.plot_cost()
复制代码


参考文献:


https://www.jianshu.com/p/10930c371cac


https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow


http://inoryy.com/post/tensorf


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/70009692


2019-12-02 16:221486

评论

发布
暂无评论
发现更多内容

5 月 28 日 - 29 日阿里云峰会视频云专场直播预告

阿里云CloudImagine

阿里云 音视频

MySQL事务处理特性的实现原理

华为云开发者联盟

MySQL 数据库 innodb 事务 隔离

博睿数据2021战略发布巡展,开辟IT运维创新路径

博睿数据

博睿数据 数据链DNA 服务可达

🍃【SpringCloud基础使用】Nacos与Gateway实现动态路由

码界西柚

nacos SpringCloud Gateway 5月日更 自定义配置

我厂与张家港市达成全面战略合作,共推数据中心和城市智能化转型

百度大脑

数据中心 城市智能化

中国呼叫中心与卓越客服产业峰会,百度智能客服再提行业创新

百度大脑

解决方案 行业创新

并发王者课-青铜7:顺藤摸瓜-如何从synchronized中的锁认识Monitor

MetaThoughts

Java 多线程 并发

现在已经卷到需要问三色标记了吗?

艾小仙

获得业内一致好评!华山版Java性能优化全栈手册“登场”

Java架构追梦

Java 阿里巴巴 架构 性能优化 华山版

答应我,别再学Swing框架了好吗?

北游学Java

Java spring swing

从零开始学习ThingJS之创建App对象

ThingJS数字孪生引擎

可视化 3D可视化 数字孪生

编曲新手可以用什么编曲软件?

奈奈的杂社

编曲 编曲宿主 编曲软件

用图数据库可视化探索 Chia Network 区块链数据

古思为

区块链 可视化 图数据库

低代码实现传统装饰企业的管理跃迁

华为云开发者联盟

低代码 华为云 计算 低代码开发 AppCube

从源码角度研究Java动态代理

叫我阿柒啊

动态代理 代理模式 rmi

appium 入门参考

37手游iOS技术运营团队

ios 测试 自动化测试 iOS Developer

服务可达,达者为先,产品发布会嘉宾精彩观点分享!

博睿数据

博睿数据 数据链DNA 服务可达

1小时内被全网疯转 29.8w 次,最终被所有大V协力封杀!

Java架构师迁哥

鸿蒙轻内核M核源码分析:数据结构之任务排序链表

华为云开发者联盟

鸿蒙 数据结构 任务排序链表 双向链表数组 鸿蒙轻内核

活动预告 _ 即构×火山引擎:泛娱乐社交音视频技术实践沙龙

ZEGO即构

工业4.0加速实现“数物相合”,可视化工厂节省时效高达85%

一只数据鲸鱼

人工智能 数据可视化 工业互联网 智慧工厂 智能生产

Bugless 异常监控系统 (iOS端)

37手游iOS技术运营团队

ios iOS Developer 崩溃分析 bugless

【玩转PDF】贼稳,产品要做一个三方合同签署,我方了!

牧小农

JVM

2021 全球技术领导力峰会 融云布道技术领导力进阶之路

融云 RongCloud

用Python在树莓派上播放音乐

IT蜗壳-Tango

5月日更

英特尔院士斯旺:由外而内重塑芯片设计

E科讯

眼观六路耳听八方还不知疲倦?数仓智能运维服务体系是怎么做到的?

华为云开发者联盟

数据库 数据仓库 监控 智能运维 数据库监控

高可用DevHa实践,告诉你生产环境0性能故障是如何做到的!

TakinTalks稳定性社区

压测 性能调优 全链路压测 系统稳定高可用 性能压测

webRTC的标准与发展

anyRTC开发者

音视频 WebRTC RTC

量化网格策略交易软件,马丁倍投策略机器人

走向机器智能时代:移动机器人的困局与创新

晨山资本

机器人 移动机器人 AMR

Tensorflow2.0实现Deep-Q-Network_语言 & 开发_Alex-zhai_InfoQ精选文章