论文解读:
常见的并行训练的模式有两种:
同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。
完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。
存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。
本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。
核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。
代码实现
参考:https://blog.csdn.net/li57681522/article/details/87920210
# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=1
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.ops import state_ops, variables, variable_scope
from tensorflow.python.training import session_run_hook
# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
'Steps to validate and print loss')
# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate
# @tf_export("train.SemiSyncRunHook")
# class SemiSyncRunHook(session_run_hook.SessionRunHook):
class SemiSyncRunHook(tf.train.SessionRunHook):
"""Run by SSP."""
def __init__(self, index, worker_count, staleness=10):
"""Initializes a `SemiSyncRunHook`.
Args:
index: work index
worker_count: number of workers
staleness:
"""
if index >= worker_count:
print("worker index {} is bigger than worker_count {}".format(index, worker_count))
return
self._const_max_test_step = 10000
self._last_step = 0 # 上一次wait的步骤数
self._last_time = self._now_time() # 上一次wait的时间
self._index = index
self._staleness = staleness
self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关
self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count):
worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i))
self._worker_steps.append(worker_step)
if i == index:
self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self):
return time.time()
def after_create_session(self, session, coord):
session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context):
run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数
return None
def after_run(self, run_context, run_values):
while True:
# 1.获取所有worker的训练步骤数
all_worker_steps = run_context.session.run(self._worker_steps)
# print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break;
if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness:
diff_step = all_worker_steps[self._index] - self._last_step
if diff_step / self._const_max_test_step > 1:
self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新
self._last_step = all_worker_steps[self._index]
self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行
# print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time))
else:
break
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="reminder")
y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度
train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10)
hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./ssp_saved_model",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
train_x = np.random.randn(1)
train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10
_, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y})
if step % steps_to_validate == 0:
w_, b_ = mon_sess.run([w, b])
print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))
if __name__ == "__main__":
tf.app.run()
参考文献:
More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/69665236
更多内容推荐
毕设:设计电商秒杀系统
商品:每个品类不超过 20 个商品,目前做了 10 个品类,挑选选品各大电商平台上畅销和好评的商品进行销售;
2022-04-19
百度百舸·AI 异构计算平台,加速自动驾驶模型迭代
数据访问性能提升了 5 倍,自动驾驶典型模型训练性能最高提升 391%,典型模型推理延迟最高降低了 90%,模型仿真成本降低了 60%。
2023-01-05
10、JVM 调优实战 - 堆栈优化、吞吐量与响应时间优先策略
2023-09-26
iOS MachineLearning 系列(18)—— PoseNet,DeeplabV3 与 FCRN-DepthPrediction 模型
本篇文章将再介绍三个官方的CoreML模型:PoseNet,DeeplabV3和FCRN-DepthPrediction。
2023-05-25
LeetCode 题解:104. 二叉树的最大深度,递归,JavaScript,详细注释
原题链接:https://leetcode-cn.com/problems/maximum-depth-of-binary-tree/
2020-10-10
2. 递归题目实战
2023-09-26
新年红包封面来了,3000 万份红包封面来啦!到点直领!(1)
麦吉丽 秦岚
2021-11-12
4. 分片集群架构设计技巧
2023-09-26
Alibaba 最新神作!耗时 182 天肝出来的 1015 页分布式全栈手册太香了
到底什么是分布式?这个话题一直以来就在各大平台论坛上被热议。一千个读者里面就有一千个哈姆雷特。官方这边给出的结论是:分布式就是将相同或相关的程序运行在多台计算机上,从而实现特定目标的一种计算方式。而从分布式技术的起源来看,随之诞生的分布式系
2021-10-14
一致性协议算法
一致性算法整理:raft、zab、paxos
2020-07-14
【愚公系列】2022 年 05 月 二十三种设计模式 (十六)- 迭代器模式 (Iterator Pattern)
设计模式(Design pattern)是一套被反复使用、多数人知晓的、经过分类编目的、代码设计经验的总结。使用设计模式是为了可重用代码、让代码更容易被他人理解、保证代码可靠性。 毫无疑问,设计模式于己于他人于系统都是多赢的,设计模式使代码编制真正工程化
2022-05-17
实用机器学习笔记十三:随机梯度下降
本文是个人在 B 站自学李沐老师的实用机器学习课程【斯坦福 2021 秋季中文同步】的学习笔记,感觉沐神讲解的非常棒 yyds。
2021-12-14
Android 布局阴影实现,移动开发框架 2019
android:gravity="center"android:textSize="14sp"android:textColor="@color/colorBlack"android:layout_width="100dp"android:elevation="3dp"android:layout_height="80dp"/
2021-11-05
1. 调度器 kube-scheduler
2023-09-26
12.1 大数据技术发展史
12.1大数据技术发展史
2020-12-14
7. 计数器累加器、分布式缓存和 Task 数据传输策略
2023-09-08
6. gRPC:etcd 服务发现实现
2023-09-27
加速云原生应用落地,焱融 YRCloudFile 与天翼云完成兼容性认证
近日,焱融 YRCloudFile 分布式存储系统完成与天翼云适配认证!
2021-11-25
第二模块作业
第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业第二模块作业
2021-10-31
我翻遍整个牛客网,整理出了全网最全的 Java 面试八股文大合集,整整 6000 多页
大家从 Boss 直聘上或者其他招聘网站上都可以看到 Java 岗位众多,Java 岗位的招聘薪酬天差地别,人才要求也是五花八门。而很多 Java 工程师求职过程中,也是冷暖自知。很多时候技术有,但是面试的时候就是过不了!
2023-05-29
推荐阅读
4、一主多从,互为主从
2023-09-27
6、数据库调优 - 慢查询日志、最大连接数、线程缓存
2023-09-27
2023 上海国际智慧停车展览会
2023-08-21
华院计算宣晓华:未来十年,基于数据与知识融合的模型将大放异彩
2023-06-21
Apache Doris 2.0.2 版本正式发布!
2023-10-13
小间距 LED 显示屏的技术优势有哪些?
2023-10-20
3. 负载均衡:加权轮询实现
2023-09-27
电子书
大厂实战PPT下载
换一换 姚远 | 面壁智能 研究员 & 清华大学计算机系博士后
崔红保 | DCloud CTO
Scott Shaw | Thoughtworks 亚太区 CTO
评论