QCon北京「鸿蒙专场」火热来袭!即刻报名,与创新同行~ 了解详情
写点什么

无痛的增强学习入门:Q-Learning

  • 2017-11-19
  • 本文字数:4428 字

    阅读完需:约 15 分钟

系列导读:《无痛的增强学习入门》系列文章旨在为大家介绍增强学习相关的入门知识,为大家后续的深入学习打下基础。其中包括增强学习的基本思想,MDP 框架,几种基本的学习算法介绍,以及一些简单的实际案例。

作为机器学习中十分重要的一支,增强学习在这些年取得了十分令人惊喜的成绩,这也使得越来越多的人加入到学习增强学习的队伍当中。增强学习的知识和内容与经典监督学习、非监督学习相比并不容易,而且可解释的小例子比较少,本系列将向各位读者简单介绍其中的基本知识,并以一个小例子贯穿其中。

8 Q-Learning

8.1 Q-Learning

上一节我们介绍了 TD 的 SARSA 算法,它的核心公式为:

\(q_t(s,a)=q_{t-1}(s,a)+\frac{1}{N}[R(s’)+q(s’,a’)-q_{t-1}(s,a)]\)

接下来我们要看的另一种 TD 的算法叫做 Q-Learning,它的基本公式为:

\(q_t(s,a)=q_{t-1}(s,a)+\frac{1}{N}[R(s’)+max_{a’} q(s’,a’)-q_{t-1}(s,a)]\)

两个算法的差别只在其中的一个项目,一个使用了当前 episode 中的状态 - 行动序列,另一个并没有,而是选择了数值最大的那个。这就涉及到两种不同的思路了。我们先暂时不管这个思路,来看看这个算法的效果。

首先还是实现代码:

复制代码
def q_learning(self):
iteration = 0
while True:
iteration += 1
self.q_learn_eval()
ret = self.policy_improve()
if not ret:
break

对应的策略评估代码为:

复制代码
def q_learn_eval(self):
episode_num = 1000
env = self.snake
for i in range(episode_num):
env.start()
state = env.pos
prev_act = -1
while True:
act = self.policy_act(state)
reward, state = env.action(act)
if prev_act != -1:
return_val = reward + (0 if state == -1 else np.max(self.value_q[state,:]))
self.value_n[prev_state][prev_act] += 1
self.value_q[prev_state][prev_act] += (return_val - \
self.value_q[prev_state][prev_act]) / \
self.value_n[prev_state][prev_act]
prev_act = act
prev_state = state
if state == -1:
break

实际的运行代码省略,最终的结果为:

复制代码
Timer Temporal Difference Iter COST:4.24033594131
return_pi=81
[0 0 0 0 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
0 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0]
policy evaluation proceed 94 iters.
policy evaluation proceed 62 iters.
policy evaluation proceed 46 iters.
Iter 3 rounds converge
Timer PolicyIter COST:0.318824052811
return_pi=84
[0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0
0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]

可以看出,Q-Learning 的方法和策略迭代比还是有点差距,当然通过一些方法是可以把数字提高的,但是这不是本文的重点了。不过从结果来看,Q-Learning 比 SARSA 要好一些。

8.2 On-Policy 与 Off-Policy

实际上这两种算法代表了两种思考问题的方式,分别是 On-Policy 和 Off-Policy。On-Policy 的代表是 SARSA,它的价值函数更新是完全根据 episode 的进展进行的,相当于在之前的基础上不断计算做改进;Off-Policy 的代表是 Q-Learning,它的价值函数中包含了之前的精华,所以可以算是战得更高,看得更远。当然也不能说 Off-Policy 就绝对好,Off-Policy 还会带来一些自己的问题,总之我们需要具体问题具体分析。

这里最后再做一点对两大类算法的总结。其实如果游戏有终点,模型不复杂(像我们的蛇棋),蒙特卡罗法还是有绝对的优势,但是蒙特卡罗的软肋也比较明显,它需要完整的 episode,产生的结果方差大,这对于一些大型游戏并不适合,所以在真正的产品级规模的增强学习上,TD 的身影还是更多一些。

8.3 展望

由于篇幅的限制,到这里我们实际上就完成了对增强学习的一些基础内容的介绍,但是了解了这些内容,还不足以完成更加复杂的任务,下面就让我们简单了解一些更为高级的内容。

8.3.1 Function Approximation

前面我们提到的所有的增强学习算法都是基于表格的,表格的好处在于我们可以独立地考虑每一个状态、每一个行动的价值,不需要把很多状态汇集在一起考虑,然而在很多实际问题中,状态的数目非常多,甚至状态本身就不是连续的,那么采用表格来进行计算就显得不太现实,于是研究人员就开始研究一些更为合适的表达方式,这时候,我们的机器学习就登场了,我们可以建立一些模型,并用一个模型来表示状态、行动到价值函数的关系。

我们令状态为\(s \in S\),行动为\(a \in A\),最终的价值函数为\(v \in R\),那么我们要建立这样一个映射:

\(S \times A \rightarrow R\)

这样我们就把增强学习中的一个子问题转换成了监督学习,而监督学习是大家熟悉的套路,所以做起来就更加的得心应手了。实际上这就是一个回归问题,于是所有可以用于回归问题的模型都可以被我们用上。比方说线性回归,支持向量回归,决策树,神经网络。因为现在深度学习十分火,于是我们可以用一个深层的神经网络完成这个映射。

模型函数和监督学习使我们又了从另一个角度观察增强学习的可能。此时的模型要考虑 bias 和 variance 的权衡,要考虑模型的泛化性,这些问题最终都会映射到增强学习上。模型的表示形式也有很多种,前面给出的\(S \times A \rightarrow R\) 只是模型表示的一种形式,我们还可以表示成\(S \rightarrow A,R\) 这样的形式。对于第一种表示形式,由于不同的行动将作用在同一个状态下(或者不同的行动被同一种行动操纵),模型中的参数表示必然会存在重复,那么为了更好地共享参数,我们可以使用后面一种形式来表示。

说完了模型的形式,那么接下来就来看看模型的目标函数。最简单的目标函数自然是平方损失函数:

\(obj=\frac{1}{2}\sum_i^N(v’_i(s,a;w) - v_i)^2\)

其中的\(v_i’\) 表示模型估计的价值,而\(v_i\) 表示当前的真实价值,定义了目标函数,下面我们就可以利用机器学习经典的梯度下降法求解了:

\(\frac{\partial obj}{\partial w}=\sum_i^N (v_i’ -v_i) \frac{\partial v_i’}{\partial w}\)

这个公式是不是看上去很眼熟?实际上如果

\(\frac{\partial v_i’}{\partial w}=1\)

,那么模型的最优解就等于

\(v_i’=\frac{1}{N}\sum_i^N v_i\)

也就是所有训练数据的平均值,这个结果和表格版的计算公式是一致的,也就是说表格版的算法实际上也是一种模型,只不过它的梯度处处为 1。

关于模型更多的讨论,我们也可以取阅读各种论文,论文中对这里面的各个问题都有深入的讨论。

8.3.2 Policy Gradient

另外一个方向则是跳出已有的思维框架,朝着另一种运算方式前进。这种方法被称为 Policy Gradient。在这种方法中,我们不再采用先求价值函数,后更新策略的方式,而是直接针对策略进行建模,也就是\(\pi(a|s)\) 建模。

我们回到增强学习问题的源头,最初我们希望找到一种策略使得 Agent 的长期回报最大化,也就是:

\(max E_{\pi}[v_{\pi}(s_0)]\)

这个公式求解梯度可以展开为:

\(\nabla v_{\pi}(s_0)=E_{\pi}[\gamma^t G_t \nabla log \pi(a_t|s_t;w)]\)

更新公式为:

\(\theta_{t+1}=\theta_t + \alpha G_t \nabla log\pi(a_t|s_t;w)\)

这其中的推导过程就省略不谈了。得到了目标函数的梯度,那么我们就可以直接根据梯度去求极值了。我们真正关心的部分实际上是里面的那个求梯度的部分,所以当整体目标达到最大时,策略也就达到了最大值,因此目标函数的梯度可以回传给模型以供使用。由于 log 函数不改变函数的单调性,对最终的最优策略步影响,于是我们可以直接对\(log \pi(a|s;w)\) 进行建模。这个方法被称为 REINFORCE。

当然,这个算法里面还存在着一些问题。比方说里面的\(G_t\) 同样是用蒙特卡罗的方法得到的。它和蒙特卡罗方法一样,存在着高方差的问题,为了解决高方差的问题,有人提出构建一个 BaseLine 模型,让每一个\(G_t\) 减掉 BaseLine 数字,从而降低了模型的方差。更新公式为:

\(\theta_{t+1}=\theta_t + \alpha (G_t-b(S_t)) \nabla log\pi(a_t|s_t;w)\)

8.3.3 Actor-Critic

既然有基于 Policy-Gradient 的蒙特卡罗方法,那么就应该有 TD 的方法。这种方法被称为 Actor-Critic 方法。这个方法由两个模型组成,其中 Actor 负责 Policy 的建模,Critic 负责价值函数的建模,于是上面基于 BaseLine 的 REINFORCE 算法的公式就变成了:

\(\theta_{t+1}=\theta_t + \alpha (R_{t+1}+\gamma \hat{v}(s_{t+1};w)-\hat{v}(S_t;w)) \nabla log\pi(a_t|s_t;w)\)

这样每一轮的优化也就变成了先根据模拟的 episode 信息优化价值函数,然后再优化策略的过程。

一般来说,上面公式中的\((R_{t+1}+\gamma \hat{v}(s_{t+1};w)-\hat{v}(S_t;w))\) 表示向前看一步的价值和就在当前分析的价值的差。下棋的人们都知道“下棋看三步”的道理,也就是说有时面对一些游戏的状态,我们在有条件的情况下可以多向前看看,再做出决定,所以一般认为多看一步的价值函数的值会更高更好一些,所以上面的那一项通常被称为优势项(Advantage)。

在前面的实验中我们也发现,对于无模型的问题,我们要进行大量的采样才能获得比较精确的结果,对于更大的问题来说,需要的采样模拟也就更多。因此,对于大量的计算量,如何更快地完成计算也成为了一个问题。有一个比较知名且效果不错的算法,被称为 A3C(Asynchronous Advantage Actor-Critic) 的方法,主要是采用了异步的方式进行并行采样,更新参数,这个方法也得到了比较好的效果。

以上就是《无痛的增强学习》入门篇,我们以蛇棋为例介绍了增强学习基础框架 MDP,介绍了模型已知的几种方法——策略迭代、价值迭代、泛化迭代法,介绍了模型未知的几种方法——蒙特卡罗、SARSA、Q-Learning,还简单介绍了一些更高级的计算方法,希望大家能从中有所收获。由于作者才疏学浅,行文中难免有疏漏之处,还请各位谅解。

作者介绍

冯超,毕业于中国科学院大学,猿辅导研究团队视觉研究负责人,小猿搜题拍照搜题负责人之一。2017 年独立撰写《深度学习轻松学:核心算法与视觉实践》一书,以轻松幽默的语言深入详细地介绍了深度学习的基本结构,模型优化和参数设置细节,视觉领域应用等内容。自 2016 年起在知乎开设了自己的专栏:《无痛的机器学习》,发表机器学习与深度学习相关文章,收到了不错的反响,并被多家媒体转载。曾多次参与社区技术分享活动。

2017-11-19 17:083816

评论

发布
暂无评论
发现更多内容

TDengine在弘源泰平量化投资中的实践

TDengine

数据库 tdengine 开源 时序数据库

ApacheCon Asia 2022 强势来袭!16 大专题等你投稿!

阿里巴巴云原生

开源 云原生 活动

Docker学习记录

ZuccRoger

5月月更

架构7期模块1作业

Elvis FAN

架构实战营

场景实践 | 如何使用融云超级群构建游戏社区

融云 RongCloud

携手 TDengine,释普科技升级实验室仪器、监控智能方案

TDengine

数据库 tdengine 开源 物联网

争夺存量用户关键战,助力企业构建完美标签体系丨01期直播回顾

袋鼠云数栈

大数据 数据中台

互联网公司目标管理OKR实践落地与反思

laofo

互联网 OKR 研发效能 绩效管理 快手

ShardingSphere 在东南亚|与科技保险公司 Fuse 的技术融合

SphereEx

Apache 开源 ShardingSphere SphereEx 数据库·

时间序列化数据库选型?时序数据库的选择?

TDengine

数据库 tdengine

[Day41]-[回溯]-全排列

方勇(gopher)

LeetCode 回溯算法 数据结构算法

直播预约|数据指标体系如何搭建才最有效,从0到1带你快速入门

袋鼠云数栈

大数据 数据中台

为什么说 MongoDB 和 HBase 不适用于汽车行业的时序数据处理?

TDengine

数据库 tdengine 开源 时序数据库

TDengine 在酷哞哞的应用

TDengine

数据库 tdengine 开源 物联网

互联网公司目标管理OKR和绩效考核误区

laofo

OKR 研发效能 互联网公司 快手 绩效考核 GRAD

学生管理系统架构设计图

Justin1024

B站S11破亿直播在线稳定性保障秘籍——演讲实录

TakinTalks稳定性社区

混沌工程 系统稳定性 全链路压测 安全生产

【刷题第12天】58. 最后一个单词的长度

白日梦

5月月更

携手数字人、数字空间、XR平台,阿里云与伙伴共同建设“新视界”

阿里云弹性计算

XR 数字人 视觉计算 瑶台

面试突击49:说一下 JUC 中的 Exchange 交换器?

王磊

Java java面试

万字长文:手把手教你实现一套高效的IM长连接自适应心跳保活机制

JackJiang

TCP 网络编程 即时通讯 im开发 心跳保活

要做研发高手,就是必须能看英文、写英文

TDengine

数据库 tdengine 开源

私有化的IM即时通讯平台,企业首选的沟通工具

BeeWorks

Reactor百万连接的并发

C++后台开发

reactor 高并发 epoll Linux服务器开发 C++后台开发

加入MOVE,一起体验Move2Earn的运动乐趣

BlockChain先知

时序数据库的集群方案?

TDengine

数据库 tdengine 开源

敏捷已死

方云AI研发绩效

netty系列之:在netty中实现线程和CPU绑定

程序那些事

Java Netty 程序那些事 5月月更

互联网大厂研发效能团队的需求管理

laofo

互联网 DevOps cicd 研发效能 CI/CD

通过 Amazon API Gateway 和 Amazon Lambda 实现基于 Restful API 的 CloudFront Distribution 复制/克隆功能

亚马逊云科技 (Amazon Web Services)

Lambda Gateway

开启分布式应用性能观测(APM)

观测云

可观测性 可观测

无痛的增强学习入门:Q-Learning_语言 & 开发_冯超_InfoQ精选文章