Last active
May 10, 2019 01:14
-
-
Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.
Revisions
-
benoitdescamps revised this gist
Oct 28, 2018 . 1 changed file with 5 additions and 18 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,8 +6,6 @@ def map_fun(args, ctx): if job_name == "ps": server.join() elif job_name == "worker": is_chiefing = (task_index == 0) with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, @@ -16,31 +14,20 @@ def map_fun(args, ctx): def build_model(): pass hooks=[...] with tf.train.MonitoredTrainingSession(master=server.target,\ is_chief=is_chiefing, checkpoint_dir=arsg['save_dir'],\ hooks=hooks,\ save_checkpoint_secs=600.) as mon_sess: tf_feed = ctx.get_data_feed(train_mode=True) while not mon_sess.should_stop() and not tf_feed.should_stop(): batch_data = tf_feed.next_batch(args['batch_size'])) #apply what you need to be done here _ = mon_sess.run(...) if mon_sess.should_stop(): tf_feed.terminate()
-
benoitdescamps revised this gist
Oct 28, 2018 . 1 changed file with 39 additions and 75 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,82 +1,46 @@ def map_fun(args, ctx): worker_num = ctx.worker_num job_name = ctx.job_name task_index = ctx.task_index cluster, server = ctx.start_cluster_server(1) if job_name == "ps": server.join() elif job_name == "worker": #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor #one task should be identified as chief. This is necessary to handle for exmaple initialization is_chiefing = (task_index == 0) with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): def build_model(): pass model_input, model_labels,tf_optimizer, tf_loss,tf_global_step = build_model() hooks=[...] with tf.train.MonitoredTrainingSession(master=server.target,\ is_chief=is_chiefing, checkpoint_dir=arsg['save_dir'],\ hooks=hooks,\ save_checkpoint_secs=600.) as mon_sess: start_time = datetime.now() tf.logging.info("{0} session ready".format(start_time.isoformat())) #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py tf_feed = ctx.get_data_feed(train_mode=True) step = 0 while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']: batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size'])) if len(batch_data) > 0: feed = {model_input: batch_data, model_labels: batch_labels} _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed) if mon_sess.should_stop() or step >= args['steps']: tf_feed.terminate()
-
benoitdescamps renamed this gist
Oct 26, 2018 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
benoitdescamps created this gist
Oct 26, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,82 @@ def map_fun(args, ctx): try: import tensorflow as tf #utils from datetime import datetime import time import logging import numpy as np logger = logging.getLogger() tf.logging.set_verbosity(tf.logging.DEBUG) worker_num = ctx.worker_num job_name = ctx.job_name task_index = ctx.task_index #cluster_spec = ctx.cluster_spec #num_workers = len(cluster_spec['worker']) cluster, server = ctx.start_cluster_server(1) #TFNode.start_cluster_server(ctx) def get_next_batch(batch): batch = np.array(batch) data = batch[:,2:-1].reshape((batch.shape[0],timesteps,num_features)) labels = batch[:,-1].astype(int) return data,to_categorical(labels,num_classes=num_classes) if job_name == "ps": server.join() elif job_name == "worker": #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor #one task should be identified as chief. This is necessary to handle for exmaple initialization is_chiefing = (task_index == 0) with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): def build_model(): pass model_input,\ model_labels,\ model_output,\ tf_global_step,\ tf_loss,\ tf_optimizer,\ tf_metrics = build_model() hooks=[tf.train.StepCounterHook()] with tf.train.MonitoredTrainingSession(master=server.target,\ is_chief=is_chiefing, checkpoint_dir=arsg['save_dir'],\ hooks=hooks,\ save_checkpoint_secs=600.) as mon_sess: start_time = datetime.now() tf.logging.info("{0} session ready".format(start_time.isoformat())) #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py tf_feed = ctx.get_data_feed(train_mode=True) step = 0 while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']: batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size'])) if len(batch_data) > 0: feed = {model_input: batch_data, model_labels: batch_labels} _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed) if mon_sess.should_stop() or step >= args['steps']: tf_feed.terminate() logger.info("{0} stopping supervisor".format(datetime.now().isoformat())) except Exception as e: logger.error(e)