Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Last active May 10, 2019 01:14
Show Gist options
  • Select an option

  • Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.

Select an option

Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.
def map_fun(args, ctx):
worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index
cluster, server = ctx.start_cluster_server(1)
if job_name == "ps":
server.join()
elif job_name == "worker":
#https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
#one task should be identified as chief. This is necessary to handle for exmaple initialization
is_chiefing = (task_index == 0)
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):
def build_model():
pass
model_input, model_labels,tf_optimizer, tf_loss,tf_global_step = build_model()
hooks=[...]
with tf.train.MonitoredTrainingSession(master=server.target,\
is_chief=is_chiefing,
checkpoint_dir=arsg['save_dir'],\
hooks=hooks,\
save_checkpoint_secs=600.) as mon_sess:
start_time = datetime.now()
tf.logging.info("{0} session ready".format(start_time.isoformat()))
#https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
# see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
tf_feed = ctx.get_data_feed(train_mode=True)
step = 0
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))
if len(batch_data) > 0:
feed = {model_input: batch_data, model_labels: batch_labels}
_, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)
if mon_sess.should_stop() or step >= args['steps']:
tf_feed.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment