写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:25966

评论

发布
暂无评论
发现更多内容

「ChatGPT最强竞品」爆火:不限量不要钱免注册!一手实测体验在此

Openlab_cosmoplat

人工智能 开源社区 openai ChatGPT

为什么FTP会随着时间的过去而变慢?

镭速

过去的90天,ODC 发生了哪些新的改变?

OceanBase 数据库

数据库 oceanbase

Chrome 浏览器的更新导致 jQuery 反复发版,只因 :has() 这个伪类

茶无味的一天

CSS jquery chrome 前端 浏览器

第五元素奏鸣曲:企业的新数据之道

脑极体

数据

支撑百万商户、千亿级调用:微盟如何通过链路设计降本40%?

TakinTalks稳定性社区

阿里云计算巢产品负责人何川:计算巢,通过数字化工具加速企业数字原生

云布道师

云计算 计算巢

使用appuploader工具发布证书和描述性文件教程

雪奈椰子

一文掌握 Go 文件的写入操作

陈明勇

Go golang 后端 文件写入 三周年连更

小程序生命周期

程序员海军

三周年连更

“930大促”日活增速超40% ,哈啰如何用预案高效应急?

TakinTalks稳定性社区

低代码起势,程序员闷头开发的日子结束了

引迈信息

低代码 快速开发 JNPF

杨志丰:一文详解,什么是单机分布式一体化?

OceanBase 数据库

数据库 oceanbase

求助 iOS 分发的最佳实践

雪奈椰子

Java Stream常见用法汇总,开发效率大幅提升

程序员大彬

Java java8

探究光明源智慧公厕系统的科技创新与管理优势

光明源智慧厕所

智慧城市

华为云全流程等保服务,帮助企业守护信息安全

科技怪授

一篇文章了解SoapUI接口测试的全部流程

Liam

测试 接口测试 测试工具 API 测试

Django笔记九之model查询filter、exclude、annotate、order_by

Hunter熊

Python django alias annotate order_by

为企业发展赋能,华为云网站安全解决方案,保护企业网络安全

科技怪授

未来源码|什么是数据集成?超全的SeaTunnel 集成工具介绍

MobTech袤博科技

华为云网站安全解决方案,助力企业安心稳步发展

科技说

软件测试/测试开发丨Python 算法与数据结构面试题

测试人

软件测试 面试题 测试开发

阿凡达Sun4.0众筹开发系统技术搭建

薇電13242772558

NFT

多云之下,京东云的降本增效之道

人称T客

推平“知识高峰”,AI将如何影响我们的学习?

Alter

糟了,生产环境数据竟然不一致,人麻了!

冰河

MySQL 数据库 数据一致性 数据存储

一文读懂注解的底层原理

老周聊架构

三周年连更

瑞云科技副总经理黄金进受邀出席2023广东超聚变生态伙伴大会并作主题演讲

3DCAT实时渲染

元宇宙 实时渲染 云流化 3D实时云渲染 云化XR

我决定给 ChatGPT 做个缓存层 >>> Hello GPTCache

Zilliz

Zilliz ChatGPT LLM gptcache

华为云网站安全方案为企业数据保驾护航

科技说

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章