写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:235174

评论

发布
暂无评论
发现更多内容

组件发布效率提升15倍是怎么做到的——基于 Gradle 调度机制深度研究与优化

字节跳动终端技术

字节跳动 Gradle Andriod

专有云运维如何更快、更准、更稳?丨智能运维

百度大脑

人工智能

Flink 在字节跳动数据流的实践

字节跳动数据平台

大数据 flink 字节跳动 埋点 数据流

JuiceFS 即将发布 1.0 并调整开源许可

Juicedata

开源 文件系统 云存储

Go语言gorm框架MySQL实践

FunTester

Go MySQL gorm FunTester

2022 起始篇

万万

龙蜥利器:系统运维工具 SysAK的云上应用性能诊断 | 龙蜥技术

OpenAnolis小助手

开源 运维

重磅功能!Apache APISIX 拥抱 WASM 生态

API7.ai 技术团队

开源 云原生 Wasm Apache APISIX

重庆华美:用宜搭实现全流程管理上云,节约超百万研发成本

一只大光圈

前端 阿里 低代码 数字化转型 钉钉宜搭

元宇宙时代:银行如何探索数字化转型

CECBC

「干货分享」如何做好应急响应工作?常见应急响应流程

H

网络安全 应急响应

金融小程序风险如何控制,WeTest小程序质量专项方案一步到位

WeTest

VuePress 博客优化之开启 HTTPS

冴羽

Vue 前端 博客 vuepress 博客搭建

网络安全——内网渗透完整流程

网络安全学海

黑客 网络安全 信息安全 渗透测试 安全漏洞

微软Office新增实用功能允许用户在不同设备上轻松送同步字体

淋雨

实时音视频入门学习:开源工程WebRTC的技术原理和使用浅析

JackJiang

音视频 WebRTC IM 即时通讯IM

不优雅的 React Hooks

CRMEB

关于A股投资--《香帅中国财富报告》摘录(3/100)

hackstoic

投资

2022中国低代码十大发展趋势,市场规模预计达42.6亿

J2PaaS低代码平台

低代码 低代码开发 J2PaaS

LabVIEW实现PCB电路板坐标定位(实战篇—2)

不脱发的程序猿

机器视觉 图像处理 LabVIEW PCB电路板坐标定位

Tengine + BabaSSL ,让国密更易用!

SOFAStack

密码学 tengine 国密 BABASSL

酒店资产管理系统解决方案

低代码小观

CRM 企业管理系统 CRM系统 企业管理工具 企业管理软件

一周信创舆情观察(1.4~1.9)

统小信uos

不是私密链接,如何继续前往?

BUG侦探

https HSTS 劫持

ORTC与SIP融合通信服务架构

安第斯智能云

音视频 RTC 流媒体

一个cpp协程库的前世今生(十七)带时限的锁

SkyFire

c++ cocpp

用 SwiftUI 实现一个开源的 App Store

37手游iOS技术运营团队

swift appstore SwiftUI App榜单 App免费榜

虎年前迎来脑科学新锐:脑虎科技的创生故事

脑极体

到底什么是云?其实云计算从业者也不懂!

Geek_f56666

云计算

金融云原生漫谈(五)|如何打造更适合云原生的数据存储方案?

York

云原生 数据存储

百度智能云以“3D+AI”技术,助力“三亿人上冰雪”

百度开发者中心

人工智能

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章