产品战略专家梁宁确认出席AICon北京站,分享AI时代下的商业逻辑与产品需求 了解详情
写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:25946

评论

发布
暂无评论
发现更多内容

【Python实战】Python对中国500强排行榜数据进行可视化分析

BROKEN

三周年连更

使用 Kubectl Patch 命令更新资源

Se7en

云原生

2023年3月用户体验GX评测:国有行及股份行持续领跑,城商行及农商行农信社积极探索实践用户体验体系搭建

易观分析

金融 银行

基于 Amazon SageMaker 构建细粒度情感分析应用

亚马逊云科技 (Amazon Web Services)

机器学习 Amazon SageMaker

【深入浅出Spring原理及实战】「源码调试分析」深入源码探索Spring底层框架的的refresh方法所出现的问题和异常

洛神灬殇

spring NPE 源码剖析 4月日更 问题分析

数字未来:世界正走向新的“破茧时刻”

脑极体

华为

Matlab实现光伏发电电池模型

Shine

三周年连更

python时间序列预测之Holt-Winters

AIWeker

Python 机器学习 时间序列 三周年连更

ChatGPT辅助编写自动化测试

QE_LAB

单元测试 自动化测试 接口测试 测试技术 ChatGPT

如何在 Linux 中查找文件所有者?

wljslmz

三周年连更

Django笔记十五之in查询及date日期相关过滤操作

Hunter熊

Python django 日期

百度平地起“雷”,突然爆出的QPS数据意味着什么?

脑极体

大模型

关于容器云的三种网络设计

穿过生命散发芬芳

容器云 三周年连更

音视频八股文(4)--ffmpeg常见命令(3)

福大大架构师每日一题

音视频 ffmpeg

新手如何学好Zbrush3D建模?

Finovy Cloud

3D软件

当⻉借⼒阿⾥云落地云原⽣架构转型,运维降本、效率稳定性双升

阿里巴巴云原生

阿里云 云原生 云原生架构

算法刷题-移除元素、分数到小数、整数转罗马数字

共饮一杯无

数据结构 算法 三周年连更

重构这件“小”事儿 | 得物技术

得物技术

2022-04-23:给定你一个整数数组 nums 我们要将 nums 数组中的每个元素移动到 A 集合 或者 B 集合中 使得 A 集合和 B 集合不为空,并且 average(A) == aver

福大大架构师每日一题

golang 算法 rust

如何使用 Java 将 JSON 文件读取为字符串?这三种方法很管用!

Java架构历程

三周年连更

各行业常见的业务指标汇总(数据分析常用数据指标)

Data 探险实验室

数据分析 数据分析师 数据指标 指标中台; 数据分析 指标洞察

缓存的处理步骤

阿泽🧸

缓存 三周年连更

Shell脚本从入门到精通

袁袁袁袁满

三周年连更

今天,飞桨公众号六岁啦!

飞桨PaddlePaddle

飞桨PaddlePaddle

TypeScript Module

程序员海军

三周年连更

Docgeni 2.1 正式发布

PingCode研发中心

软件开发 Docgeni

世界读书日|华为阅读联合40余家伙伴推出精品书单

最新动态

CnosDB成为首个支持sqllogictest的时序数据库,稳定性与可靠性再升级

CnosDB

数据库 开源 时序数据库 CnosDB

学会 Go select 语句,轻松实现高效并发

陈明勇

Go golang 高并发 select 三周年连更

测试需求平台8-Arco组件实现产品增改需求

MegaQi

测试平台开发 三周年连更 AcroVue

【已结束】直播预告|传统 PvE 游戏 ∕ 开房间 PvP 游戏的云原生架构升级

阿里巴巴云原生

阿里云 云原生 游戏

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章