速来报名!AICon北京站鸿蒙专场~ 了解详情
写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:234215

评论

发布
暂无评论
发现更多内容

Java岗史上最全八股文面试真题汇总,堪称2022年面试天花板

Geek_0c76c3

Java 数据库 开源 程序员 开发

华为应用市场审核指南解读课程上线,面向开发者讲解应用审核2022年更新要点

最新动态

OptaPlanner快速入门-概述

积木思维

ESP32-C3 学习测试 蓝牙 篇(二、蓝牙调试APP、开发板手机连接初体验)

矜辰所致

ESP32-C3 9月月更 蓝牙APP

一加是OPPO的子品牌?我来说说我的看法

Geek_8a195c

产品经理必看的高效产品文档撰写指南

Baklib

产品 产品经理 文档

React 新提案 useEvent 已死?不,它将涅盘重生。

清秋

React useEvent RFC 提案

帮助中心案例分析|师爷,给我解释解释什么叫降本增效?

Baklib

降本增效 帮助中心

联通研究院霍龙社博士深度解析“AI项目到底适不适合开源”

OpenI启智社区

人工智能 OpenI启智社区 AI开源 CubeAI智立方

一文读懂TDengine的三种查询功能

TDengine

数据库 tdengine 时序数据库 企业号九月金秋榜

全网首发!马士兵内部共享—1658页《Java面试突击核心讲》

Geek_0c76c3

Java 数据库 开源 程序员 开发

全方位助力数据科学组织协同&个人研究:ModelWhale 产品功能介绍与版本选择指引

ModelWhale

云计算 科技 数据科学 编程建模 组织协同

为什么3D实时渲染很重要

3DCAT实时渲染

云计算 元宇宙 实时渲染 实时云渲染 云VR

Baklib+伙伴云+企微会话存档,打造伙伴云帮助中心运营体系

Baklib

面试整理的45W字Java真题和答案详解(含核心考点及6家大厂真题)

Geek_0c76c3

Java 数据库 开源 程序员 开发

开发者有话说|刚毕业的“00后”,歪打误撞进入了SAP行业

暮春零贰

个人成长 9月月更

好的,BFS,学会了

掘金安东尼

前端 9月月更

企业IT运维开发一体化解决方案

力软低代码开发平台

什么是实时渲染,3D实时渲染的优缺点

3DCAT实时渲染

云计算 元宇宙 实时渲染 实时云渲染 云VR

ESP32-C3 学习测试 蓝牙 篇(三、认识蓝牙 GATT 协议)

矜辰所致

蓝牙 ESP32-C3 9月月更 GATT

当下企业数字化转型,PaaS是基础解

ToB行业头条

Trending热榜关闭前,我把Github今年最火Java面试题汇总扒下来了

Geek_0c76c3

Java 数据库 开源 程序员 开发

借助iMazing工具重新安装或升级 iOS系统

淋雨

ios iphone

专访美象科技|中国数字孪生50强为何需要3DCAT实时渲染云的赋能?

3DCAT实时渲染

云计算 元宇宙 实时渲染 实时云渲染 云VR

哪7个场景影响研发效能?

LigaAI

敏捷 LigaAI 企业号九月金秋榜 #敏捷开发 #程序

为了进大厂!吃透了各大厂最新 3000+Java 面试题啃完面试肯定妥了

Geek_0c76c3

Java 开源 程序员 架构 开发

如何使用游戏引擎进行实时渲染和内容创建

3DCAT实时渲染

云计算 元宇宙 实时渲染 实时云渲染 云VR

盘点团队在线协作文档工具

Baklib

在线协作文档

Vue3入门指北(五)条件渲染

Augus

Vue 3 9月月更

Apache APISIX 集成 Elasticsearch 实现实时日志监控

API7.ai 技术团队

elasticsearch API网关 APISIX 网关

从新零售、物流到广告,搞定指标中台就这么简单!

Kyligence

数据分析 指标管理 指标中台

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章