HarmonyOS开发者限时福利来啦!最高10w+现金激励等你拿~ 了解详情
写点什么

无痛的增强学习入门:差分时序法

  • 2017-11-02
  • 本文字数:3303 字

    阅读完需:约 11 分钟

系列导读:《无痛的增强学习入门》系列文章旨在为大家介绍增强学习相关的入门知识,为大家后续的深入学习打下基础。其中包括增强学习的基本思想,MDP 框架,几种基本的学习算法介绍,以及一些简单的实际案例。

作为机器学习中十分重要的一支,增强学习在这些年取得了十分令人惊喜的成绩,这也使得越来越多的人加入到学习增强学习的队伍当中。增强学习的知识和内容与经典监督学习、非监督学习相比并不容易,而且可解释的小例子比较少,本系列将向各位读者简单介绍其中的基本知识,并以一个小例子贯穿其中。

7 差分时序法

7.1 蒙特卡罗的方差问题

上一节我们介绍了蒙特卡罗法,它可以比较好地解决无模型场景下的蛇棋问题。当然,它也存在着一定的不完美,这些不完美体现在很多方面,接下来我们就来看一看其中的一个问题——估计值的方差问题。

前面我们已经知道,蒙特卡罗是用一系列的模拟游戏得到一些 episode 的回报信息,然后用这些回报信息求平均,就可以得到估计的价值函数。这个方法从理论上讲是没有问题的,它能够成立的根本在于概率论中的大数定理。大数定理也是没有问题的,可是大数定理只说明了收敛性,并没有说明收敛的速度。很显然,如果采样序列的方差比较大,那么想要让它收敛就需要更长的时间,如果采样序列的方差较小,那么收敛的速度也会相应地降低。

基于上面的分析,我们就来看看前面的蒙特卡罗法存不存在方差方面的问题。我们用上一节的方法进行计算,同时保存 state=50,action=1 的 return 值,将所有的值综合起来做成一个直方图,如图 7-1 所示。

图 7-1 蒙特卡罗法 every-visit 的 return 收集直方图

可以看出,return 的跨度还是比较大,这样模型就不太容易收敛。

当然,跨度大和其他问题也有关系,其中一个关系就是 episode 中的采样频率。上一节采用的方法叫做"every-viist",也就是说 episode 中的每一个 state-action 对都会参与到计算当中,这样就为统计带来了困难。因为在蛇棋这个游戏中,由于有梯子的存在,某一个位置可能被走过两次。比方说例子中的位置“50”,我们在计算的过程中没有区分第一次到达“50”和第二次到达“50”,而且将这两个 return 直接加在了一起,由于这两次的 return 一定有差异,因此这样的统计方式会让方差增大,从而难以收敛。

那么该如何应对这个问题呢?一个解决方案是把"every-visit"方法换掉,改成"first-visit"法。所谓的"first-visit"法就是只统计第一次出现的状态,对于后面出现的同一状态则不再统计。从方法的理念来说,同一个 episode 内因多次出现同一状态而造成方差增大的问题将得到解决。下面就来看看"first-visit"的结果:

图 7-2 蒙特卡罗法 first-visit 的 return 收集直方图

从结果来看,first-visit 的效果稍有提高,但是并没有明显的提高。因为即使只取一个不重复的状态,我们仍然会遇到不同情况的 first-visit,每种之间的数值差距依旧很大。

既然蒙特卡罗法存在着这样一些遗憾,那么我们就来看看接下来介绍的方法:差分时序法(Temporal Difference)。

7.2 差分时序

差分时序法是一种结合了蒙特卡罗法和动态规划法的方法。从算法的主体结构来看,它同蒙特卡罗法类似,同样通过模拟 episode 的方式进行求解;从算法的核心思想来说,它有用到了增强学习中的经典公式——Bellman 公式进行自迭代更新。

前面提到状态 - 行动价值函数的 Bellman 公式的形式为:

\(q(s,a)=\sum_{s’}p(s’|s,a)[R + \sum_{a’} \pi(a|s’)q(s’,a’)]\)

那么利用蒙特卡罗的方法,我们就可以将公式变为:

\(q(s,a)=\frac{1}{N} \sum_{i=1}^N [R(s’_i)+q(s’_i,a’_i)]\)

当然,这个公式还可以变成前面蒙特卡罗法中实际应用的形式:

\(q_t(s,a)=q_{t-1}(s,a)+\frac{1}{N}[R(s’)+q(s’,a’)-q_{t-1}(s,a)]\)

从这个公式我们可以看出这个方法和蒙特卡罗、动态规划的关系来。下面我们就来实现这个方法,我们首先要实现的方法被称为"SARSA",这个名字看上去有些奇怪,其实它来自于这个方法的五个关键因子:S(待求状态),A(待求行动),R(模拟得到的奖励),S(模拟进入的下一个状态),A(模拟中采取的下一个行动)。

7.3 On-policy:SARSA

接下来就来实现这个算法,其实它的实现过程也比较简单,这里只展现评估的过程:

复制代码
for i in range(episode_num):
env.start()
state = env.pos
prev_act = -1
while True:
act = self.policy_act(state)
reward, state = env.action(act)
if prev_act != -1:
return_val = reward + (0 if state == -1 else self.value_q[state][act])
self.value_n[prev_state][prev_act] += 1
self.value_q[prev_state][prev_act] += (return_val - \
self.value_q[prev_state][prev_act]) / \
self.value_n[prev_state][prev_act]
{1}
prev_act = act
prev_state = state
{1}
if state == -1:
break
{1}

那么这个代码的实现效果如何呢?我们同样写一段测试的代码。

复制代码
def td_demo():
np.random.seed(0)
env = Snake(10, [3,6])
agent = MonteCarloAgent(100, 2, env)
with timer('Timer Temporal Difference Iter'):
agent.td_opt()
print 'return_pi={}'.format(eval(env,agent))
print agent.policy
agent2 = TableAgent(env.state_transition_table(), env.reward_table())
with timer('Timer PolicyIter'):
agent2.policy_iteration()
print 'return_pi={}'.format(eval(env,agent2))
print agent2.policy

为了试验效果,我们将 SARSA 算法中的迭代轮数增加到 2000,最终的结果为:

复制代码
Timer Temporal Difference Iter COST:3.8632440567
return_pi=80
[0 0 0 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]
policy evaluation proceed 94 iters.
policy evaluation proceed 62 iters.
policy evaluation proceed 46 iters.
Iter 3 rounds converge
Timer PolicyIter COST:0.330899000168
return_pi=84
[0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0
0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]

从结果可以看出,SARSA 算法的效果并不够好,当然我们还没有加入更多的策略来提升效果。不过看上去它的结果还不及蒙特卡罗,那么它和蒙特卡罗相比有什么优势呢?我们来看看前面提到的方差问题,再来看一看 SARSA 算法的结果图:

图 7-3 TD 法的 return 收集直方图

可以看出,这个方法的数值变动和前面相比要更为稳定一些,跨度比之前要小一些。

那么为什么 TD 系列的方法会在方差方面控制得更好一些呢?答案就在它的更新公式上。蒙特卡罗的计算方法由于使用了精确的 return,所以在对价值的估计上更精确一些,但是同时它要一个序列的信息,而序列的信息存在更多的波动,所以方差会比较大;而 TD 方法只考虑了一步的计算,其余的计算均使用了之前的估计,所以当整体系统没有达到最优时,这样的估计都是存在偏差的,但是由于它只估计了一步,所以它在估计值方面受到的波动比较少,因此带来的方差也会相应减少许多。

所以前人们发现,蒙特卡罗法和 TD 法象征着两个极端——一个为了追求极小的误差而放松了方差,一个为了缩小方差而放松了误差。这个问题仿佛回到了机器学习的经典问题——bias 和 variance 的权衡问题。

那么,除了 SARSA 这一种方法,还有没有别的算法呢?最后一节我们再来看看 TD 的另一种算法。

作者介绍

冯超,毕业于中国科学院大学,猿辅导研究团队视觉研究负责人,小猿搜题拍照搜题负责人之一。2017 年独立撰写《深度学习轻松学:核心算法与视觉实践》一书,以轻松幽默的语言深入详细地介绍了深度学习的基本结构,模型优化和参数设置细节,视觉领域应用等内容。自 2016 年起在知乎开设了自己的专栏:《无痛的机器学习》,发表机器学习与深度学习相关文章,收到了不错的反响,并被多家媒体转载。曾多次参与社区技术分享活动。

2017-11-02 17:502809

评论

发布
暂无评论
发现更多内容

【云原生 | 从零开始学Kubernetes】一、kubernetes到底是个啥

泡泡

云计算 云原生 k8s 9月月更

监控系统的阶段建设

穿过生命散发芬芳

监控系统 9月月更

【数据结构】五分钟带你了解及自定义有向图

迷彩

数据结构 算法 无向图 9月月更 有向图

面对全新的编程语言,这些思路可以帮助你察觉漏洞

网络安全学海

黑客 网络安全 信息安全 渗透测试 漏洞利用

kube-prometheus 监控系统使用与总结

CTO技术共享

ESP32-C3入门教程 基础篇(四、I2C总线 — 与SHT21温湿度传感器通讯)

矜辰所致

I2C I2C协议 ESP32-C3 9月月更

4 分钟优化 Fetch 函数写法~

掘金安东尼

前端 9月月更

Ubuntu服务器上部署Kubernetes集群

CTO技术共享

2022-09-19:给定字符串 S and T,找出 S 中最短的(连续)子串 W ,使得 T 是 W 的 子序列 。 如果 S 中没有窗口可以包含 T 中的所有字符,返回空字符串 ““。 如果有不

福大大架构师每日一题

算法 rust 福大大

springboot搭建基于minio的高性能存储

CTO技术共享

SSM整合(功能模块的开发)

十八岁讨厌编程

Java ssm 后端开发 9月月更

Web3.0杂谈-#001(47/100)

hackstoic

Web3.0

开发者有话说|时间过得真快,我也是一个“奔三”的人了

武师叔

个人成长

2022-09-20:以下go语言代码输出什么?A:8 8;B:8 16;C:16 16;D:16 8。 package main import ( “unsafe“ “fmt“ )

福大大架构师每日一题

golang 福大大 选择题

通过爬虫爬取一些图片

吉师职业混子

9月月更

数据平台发展史-从数据仓库数据湖到数据湖仓

明哥的IT随笔

hadoop spark 数据仓库 数据湖 湖仓一体

开发者有话说|情分 or 本分

卷卷龙

个人成长 职场 PUA

概述构建应用智能运维系统的核心能力

阿泽🧸

智能运维 9月月更

[SSM]SSM整合①(整合配置)

十八岁讨厌编程

Java 后端开发 9月月更

面试突击85:为什么事务@Transactional会失效?

王磊

Java 面试

C++学习---cstdio的源码学习分析04-创建临时文件函数tmpfile

桑榆

c++ 源码阅读 9月月更

《简单记个笔记》之表单标签加CSS选择器

吉师职业混子

9月月更

开发者有话说 | 一个普通人的前端职业成长之路

范文杰

个人成长

跟着卷卷龙一起学Camera--CCM

卷卷龙

ISP 9月月更

RAID(独立冗余磁盘阵列)

阿柠xn

Linux 运维 操作系统 raid 9月月更

数字化转型新抓手:一看就懂的《企业应用现代化行动指南》(附下载)

York

容器 微服务 云原生 应用现代化

ESP32-C3入门教程 基础篇(三、UART模块 — 与Enocean无线模块串口通信)

矜辰所致

ESP32-C3 9月月更 UART

史上最详细vue的入门基础

楠羽

Vue 笔记 9月月更

40 岁程序员会有哪些肺腑之言?这篇文章告诉你

宇宙之一粟

学习 程序员 读书感悟 9月月更

SSM整合(接口测试)

十八岁讨厌编程

Java SSM框架 后端开发 9月月更

《简单记个笔记》之部分CSS选择器介绍

吉师职业混子

9月月更

无痛的增强学习入门:差分时序法_语言 & 开发_冯超_InfoQ精选文章