写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:235208

评论

发布
暂无评论
发现更多内容

解决centos7.0安装mysql后出现access defind for user@'localhost'的错误

北桥苏

MySQL

DPDK与ScaleFlux CSD 3000:金融数据处理的创新组合

ScaleFlux

DPDK 存储技术 数据压缩 金融开源

技术同学如何提高职场话语权

老张

话语权 职场影响力

IaaS预留实例在线交易策略详解

天翼云开发者社区

云计算 大数据 云服务

长三角生物医药产业加速跑,飞桨螺旋桨为创新药企、医药技术伙伴装上AI大模型引擎

飞桨PaddlePaddle

飞桨 生物医药

统一门户的快速构建--基于小程序技术的一种可能

FinFish

统一门户 小程序容器 小程序化 小程序技术

您的数据可以压缩吗?

ScaleFlux

存储成本 存储技术 数据压缩

ChatGPT火了,客服产业怎么办?

创智荟

知识计算 客服 ChatGPT 数字员工

轻量级思维导图工具:iMap Builder 免激活版

真大的脸盆

Mac 思维导图 Mac 软件

百度王海峰团队荣获吴文俊人工智能科技进步奖特等奖,成果已应用于文心一言

飞桨PaddlePaddle

CCIG 2023 百度飞桨分论坛:大模型时代的图象图形技术变革与实践

飞桨PaddlePaddle

ScaleFlux压缩存储产品通过 PolarDB-PG社区版和PolarDB-X 开源版认证

ScaleFlux

开源数据库 数据压缩 数据库技术 企业数据

Spring Boot 单体应用一键升级成 Spring Cloud Alibaba

阿里巴巴云原生

阿里云 微服务 云原生 spring cloud alibaba

出海无从下手?看社交泛娱乐出海「第一趁手工具」怎么说

融云 RongCloud

互联网 社交 融云 泛娱乐 出海

应用在虚机和容器场景下如何优雅上下线

华为云开源

微服务 云原生

MobPush 创建应用

MobTech袤博科技

小程序:技术标准与生态的演变

没有用户名丶

【web 开发】快来给你的类定个标准 -PHP 的接口技术(64)

迷彩

php 接口 interface 三周年连更 类扩展

最高等级!Apache RocketMQ 入选可信开源项目星云象限领导型象限

阿里巴巴云原生

阿里云 云原生 Apache RocketMQ

分享:两年两度升级数据库,我们经历了什么

OceanBase 数据库

数据库 oceanbase

假期充电,用阿里云 Serverless K8s + AIGC 搭建私人代码助理

阿里巴巴云原生

阿里云 Serverless Kubernetes 云原生 AIGC

专访顶象CEO: 新一代AI如何增强验证码安全性

极客天地

精通Vue.js系列实例教程 │ Vue组件的数据监听

TiAmo

Vue Web Worker 监听 watche

极狐(GitLab)重磅发布新产品「极狐星」,让研发效能看得清,算得准,成就企业精英效能管理

极狐GitLab

DevOps 研发管理 研发效能 极狐GitLab 研发效能度量

走进南京邮电大学!龙蜥导师面对面分享如何通过开源经历获得实习/工作机会?| 开源之夏 2023

OpenAnolis小助手

操作系统 实习 龙蜥社区 开源之夏 南京邮电大学

flutter系列之:做一个修改组件属性的动画

程序那些事

flutter 大前端 程序那些事

openEuler 社区 2023 年 4 月运作报告

openEuler

Linux 开源 操作系统 openEuler 资讯

【自己更换模型】如何用 Serverless 一键部署 Stable Diffusion?

阿里巴巴云原生

阿里云 Serverless 云原生 动态模型

日常节省 30%计算资源:阿里云实时计算 Flink 自动调优实践

Apache Flink

大数据 flink 实时计算

前方高能!融云《社交泛娱乐出海作战地图》来袭,前 100 位免费领

融云 RongCloud

图片 社交 融云 泛娱乐 出海

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章