论文解读:
常见的并行训练的模式有两种:
同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。
完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。
存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。
本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。
核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。
代码实现
参考:https://blog.csdn.net/li57681522/article/details/87920210
# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=1
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.ops import state_ops, variables, variable_scope
from tensorflow.python.training import session_run_hook
# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
'Steps to validate and print loss')
# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate
# @tf_export("train.SemiSyncRunHook")
# class SemiSyncRunHook(session_run_hook.SessionRunHook):
class SemiSyncRunHook(tf.train.SessionRunHook):
"""Run by SSP."""
def __init__(self, index, worker_count, staleness=10):
"""Initializes a `SemiSyncRunHook`.
Args:
index: work index
worker_count: number of workers
staleness:
"""
if index >= worker_count:
print("worker index {} is bigger than worker_count {}".format(index, worker_count))
return
self._const_max_test_step = 10000
self._last_step = 0 # 上一次wait的步骤数
self._last_time = self._now_time() # 上一次wait的时间
self._index = index
self._staleness = staleness
self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关
self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count):
worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i))
self._worker_steps.append(worker_step)
if i == index:
self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self):
return time.time()
def after_create_session(self, session, coord):
session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context):
run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数
return None
def after_run(self, run_context, run_values):
while True:
# 1.获取所有worker的训练步骤数
all_worker_steps = run_context.session.run(self._worker_steps)
# print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break;
if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness:
diff_step = all_worker_steps[self._index] - self._last_step
if diff_step / self._const_max_test_step > 1:
self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新
self._last_step = all_worker_steps[self._index]
self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行
# print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time))
else:
break
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="reminder")
y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度
train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10)
hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./ssp_saved_model",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
train_x = np.random.randn(1)
train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10
_, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y})
if step % steps_to_validate == 0:
w_, b_ = mon_sess.run([w, b])
print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))
if __name__ == "__main__":
tf.app.run()
参考文献:
More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/69665236
更多内容推荐
将博客搬至 CSDN
将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN将博客搬至CSDN
2022-04-24
FL Studio23 最新永久版水果软件下载教程
FL Studio21首先提供了音符编辑器,编辑器可以针对音乐创作人的要求编辑出不同音律的节奏,例如鼓,镲,锣,钢琴,笛,大提琴,筝,扬琴等等任何乐器在音乐中的配乐。其次提供了音效编辑器,音效编辑器可以编辑出各类声音针对在不同音乐中所要求的音效,例如
2023-01-15
Alibaba 最新神作!耗时 182 天肝出来的 1015 页分布式全栈手册太香了
到底什么是分布式?这个话题一直以来就在各大平台论坛上被热议。一千个读者里面就有一千个哈姆雷特。官方这边给出的结论是:分布式就是将相同或相关的程序运行在多台计算机上,从而实现特定目标的一种计算方式。而从分布式技术的起源来看,随之诞生的分布式系
2021-10-14
过去一年对我帮助最大的三本书
书犹药也,善读之可以医愚
2022-05-23
一个关于 += 的谜题
原文链接: 一个关于 += 的谜题
2022-02-28
10. SQL 编程:Prepare Statement
2023-09-26
再说绩效考核
绩效考核 自评 非物质奖励
2021-12-24
GitHub 标星 3,Android 面试
Android Binder机制及AIDL使用
2021-11-05
7. 计数器累加器、分布式缓存和 Task 数据传输策略
2023-09-08
有迹可循之 CheckList
我们经常在Code Review的时候经常不知道怎么CodeReview,或者说写代码的时候怎么写出很棒的代码,更多的是靠我们的经验和感觉。怎么才能做到知其然知其所以然呢?
2021-09-01
4. 分片集群架构设计技巧
2023-09-26
7. Spark 的广播变量和累加器
2023-09-08
Nacos 配置中心之环境准备
Nacos配置中心的工作流程是怎么样的呢?首先启动SpringBoot项目,在启动项目之后,需要把远程服务器的配置文件加载到Spring容器中
2022-07-21
10、串行与并行收集器 - 吞吐量优先 PS、PO
2023-09-26
Android 动画之补间动画
2.位移动画:TranslateAnimation:四个参数:X轴开始位置,X轴结束位置,Y轴开始位置,Y轴结束位置
2021-11-07
当你面试的时候,被问到关于 Fragment 的种种,Android 开发教程
public int getBreadCrumbShortTitleRes();public CharSequence getBreadCrumbTitle();public CharSequence getBreadCrumbShortTitle();}
2021-11-03
4、一主多从,互为主从
2023-09-27
元宇宙大热,是风口还是虎口
说起当下科技界和投资界的热门词汇,元宇宙应当榜上有名。尽管元宇宙概念最早出现在1992年美国作家尼尔·斯蒂芬森的科幻小说《雪崩》中,但2021年元宇宙却突然火爆各界,因此,2021年也被称为“元宇宙元年”。2021年3月,元宇宙游戏公司Roblox以300亿美元市值
2022-04-10
【C 语言深度剖析】深入理解 C 语言中函数的递归算法
一个函数在它的函数体内调用它自身,这种调用过程称为递归,这种函数称为递归函数
2022-09-10
防抖 & 节流
防抖、节流
2021-11-03
推荐阅读
27|模型工程(三):低成本领域模型方案,小团队怎么做大模型?
2023-10-20
如何用 Xcode 安装 ipa
2023-05-11
26|模型工程(二):算力受限,如何为“无米之炊”?
2023-10-18
Studio One 6 for mac(专业音乐制作软件)
2024-12-18
14. 如何做一名教练型的乘法领导者
2023-10-17
Lakehouse is ALL you need
2024-12-17
父母、离别
2023-02-24
电子书
大厂实战PPT下载
换一换 孟红伦(云际) | 阿里巴巴 高级前端技术专家
蒋志伟 | next.ai 创始人
马腾 | 微软(中国)有限公司 全渠道事业部架构师团队/资深云解决方案架构师
评论