写点什么

分布式 tensorflow 源码解读 2:MonitoredTrainingSession

  • 2019-11-28
  • 本文字数:15917 字

    阅读完需:约 52 分钟

分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession(    master='',  # pylint: disable=invalid-name    is_chief=True,    checkpoint_dir=None,    scaffold=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=USE_DEFAULT,    save_summaries_steps=USE_DEFAULT,    save_summaries_secs=USE_DEFAULT,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=USE_DEFAULT,    summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(    worker_context,  # pylint: disable=missing-docstring    scaffold,    checkpoint_dir=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=None,    save_summaries_steps=None,    save_summaries_secs=None,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=None,    summary_dir=None):  all_hooks = []
“”“
复制代码


将多个 hook 都加入到 all_hooks 这个列表中


  ”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码


根据正确master的target和session的config去返回session的creator方法体。
复制代码


    """    if config:      session_config = copy.deepcopy(config)      session_config.MergeFrom(self._session_config)    else:      session_config = self._session_config
复制代码


# 根据不同的角色来创建session
复制代码


    if not self._strategy or self._strategy.extended.experimental_should_init:      logging.info("Creating chief session creator with config: %r", config)      return monitored_session.ChiefSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          checkpoint_dir=checkpoint_dir,          checkpoint_filename_with_path=checkpoint_filename_with_path)    else:      logging.info("Creating worker session creator with config: %r", config)      return monitored_session.WorkerSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:


  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`   with tf.Graph().as_default():    # add operations to the graph...    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.    sm = SessionManager()    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

  # 再创建自己的session  with tf.Graph().as_default():    # ...add operations to the graph...    # Create a SessionManager that will wait for the model to become ready.    sm = SessionManager()    sess = sm.wait_for_session(master)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=["train.SessionManager"])class SessionManager(object):  def __init__(self,               local_init_op=None,               ready_op=None,               ready_for_local_init_op=None,               graph=None,               recovery_wait_secs=30,               local_init_run_options=None):    """
复制代码


local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码


    """    # Sets default values of arguments.    if graph is None:      graph = ops.get_default_graph()    self._local_init_op = local_init_op    self._ready_op = ready_op    self._ready_for_local_init_op = ready_for_local_init_op    self._graph = graph    self._recovery_wait_secs = recovery_wait_secs    self._target = None    self._local_init_run_options = local_init_run_options    if ready_for_local_init_op is not None and local_init_op is None:      raise ValueError("If you pass a ready_for_local_init_op "                       "you must also pass a local_init_op "                       ", ready_for_local_init_op [%s]" %                       ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码


其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码


    """    sess, is_loaded_from_checkpoint = self._restore_checkpoint(        master,        saver,        checkpoint_dir=checkpoint_dir,        checkpoint_filename_with_path=checkpoint_filename_with_path,        wait_for_checkpoint=wait_for_checkpoint,        max_wait_secs=max_wait_secs,        config=config)    if not is_loaded_from_checkpoint:      if init_op is None and not init_fn and self._local_init_op is None:        raise RuntimeError("Model is not initialized and no init_op or "                           "init_fn or local_init_op was given")      if init_op is not None:        sess.run(init_op, feed_dict=init_feed_dict)      if init_fn:        init_fn(sess)    ”“”    .....    “”“    return sess

def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


  • 为每个 hook 调用 hook.begin()

  • MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=[‘train.MonitoredTrainingSession’])


def MonitoredTrainingSession(


master=’’, # pylint: disable=invalid-name


is_chief=True,


checkpoint_dir=None,


scaffold=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=USE_DEFAULT,


save_summaries_steps=USE_DEFAULT,


save_summaries_secs=USE_DEFAULT,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=USE_DEFAULT,


summary_dir=None):


“”"


Creates a MonitoredSession for training.


Returns:


A MonitoredSession object.


“”"


scaffold = scaffold or Scaffold()


worker_context = distribute_coordinator_context.get_current_worker_context()


if worker_context:


return _create_monitored_session_with_worker_context(


worker_context,


scaffold,


checkpoint_dir=checkpoint_dir,


hooks=hooks,


chief_only_hooks=chief_only_hooks,


save_checkpoint_secs=save_checkpoint_secs,


save_summaries_steps=save_summaries_steps,


save_summaries_secs=save_summaries_secs,


config=config,


stop_grace_period_secs=stop_grace_period_secs,


log_step_count_steps=log_step_count_steps,


max_wait_secs=max_wait_secs,


save_checkpoint_steps=save_checkpoint_steps,


summary_dir=summary_dir)


if not is_chief:


session_creator = WorkerSessionCreator(


scaffold=scaffold,


master=master,


config=config,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=hooks or [],


stop_grace_period_secs=stop_grace_period_secs)


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


if hooks:


all_hooks.extend(hooks)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(


worker_context, # pylint: disable=missing-docstring


scaffold,


checkpoint_dir=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=None,


save_summaries_steps=None,


save_summaries_secs=None,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=None,


summary_dir=None):


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(


scaffold,


config=config,


checkpoint_dir=checkpoint_dir,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,


scaffold=None,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


max_wait_secs=7200):


“”"


根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。


“”"


if config:


session_config = copy.deepcopy(config)


session_config.MergeFrom(self._session_config)


else:


session_config = self._session_config


# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init:  logging.info("Creating chief session creator with config: %r", config)  return monitored_session.ChiefSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      checkpoint_dir=checkpoint_dir,      checkpoint_filename_with_path=checkpoint_filename_with_path)else:  logging.info("Creating worker session creator with config: %r", config)  return monitored_session.WorkerSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      max_wait_secs=max_wait_secs)
复制代码

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])


class ChiefSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a chief.”""


def init(self,


scaffold=None,


master=’’,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None):


self._checkpoint_dir = checkpoint_dir


self._checkpoint_filename_with_path = checkpoint_filename_with_path


self._scaffold = scaffold or Scaffold()


self._session_manager = None


self._master = master


self._config = config


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().prepare_session(


self._master,


saver=self._scaffold.saver,


checkpoint_dir=self._checkpoint_dir,


checkpoint_filename_with_path=self._checkpoint_filename_with_path,


config=self._config,


init_op=self._scaffold.init_op,


init_feed_dict=self._scaffold.init_feed_dict,


init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])


class WorkerSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a worker.”""


def init(self,


scaffold=None,


master=’’,


config=None,


max_wait_secs=30 * 60):


“”"Initializes a worker session creator.


Args:  max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().wait_for_session(


self._master, config=self._config, max_wait_secs=self._max_wait_secs)


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():


# add operations to the graph…


# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.


sm = SessionManager()


sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():


# …add operations to the graph…


# Create a SessionManager that will wait for the model to become ready.


sm = SessionManager()


sess = sm.wait_for_session(master)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=[“train.SessionManager”])


class SessionManager(object):


def init(self,


local_init_op=None,


ready_op=None,


ready_for_local_init_op=None,


graph=None,


recovery_wait_secs=30,


local_init_run_options=None):


“”"


local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。


ready_op 用于 check 模型是否准备好的一个 op。


ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。


“”"


# Sets default values of arguments.


if graph is None:


graph = ops.get_default_graph()


self._local_init_op = local_init_op


self._ready_op = ready_op


self._ready_for_local_init_op = ready_for_local_init_op


self._graph = graph


self._recovery_wait_secs = recovery_wait_secs


self._target = None


self._local_init_run_options = local_init_run_options


if ready_for_local_init_op is not None and local_init_op is None:


raise ValueError("If you pass a ready_for_local_init_op "


"you must also pass a local_init_op "


“, ready_for_local_init_op [%s]” %


ready_for_local_init_op)


def prepare_session(self,


master,


init_op=None,


saver=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


wait_for_checkpoint=False,


max_wait_secs=7200,


config=None,


init_feed_dict=None,


init_fn=None):


“”"


其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果


不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。


“”"


sess, is_loaded_from_checkpoint = self._restore_checkpoint(


master,


saver,


checkpoint_dir=checkpoint_dir,


checkpoint_filename_with_path=checkpoint_filename_with_path,


wait_for_checkpoint=wait_for_checkpoint,


max_wait_secs=max_wait_secs,


config=config)


if not is_loaded_from_checkpoint:


if init_op is None and not init_fn and self._local_init_op is None:


raise RuntimeError("Model is not initialized and no init_op or "


“init_fn or local_init_op was given”)


if init_op is not None:


sess.run(init_op, feed_dict=init_feed_dict)


if init_fn:


init_fn(sess)


”“”



“”“


return sess


def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):


“”"


Creates a new Session and waits for model to be ready.


“”"


self._target = master


if max_wait_secs is None:


max_wait_secs = float(“Inf”)


timer = _CountDownTimer(max_wait_secs)


while True:  sess = session.Session(self._target, graph=self._graph, config=config)  not_ready_msg = None  not_ready_local_msg = None  local_init_success, not_ready_local_msg = self._try_run_local_init_op(      sess)  if local_init_success:    # Successful if local_init_op is None, or ready_for_local_init_op passes    is_ready, not_ready_msg = self._model_ready(sess)    if is_ready:      return sess  self._safe_close(sess)  # Do we have enough time left to try again?  remaining_ms_after_wait = (      timer.secs_remaining() - self._recovery_wait_secs)  if remaining_ms_after_wait < 0:    raise errors.DeadlineExceededError(        None, None,        "Session was not ready after waiting %d secs." % (max_wait_secs,))  logging.info("Waiting for model to be ready.  "               "Ready_for_local_init_op:  %s, ready: %s",               not_ready_local_msg, not_ready_msg)  time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


为每个 hook 调用 hook.begin()


  • 调用 scaffold.finalize()完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session()函数

  • 当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run()的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话

  • 当 close()函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。

  • 最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:


from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS

def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255

def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


https://www.cnblogs.com/estragon/p/10034511.html


https://zhuanlan.zhihu.com/p/88876923


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/91608555


2019-11-28 08:001445

评论

发布
暂无评论
发现更多内容

Linux cp命令:复制文件和目录

梦笔生花

Downie 4 for mac(视频下载工具) 4.6.32中文激活版

mac

视频下载工具 苹果mac Windows软件 Downie

DAPP合约质押挖矿系统开发(源码搭建)

l8l259l3365

ES6新特性(四)

阡陌r

JavaScript set map ES6 Proxy

Premiere Pro 2024 for Mac(pr 2024视频编辑软件)v24.0完美激活版

mac

苹果mac Windows软件 视频编辑软件 Premiere Pro 2024 pr2024

2023-10-28:用go语言,给定一个n*m的二维矩阵,每个位置都是字符, U、D、L、R表示传送带的位置,会被传送到 : 上、下、左、右, . 、O分别表示空地、目标,一定只有一个目标点, 可以

福大大架构师每日一题

福大大架构师每日一题

Linux ln命令:建立链接文件

梦笔生花

Illustrator 2023 for mac(ai2023矢量图形编辑软件) v27.9完整激活版

mac

ai2023 苹果mac Windows软件 矢量图形编辑软件 Illustrator 2023

手机旗舰芯片巨头较量,都有哪些“秘密武器”?

脑极体

AI

广州软件外包公司开发流程详解

V\TG【ch3nguang】

KK 架构训练营 - Week1

jjn0703

架构

Linux touch命令:创建文件及修改文件时间

梦笔生花

如何找到靠谱的软件外包开发公司?

V\TG【ch3nguang】

After Effects 2023 for Mac(Ae视频特效制作) v23.6永久激活版

mac

AE2023 苹果mac Windows软件 After Effects 2023 视频特效编辑软件

DeFi(智能合约)流动性质押挖矿系统开发理念丨python合约编程教程

V\TG【ch3nguang】

Python 模块:创建、导入和使用

小万哥

Python 程序员 软件 后端 开发

区块链系统开发软件外包公司

V\TG【ch3nguang】

API商品数据接口调用实战:爬虫与数据获取

Noah

API 文档 API 开发

LP流动性挖矿系统开发详细需求丨LP流动性挖矿dapp开发源码部署

V\TG【ch3nguang】

从零开始开发图床工具:使用 Gitee 和 Electron 实现上传、管理和分享

JYeontu

前端 Electron gitee 免费图床

分布式tensorflow源码解读2:MonitoredTrainingSession_语言 & 开发_Alex-zhai_InfoQ精选文章