with tf.train.MonitoredTrainingSession(master=server.target,\ is_chief=is_chiefing, checkpoint_dir=arsg['save_dir'],\ hooks=hooks,\ save_checkpoint_secs=600.) as mon_sess: tf_feed = ctx.get_data_feed(train_mode=True) step = 0 while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']: batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size'])) if len(batch_data) > 0: feed = {model_input: batch_data, model_labels: batch_labels} _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed) if mon_sess.should_stop() or step >= args['steps']: tf_feed.terminate()