写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:251003

评论

发布
暂无评论
发现更多内容

金三银四跳槽季,美团、字节、阿里、腾讯Java面经,终入字节

Java 程序员 架构 面试

Windows PowerShell ISE 是什么和 PowerShell 有什么区别

HoneyMoose

并发王者课-黄金1:两败俱伤-互不相让的线程如何导致了死锁僵局

MetaThoughts

Java 多线程 并发

如何拆分大型单体系统为微服务

Zhang

微服务

系统设计系列之任务队列

看山

MQ 6月日更

网络攻防学习笔记 Day42

穿过生命散发芬芳

网络攻防 6月日更

职场礼仪之坐车礼仪

石云升

6月日更 职场礼仪

Windows 使用 PowerShell 来管理另外一台 Windows 机器

HoneyMoose

22 图 |M1 和 Docker 谈了个恋爱

悟空聊架构

Mac M M1 Dock 6月日更 dokcer

栈和队列没想象中那么难

北游学Java

Java 数据结构 队列

Python——列表元素的排序

在即

6月日更

深入SpringBoot的异常处理(一)

卢卡多多

异常 SpringBoot 2 全局异常 6月日更

区块链行业的《高考志愿填报指南》

CECBC

🌏【架构师指南】带你分析认识缓存穿透/雪崩/击穿

洛神灬殇

缓存穿透 缓存击穿 缓存雪崩 6月日更

Kubernetes手记(9)- Ingress 控制器

雪雷

k8s 6月日更

JAVA对象直接输出的打印结果是什么?

加百利

Java 后端 字符串 6月日更

kubelet分析-pvc扩容源码分析

良凯尔

Kubernetes 源码分析 kubelet Ceph CSI

2021年最新阿里巴巴Java面试权威指南(泰山版)震撼来袭

Java 程序员 架构 面试 计算机

颠覆与创新,区块链将成音乐产业的下一个风口

CECBC

GitHub已霸榜!阿里技术官肝了3个月才完成的20万字Java面试手册

Java 程序员 架构 面试

给dubbo贡献源码,做梦都在修bug

捉虫大师

dubbo

读深入ES6记[五]

蛋先生DX

ES6 6月日更

小型电商微服务架构拆分

Simon

架构实战营

马丁策略量化交易系统搭建,网格量化策略系统

三步教你编写一个Neumorphism风格的小时钟

空城机

JavaScript Vue 大前端 6月日更

图解 SQL,这也太形象了吧!

xcbeyond

MySQL 6月日更

NQI国家质量基础设施“一站式”公共服务平台开发建设

源中瑞-龙先生

开发 NQI 质量基础设施“一站式”

Java Shutdown Hook 场景使用和源码分析

陈皮的JavaLib

Java 线程安全 Thread

【Flutter 专题】106 图解 AnimatedWidget & AnimatedBuilder 动画应用

阿策小和尚

Flutter 小菜 0 基础学习 Flutter Android 小菜鸟 6月日更

🌏【架构师指南】教你如何设计和规划系统架构(13条)

洛神灬殇

架构设计 架构设计原则 架构师技能 6月日更

Django组队学习Task0

IT蜗壳-Tango

IT蜗壳教学 6月日更 Datawhale

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章