Skip to content
Permalink
Browse files

TPUEstimator: Log status every 60 seconds to indicate progress.

PiperOrigin-RevId: 238738425
  • Loading branch information...
rjpower authored and tensorflower-gardener committed Mar 16, 2019
1 parent b3991e0 commit 2a2c8a967911a530e553d1f76b83a65bfe5bd6cb
Showing with 22 additions and 7 deletions.
  1. +22 −7 tensorflow/python/tpu/tpu_estimator.py
@@ -253,6 +253,18 @@ def _extract_key_names(tensor_or_dict):
return []


class PeriodicLogger(object):

def __init__(self, seconds):
self._log_every_n_seconds = seconds
self._last_log_time = 0

def log(self, msg, *args, **kw):
if time.time() - self._last_log_time > self._log_every_n_seconds:
self._last_log_time = time.time()
logging.info(msg, *args, **kw)


class _SIGNAL(object):
"""Signal used to control the thread of infeed/outfeed.
@@ -460,8 +472,6 @@ def __init__(self,
self._initial_infeed_sleep_secs = (
ctx.config.tpu_config.initial_infeed_sleep_secs)

self._feed_error = None
self._finished = False
# When using model parallelism, the TPU is pre-initialized at startup to
# fetch mesh information. We skip re-initializing it here to avoid
# suspected issues due to the mesh layout changing on the second
@@ -505,11 +515,13 @@ def _run_infeed(self, queue_ctx, session):

def _run_outfeed(self, queue_ctx, session):
logging.info('Starting outfeed thread controller.')
status_logger = PeriodicLogger(seconds=60)
with self._rendezvous.catch_errors(source='outfeed', session=session):
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
session.run(self._dequeue_ops)
status_logger.log('Outfeed finished for iteration (%d, %d)', count, i)
logging.info('Outfeed thread finished, shutting down.')

def _create_infeed_controller(self, name, target, args):
@@ -557,8 +569,6 @@ def after_create_session(self, session, coord):
shutdown_timeout=watchdog_timeout)

def before_run(self, run_context):
self._feed_error = None

iterations = run_context.session.run(self._iterations_per_loop_var)

logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
@@ -569,7 +579,6 @@ def before_run(self, run_context):
self._outfeed_controller.send_next_batch_signal(iterations)

def end(self, session):
self._finished = True
logging.info('Stop infeed thread controller')
self._infeed_controller.join()
self._rendezvous.record_done('infeed')
@@ -1490,6 +1499,14 @@ def train_step(loss):
and estimator_spec.host_call is not None):
host_call.record({'host_call': estimator_spec.host_call})
host_call_outfeed_ops = host_call.create_enqueue_op()
else:
# Create a dummy outfeed for the loss to track execution progress
host_call.record({
'host_call': (lambda loss_t: loss_t,
[array_ops.reshape(loss, [1])])
})
host_call_outfeed_ops = host_call.create_enqueue_op()

with ops.control_dependencies(host_call_outfeed_ops):
return array_ops.identity(loss)

@@ -2856,8 +2873,6 @@ def _model_fn(features, labels, mode, config, params):
tpu_init_ops.extend(embedding_variables_and_ops.load_ops())

host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
host_ops = []

shutdown_hooks = []
shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',

0 comments on commit 2a2c8a9

Please sign in to comment.
You can’t perform that action at this time.