写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:234458

评论

发布
暂无评论
发现更多内容

Golden Gate (GGX) 启动公测,下一代创新DeFi和跨链 dApps 征程开始

股市老人

《好好学习》:如何管理知识?

郭明

广西高等教育学会高校教育技术委员会莅临瑞云科技考察交流

3DCAT实时渲染

虚拟仿真 元宇宙 实时渲染云

如何使用Go语言实现迪米特法则

Jack

AIGC背后的技术分析 | K均值聚类算法Python实现

TiAmo

Python AIGC K值算法

图数据库 NebulaGraph 的内存管理实践之 Memory Tracker

NebulaGraph

数据库 内存管理 图数据库

生产环境质量保障的重要性

老张

质量保障 稳定性保障

10个提高工作效率的Cinema 4D小技巧

Finovy Cloud

C4D

Orillusion引擎正式开源!AIGC时代下的WebGPU轻量级3D渲染引擎!

Orillusion

开源 3D 渲染引擎 webgpu AIGC

Spring中@NotEmpty、@NotBlank、@NotNull 区别和使用

Java你猿哥

Java spring Spring Boot string ssm

ChatGPT 科普(65/100)

hackstoic

ChatGPT

Java 把一个 List 转换为字符串

HoneyMoose

什么是Auto-GPT?如何使用、部署Auto-GPT?

炜娓道来程序人生

人工智能 AI ChatGPT

亚马逊云科技 一周回顾 – 2022 年 7 月 18 日

亚马逊云科技 (Amazon Web Services)

Amazon

Django笔记三十之log日志记录详解

Hunter熊

Python django 日志 log

开源赋能 普惠未来|浪潮集团寄语2023开放原子全球开源峰会

开放原子开源基金会

开源赋能 普惠未来|360集团寄语2023开放原子全球开源峰会

开放原子开源基金会

C语言编程—变量的构成

芯动大师

用友BIP新零售产品发布,与零售企业共创新未来

用友BIP

新零售 数字营销

IoTLink版本更新V1.25.0

山东云则信息科技

Java 物联网平台

什么是反射?它有什么用?

javacn.site

Java 面试

2023企业数智化财务创新峰会 · 成都站圆满举办!

用友BIP

智能会计 价值财务

软件测试/测试开发丨学习笔记之列表、元组、集合

测试人

Python 软件测试 自动化测试 列表 测试开发

浅谈中小企业为何放弃自媒体营销:定位不准、期望值过高、缺乏专业团队

石头IT视角

实现园林梦想尽在GardenPlanner 激活~

真大的脸盆

Mac Mac 软件 园林设计

2023 年度中国 DevOps 现状调查|有奖问卷

CODING DevOps

DevOps 云端IDE cloudstudio

软件测试 | spyne开发接口

测吧(北京)科技有限公司

测试

“伙伴+华为”体系,数字时代的新航标

脑极体

伙伴 体系

2023-05-17:一个正整数如果能被 a 或 b 整除,那么它是神奇的。 给定三个整数 n , a , b ,返回第 n 个神奇的数字。 因为答案可能很大,所以返回答案 对 10^9 + 7 取模

福大大架构师每日一题

Go 算法 rust 福大大

智聚北京!相约全球人力资源数智化峰会

用友BIP

人力资源

LLMs 诸神之战:LangChain ,以【奥德赛】之名

Zilliz

Milvus AIGC LLM langchain

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章