def map_fun(args, ctx): worker_num = ctx.worker_num job_name = ctx.job_name task_index = ctx.task_index cluster, server = ctx.start_cluster_server(1) if job_name == "ps": server.join() elif job_name == "worker": is_chiefing = (task_index == 0) with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): def build_model(): pass hooks=[...] with tf.train.MonitoredTrainingSession(master=server.target,\ is_chief=is_chiefing, checkpoint_dir=arsg['save_dir'],\ hooks=hooks,\ save_checkpoint_secs=600.) as mon_sess: tf_feed = ctx.get_data_feed(train_mode=True) while not mon_sess.should_stop() and not tf_feed.should_stop(): batch_data = tf_feed.next_batch(args['batch_size'])) #apply what you need to be done here _ = mon_sess.run(...) if mon_sess.should_stop(): tf_feed.terminate()