写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:25983

评论

发布
暂无评论
发现更多内容

全新视觉,升维体验!全栈可观测中心嘉为鲸眼产品全新体验升级

嘉为蓝鲸

可观测 自动化运维 嘉为蓝鲸

兴业证券打造更“自然”的数字人,火山语音提供技术支持

科技热闻

火山引擎DataLeap:3个关键步骤,复制字节跳动一站式数据治理经验

字节跳动数据平台

大数据 数据治理 数据研发 实际应用 企业号 2 月 PK 榜

南宁等级测评机构有几家?分别是哪几家?

行云管家

等保 南京 等级保护 等级

嘉为蓝鲸携手麒麟软件共建国产化一站式DevOps解决方案

嘉为蓝鲸

DevOps 自动化运维 嘉为蓝鲸

状态机的概念与设计

timerring

FPGA

使用 NineData GUI 创建与修改 ClickHouse 表结构

NineData

MySQL 分布式数据库 Clickhouse Dbeaver NineData

直播预告丨 立即解锁 ALB Ingress 高级特性

阿里巴巴云原生

阿里云 容器

成熟的自动化运维平台是怎样练成的?

嘉为蓝鲸

自动化运维 嘉为蓝鲸

Apipost自动化测试功能概述

不想敲代码

自动化测试 测试自动化 apipost

GaussDB(DWS)现网案例:collation报错

华为云开发者联盟

数据库 后端 华为云 企业号 2 月 PK 榜 华为云开发者联盟

OpenInfra峰会议程已公布,特色主题演讲,百余场专题会议等你来参与!

Geek_2d6073

【等保小知识】过等保后可以收到哪些资料?

行云管家

等保 等级保护 过等保

使用 QuTrunk+Amazon Deep Learning AMI(TensorFlow2)构建量子神经网络

亚马逊云科技 (Amazon Web Services)

深度学习 量子计算

优秀实践案例征集火热开启,快来投稿!

Apache RocketMQ

消息列队

统一观测丨如何使用Prometheus 实现性能压测指标可观测

阿里巴巴云原生

阿里云 云原生 Prometheus 压测

【活动报名】re:Invent - AI 应用助力企业构建数字战略

亚马逊云科技 (Amazon Web Services)

新思科技解读2023年软件安全行业六大趋势

InfoQ_434670063458

新思科技 软件安全

嘉为科技蝉联信创工委会“卓越贡献成员”荣誉称号

嘉为蓝鲸

自动化运维 嘉为蓝鲸

MASA Stack 1.0 发布会讲稿 —— 产品篇

MASA技术团队

.net 云原生 MASA MASA Blazor

百度APP iOS端内存优化-原理篇

百度Geek说

ios 内存 企业号 2 月 PK 榜

基于 Kubernetes 的企业级大数据平台,EMR on ACK 技术初探

阿里巴巴云原生

阿里云 容器 云原生

OpenHarmony 3.2 Beta多媒体系列——视频录制

OpenHarmony开发者

OpenHarmony

聊一聊,我对DDD的关键理解

阿里技术

DDD

记录一次还算优雅的代码设计

京东科技开发者

线程 cpu 优雅 代码设计 企业号 2 月 PK 榜

给 Databend 添加 Aggregate 函数 | 函数开发系例二

Databend

我的快速调优线上服务器CPU利用率通用办法,震惊面试官

KINDLING

Java cpu 服务器 性能调优 ebpf

提升软件质量?为什么不试试华为云CodeArts Check

华为云开发者联盟

云计算 华为云 企业号 2 月 PK 榜 华为云开发者联盟

物联网平台选型葵花宝典:盘点开源、SaaS及通用型平台的优劣对比

AIRIOT

物联网 物联网平台选型 平台选型

直播预约|数据库掌门人论坛召开,共谋中国数据库生态发展新路径

镜舟科技

数据库 大数据 开源

从一个Demo说起Zookeeper服务端源码

宋小生

zookeeper

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章