论文解读:
常见的并行训练的模式有两种:
同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。
完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。
存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。
本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。
核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。
代码实现
参考:https://blog.csdn.net/li57681522/article/details/87920210
# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=0
# python dis_tf_ssp.py --job_name=worker --task_index=1
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.ops import state_ops, variables, variable_scope
from tensorflow.python.training import session_run_hook
# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
'Steps to validate and print loss')
# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate
# @tf_export("train.SemiSyncRunHook")
# class SemiSyncRunHook(session_run_hook.SessionRunHook):
class SemiSyncRunHook(tf.train.SessionRunHook):
"""Run by SSP."""
def __init__(self, index, worker_count, staleness=10):
"""Initializes a `SemiSyncRunHook`.
Args:
index: work index
worker_count: number of workers
staleness:
"""
if index >= worker_count:
print("worker index {} is bigger than worker_count {}".format(index, worker_count))
return
self._const_max_test_step = 10000
self._last_step = 0 # 上一次wait的步骤数
self._last_time = self._now_time() # 上一次wait的时间
self._index = index
self._staleness = staleness
self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关
self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count):
worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i))
self._worker_steps.append(worker_step)
if i == index:
self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self):
return time.time()
def after_create_session(self, session, coord):
session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context):
run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数
return None
def after_run(self, run_context, run_values):
while True:
# 1.获取所有worker的训练步骤数
all_worker_steps = run_context.session.run(self._worker_steps)
# print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break;
if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness:
diff_step = all_worker_steps[self._index] - self._last_step
if diff_step / self._const_max_test_step > 1:
self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新
self._last_step = all_worker_steps[self._index]
self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行
# print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time))
else:
break
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="reminder")
y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度
train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10)
hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./ssp_saved_model",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
train_x = np.random.randn(1)
train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10
_, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y})
if step % steps_to_validate == 0:
w_, b_ = mon_sess.run([w, b])
print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))
if __name__ == "__main__":
tf.app.run()
参考文献:
More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/69665236
更多内容推荐
4、一主多从,互为主从
2023-09-27
10. Flink 的 watermark 机制和顺序数据的 watermark
2023-09-08
主成分分析 PCA 与奇异值分解 SVD-PCA 中的 SVD
svd_solver是奇异值分解器的意思,为什么PCA算法下面会有有关奇异值分解的参数?不是两种算法么?
2022-11-18
3. 负载均衡:加权轮询实现
2023-09-27
Boundless Hackathon @Stanford 主题黑客松活动闭幕,一文回顾
由 Stanford Blockchain Accelerator、Zebec Protocol、 Nautilus Chain、Rootz Lab 共同主办了“ Boundless Hackathon @Stanford ” 主题的黑客松活动在 7 月 1 日正式落下帷幕。
2023-07-11
读《Software Systems Architecture》(05)—— The Role of the Software Architect
读《Software Systems Architecture》(05)—— The Role of the Software Architect
2022-06-14
赣州有资质等保测评机构有几家?咨询电话多少?
赣州有资质等保测评机构有几家?咨询电话多少?最近看到不少人在问这两个问题,这里小编就给大家简单回答一下,希望有用哈!
2023-10-31
在线随机抛硬币正反面统计工具
在线随机抛硬币正反面统计工具
2022-07-18
Sentieon | 每周文献 -Epidemiology(流行病学)- 第五期
标题(英文):Rare Variants in Inborn Errors of Immunity Genes Associated With Covid-19 Severity
2023-08-24
第二届征文大赛开奖啦!速来领奖!
第二届有奖征文活动开奖啦!
2022-06-10
10- 并发工具类 -CyclicBarrier 循环栅栏
2023-09-26
Flink 核心组件
Flink五个核心组件
2022-12-12
前端 leetcde 算法面试套路之回溯
回溯,就是无脑冲,碰壁之后就回撤一步继续搞,属于一种暴力解题的思路;
2023-02-27
免费试用的云管平台哪里有?可以试用多久?
对于企业而言,购买云管平台需要先试用一下,这样才能靠谱,毕竟市面上云管平台厂商太多了。那你知道免费试用的云管平台哪里有?可以试用多久?
2022-12-08
12、串行与并行收集器 -G1、ZGC
2023-09-26
点云标注在自动驾驶中的优化策略
点云标注在自动驾驶中是非常关键的一部分,为了提高其准确性和效率,可以采用以下优化策略:
2023-07-25
13- 并发队列:阻塞、有界和无界
2023-09-26
Acrobat Pro DC 2023 for mac 完美激活版下载
Acrobat Pro DC 2023 for Mac 是一款功能强大的 PDF 编辑和管理软件。它提供了丰富的功能和直观的界面,帮助用户轻松地创建、编辑和共享 PDF 文件。
2023-11-13
版本发布|Orillusion 0.6.7 版本发布啦!
Orillusion 0.6.7版本发布
2023-08-16
软件测试 / 测试开发丨函数式编程学习笔记
一.高阶函数 高阶函数:既然变量可以指向函数,函数的参数能接收变量,那么一个函数就可以接收另一个函数作为参数,这种函数就称之为高阶函数。
2023-07-04
推荐阅读
Parallels Desktop 19 永久激活版下载 附最新破解教程
2023-12-01
PIRF418:Complaining – Why Can’t People Just Be Real
2024-12-17
Gitlab 配置 mirrorRepository 镜像仓库
2023-11-17
27|模型工程(三):低成本领域模型方案,小团队怎么做大模型?
2023-10-20
计算网络之 MSTP 协议与 VRRP 协议
2023-11-16
1、Nginx 概述及 web server 技术选型
2023-09-28
26|模型工程(二):算力受限,如何为“无米之炊”?
2023-10-18
电子书
大厂实战PPT下载
换一换 李谋 | 零一万物 资深算法专家
赵洋 | 蔚来 前端专家工程师
马学宁 | 国投瑞银基金 信息技术部/首席架构师&副总监
评论