MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。
我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。
@tf_export(v1=['train.MonitoredTrainingSession'])
def MonitoredTrainingSession(
master='', # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):
"""
Creates a `MonitoredSession` for training.
Returns:
A `MonitoredSession` object.
"""
scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
“”“
将多个hook都加入到all_hooks这个列表中
”“”
if hooks:
all_hooks.extend(hooks)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
复制代码
我们首先解释下参数:
is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。
checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。
scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。
hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;
chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。
save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。
save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100
save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。
config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。
stop_grace_period_secs:调用 close()后线程停止的秒数。
log_step_count_steps:记录全局步/秒的全局步数的频率。
实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。
然后我们仔细分解下代码:
def _create_monitored_session_with_worker_context(
worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
“”“
复制代码
将多个 hook 都加入到 all_hooks 这个列表中
”“”
logging.info('all_hooks %r', all_hooks)
# 创建session
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
"""
复制代码
根据正确master的target和session的config去返回session的creator方法体。
复制代码
"""
if config:
session_config = copy.deepcopy(config)
session_config.MergeFrom(self._session_config)
else:
session_config = self._session_config
复制代码
if not self._strategy or self._strategy.extended.experimental_should_init:
logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
max_wait_secs=max_wait_secs)
# ChiefSessionCreator
@tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.compat.v1.Session for a chief."""
def __init__(self,
scaffold=None,
master='',
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None):
self._checkpoint_dir = checkpoint_dir
self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().prepare_session(
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)
# WorkerSessionCreator
@tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.compat.v1.Session for a worker."""
def __init__(self,
scaffold=None,
master='',
config=None,
max_wait_secs=30 * 60):
"""Initializes a worker session creator.
Args:
max_wait_secs: Maximum time to wait for the session to become available.
"""
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
self._max_wait_secs = max_wait_secs
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码
从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:
官方针对 SessionManager 类有一个简单的例子,感觉很清楚:
# prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`
with tf.Graph().as_default():
# add operations to the graph...
# Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
sm = SessionManager()
sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
复制代码
第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后
# 再创建自己的session
with tf.Graph().as_default():
# ...add operations to the graph...
# Create a SessionManager that will wait for the model to become ready.
sm = SessionManager()
sess = sm.wait_for_session(master)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
复制代码
然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:
@tf_export(v1=["train.SessionManager"])
class SessionManager(object):
def __init__(self,
local_init_op=None,
ready_op=None,
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
"""
复制代码
local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。
ready_op 用于check模型是否准备好的一个op。
ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码
"""
# Sets default values of arguments.
if graph is None:
graph = ops.get_default_graph()
self._local_init_op = local_init_op
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
", ready_for_local_init_op [%s]" %
ready_for_local_init_op)
def prepare_session(self,
master,
init_op=None,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None,
init_feed_dict=None,
init_fn=None):
"""
复制代码
其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果
不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码
"""
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
”“”
.....
“”“
return sess
def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
"""
Creates a new `Session` and waits for model to be ready.
"""
self._target = master
if max_wait_secs is None:
max_wait_secs = float("Inf")
timer = _CountDownTimer(max_wait_secs)
while True:
sess = session.Session(self._target, graph=self._graph, config=config)
not_ready_msg = None
not_ready_local_msg = None
local_init_success, not_ready_local_msg = self._try_run_local_init_op(
sess)
if local_init_success:
# Successful if local_init_op is None, or ready_for_local_init_op passes
is_ready, not_ready_msg = self._model_ready(sess)
if is_ready:
return sess
self._safe_close(sess)
# Do we have enough time left to try again?
remaining_ms_after_wait = (
timer.secs_remaining() - self._recovery_wait_secs)
if remaining_ms_after_wait < 0:
raise errors.DeadlineExceededError(
None, None,
"Session was not ready after waiting %d secs." % (max_wait_secs,))
logging.info("Waiting for model to be ready. "
"Ready_for_local_init_op: %s, ready: %s",
not_ready_local_msg, not_ready_msg)
time.sleep(self._recovery_wait_secs)
复制代码
创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,
一个完整的 monitored session 在创建时间内可做的事情(按顺序):
我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。
@tf_export(v1=[‘train.MonitoredTrainingSession’])
def MonitoredTrainingSession(
master=’’, # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):
“”"
Creates a MonitoredSession
for training.
Returns:
A MonitoredSession
object.
“”"
scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
if hooks:
all_hooks.extend(hooks)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
我们首先解释下参数:
is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。
checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。
scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。
hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;
chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。
save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。
save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100
save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。
config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。
stop_grace_period_secs:调用 close()后线程停止的秒数。
log_step_count_steps:记录全局步/秒的全局步数的频率。
实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。
然后我们仔细分解下代码:
def _create_monitored_session_with_worker_context(
worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
logging.info(‘all_hooks %r’, all_hooks)
创建 session
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
session_creator 函数主体
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
“”"
根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。
“”"
if config:
session_config = copy.deepcopy(config)
session_config.MergeFrom(self._session_config)
else:
session_config = self._session_config
# 根据不同的角色来创建session
if not self._strategy or self._strategy.extended.experimental_should_init:
logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
max_wait_secs=max_wait_secs)
复制代码
ChiefSessionCreator
@tf_export(v1=[‘train.ChiefSessionCreator’])
class ChiefSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a chief.”""
def init(self,
scaffold=None,
master=’’,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None):
self._checkpoint_dir = checkpoint_dir
self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
复制代码
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().prepare_session(
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)
WorkerSessionCreator
@tf_export(v1=[‘train.WorkerSessionCreator’])
class WorkerSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a worker.”""
def init(self,
scaffold=None,
master=’’,
config=None,
max_wait_secs=30 * 60):
“”"Initializes a worker session creator.
Args:
max_wait_secs: Maximum time to wait for the session to become available.
"""
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
self._max_wait_secs = max_wait_secs
复制代码
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
复制代码
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config, max_wait_secs=self._max_wait_secs)
从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:
官方针对 SessionManager 类有一个简单的例子,感觉很清楚:
prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_op
和 saver
with tf.Graph().as_default():
# add operations to the graph…
# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.
sm = SessionManager()
sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后
再创建自己的 session
with tf.Graph().as_default():
# …add operations to the graph…
# Create a SessionManager that will wait for the model to become ready.
sm = SessionManager()
sess = sm.wait_for_session(master)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:
@tf_export(v1=[“train.SessionManager”])
class SessionManager(object):
def init(self,
local_init_op=None,
ready_op=None,
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
“”"
local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。
ready_op 用于 check 模型是否准备好的一个 op。
ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。
“”"
# Sets default values of arguments.
if graph is None:
graph = ops.get_default_graph()
self._local_init_op = local_init_op
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
“, ready_for_local_init_op [%s]” %
ready_for_local_init_op)
def prepare_session(self,
master,
init_op=None,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None,
init_feed_dict=None,
init_fn=None):
“”"
其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果
不存在 checkpoint 就从传入的init_op
和 调用init_fn
函数去创建 session。
“”"
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
“init_fn or local_init_op was given”)
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
”“”
…
“”“
return sess
def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):
“”"
Creates a new Session
and waits for model to be ready.
“”"
self._target = master
if max_wait_secs is None:
max_wait_secs = float(“Inf”)
timer = _CountDownTimer(max_wait_secs)
while True:
sess = session.Session(self._target, graph=self._graph, config=config)
not_ready_msg = None
not_ready_local_msg = None
local_init_success, not_ready_local_msg = self._try_run_local_init_op(
sess)
if local_init_success:
# Successful if local_init_op is None, or ready_for_local_init_op passes
is_ready, not_ready_msg = self._model_ready(sess)
if is_ready:
return sess
self._safe_close(sess)
# Do we have enough time left to try again?
remaining_ms_after_wait = (
timer.secs_remaining() - self._recovery_wait_secs)
if remaining_ms_after_wait < 0:
raise errors.DeadlineExceededError(
None, None,
"Session was not ready after waiting %d secs." % (max_wait_secs,))
logging.info("Waiting for model to be ready. "
"Ready_for_local_init_op: %s, ready: %s",
not_ready_local_msg, not_ready_msg)
time.sleep(self._recovery_wait_secs)
复制代码
创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,
一个完整的 monitored session 在创建时间内可做的事情(按顺序):
为每个 hook 调用 hook.begin()
调用 scaffold.finalize()完成 graph
创建 session
为模型参数做初始化 ,通过 Scaffold
如果存在 checkpoint 则根据 checkpoint restore 参数
发布 runners 队列
调用 hook.after_create_session()函数
当 run 函数调用时,monitored session 做的事情:
调用 hook.before_run()
调用 TensorFlow 中的 session.run()
with merged fetches and feed_dict
调用 hook.after_run()
返回 session.run()的结果
如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话
当 close()函数调用时,monitored session 做的事情:
调用 hook.end()
关闭 queue runners 和 session
如果所有的输入数据被消耗完,抛出 OutOfRange 异常。
最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:
from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")
tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")
tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")
tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS
def model(images):
"""Define a simple mnist classifier"""
net = tf.layers.dense(images, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 10, activation=None)
return net
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32')
x_test = x_test.reshape(-1, 784).astype('float32')
x_train /= 255
x_test /= 255
def get_batch(image, label, batch_size=32, training=True):
df = tf.data.Dataset.from_tensor_slices((image, label))
if training:
df = df.repeat(10).shuffle(buffer_size=1000)
df = df.batch(batch_size).prefetch(batch_size)
iterator = df.make_one_shot_iterator()
batch_x, batch_y = iterator.get_next()
return batch_x, batch_y
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts'
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train)
test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps":
server.join() # ps hosts only join
elif FLAGS.job_name == "worker":
# workers perform the operation
# ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables)
# on the ps hosts (default placement strategy: round-robin over all ps hosts, and also
# place multi copies of operations to each worker host
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
logits = model(train_batch_x)
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=1e-04)
if FLAGS.is_sync:
# asynchronous training
# use tf.train.SyncReplicasOptimizer wrap optimizer
# ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers,
total_num_replicas=FLAGS.num_workers)
# create the hook which handles initialization and queues
hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./checkpoint_dir",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# mon_sess.run handles AbortedError in case of preempted PS.
_, ls, step = mon_sess.run([train_op, loss, global_step])
if step % 100 == 0:
print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__":
tf.app.run()
复制代码
参考文献:
https://www.cnblogs.com/estragon/p/10034511.html
https://zhuanlan.zhihu.com/p/88876923
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/91608555
评论