Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Last active May 10, 2019 01:14
Show Gist options
  • Select an option

  • Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.

Select an option

Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.

Revisions

  1. benoitdescamps revised this gist Oct 28, 2018. 1 changed file with 5 additions and 18 deletions.
    23 changes: 5 additions & 18 deletions blog-distributed-tensorflow-map-0.py
    Original file line number Diff line number Diff line change
    @@ -6,8 +6,6 @@ def map_fun(args, ctx):
    if job_name == "ps":
    server.join()
    elif job_name == "worker":
    #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
    #one task should be identified as chief. This is necessary to handle for exmaple initialization
    is_chiefing = (task_index == 0)
    with tf.device(tf.train.replica_device_setter(
    worker_device="/job:worker/task:%d" % task_index,
    @@ -16,31 +14,20 @@ def map_fun(args, ctx):
    def build_model():
    pass

    model_input, model_labels,tf_optimizer, tf_loss,tf_global_step = build_model()


    hooks=[...]

    with tf.train.MonitoredTrainingSession(master=server.target,\
    is_chief=is_chiefing,
    checkpoint_dir=arsg['save_dir'],\
    hooks=hooks,\
    save_checkpoint_secs=600.) as mon_sess:
    start_time = datetime.now()
    tf.logging.info("{0} session ready".format(start_time.isoformat()))

    #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
    # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
    tf_feed = ctx.get_data_feed(train_mode=True)

    step = 0
    while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
    batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))

    if len(batch_data) > 0:
    feed = {model_input: batch_data, model_labels: batch_labels}
    _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)
    while not mon_sess.should_stop() and not tf_feed.should_stop():
    batch_data = tf_feed.next_batch(args['batch_size']))
    #apply what you need to be done here
    _ = mon_sess.run(...)

    if mon_sess.should_stop() or step >= args['steps']:
    if mon_sess.should_stop():
    tf_feed.terminate()

  2. benoitdescamps revised this gist Oct 28, 2018. 1 changed file with 39 additions and 75 deletions.
    114 changes: 39 additions & 75 deletions blog-distributed-tensorflow-map-0.py
    Original file line number Diff line number Diff line change
    @@ -1,82 +1,46 @@
    def map_fun(args, ctx):
    try:
    import tensorflow as tf

    #utils
    from datetime import datetime
    import time
    import logging
    import numpy as np

    logger = logging.getLogger()
    tf.logging.set_verbosity(tf.logging.DEBUG)
    worker_num = ctx.worker_num
    job_name = ctx.job_name
    task_index = ctx.task_index
    cluster, server = ctx.start_cluster_server(1)
    if job_name == "ps":
    server.join()
    elif job_name == "worker":
    #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
    #one task should be identified as chief. This is necessary to handle for exmaple initialization
    is_chiefing = (task_index == 0)
    with tf.device(tf.train.replica_device_setter(
    worker_device="/job:worker/task:%d" % task_index,
    cluster=cluster)):

    worker_num = ctx.worker_num
    job_name = ctx.job_name
    task_index = ctx.task_index
    #cluster_spec = ctx.cluster_spec
    #num_workers = len(cluster_spec['worker'])

    cluster, server = ctx.start_cluster_server(1)
    #TFNode.start_cluster_server(ctx)

    def get_next_batch(batch):
    batch = np.array(batch)
    data = batch[:,2:-1].reshape((batch.shape[0],timesteps,num_features))
    labels = batch[:,-1].astype(int)
    return data,to_categorical(labels,num_classes=num_classes)

    if job_name == "ps":
    server.join()
    elif job_name == "worker":
    #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
    #one task should be identified as chief. This is necessary to handle for exmaple initialization

    is_chiefing = (task_index == 0)
    with tf.device(tf.train.replica_device_setter(
    worker_device="/job:worker/task:%d" % task_index,
    cluster=cluster)):
    def build_model():
    pass

    def build_model():
    pass
    model_input, model_labels,tf_optimizer, tf_loss,tf_global_step = build_model()

    model_input,\
    model_labels,\
    model_output,\
    tf_global_step,\
    tf_loss,\
    tf_optimizer,\
    tf_metrics = build_model()


    hooks=[tf.train.StepCounterHook()]

    with tf.train.MonitoredTrainingSession(master=server.target,\
    is_chief=is_chiefing,
    checkpoint_dir=arsg['save_dir'],\
    hooks=hooks,\
    save_checkpoint_secs=600.) as mon_sess:
    start_time = datetime.now()
    tf.logging.info("{0} session ready".format(start_time.isoformat()))

    #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
    # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
    tf_feed = ctx.get_data_feed(train_mode=True)

    hooks=[...]

    step = 0
    while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
    batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))
    with tf.train.MonitoredTrainingSession(master=server.target,\
    is_chief=is_chiefing,
    checkpoint_dir=arsg['save_dir'],\
    hooks=hooks,\
    save_checkpoint_secs=600.) as mon_sess:
    start_time = datetime.now()
    tf.logging.info("{0} session ready".format(start_time.isoformat()))

    if len(batch_data) > 0:
    feed = {model_input: batch_data, model_labels: batch_labels}
    _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)

    if mon_sess.should_stop() or step >= args['steps']:
    tf_feed.terminate()
    logger.info("{0} stopping supervisor".format(datetime.now().isoformat()))



    except Exception as e:
    logger.error(e)
    #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
    # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
    tf_feed = ctx.get_data_feed(train_mode=True)

    step = 0
    while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
    batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))

    if len(batch_data) > 0:
    feed = {model_input: batch_data, model_labels: batch_labels}
    _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)

    if mon_sess.should_stop() or step >= args['steps']:
    tf_feed.terminate()

  3. benoitdescamps renamed this gist Oct 26, 2018. 1 changed file with 0 additions and 0 deletions.
  4. benoitdescamps created this gist Oct 26, 2018.
    82 changes: 82 additions & 0 deletions blog-distributed-tensorflow-map-0
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,82 @@
    def map_fun(args, ctx):
    try:
    import tensorflow as tf

    #utils
    from datetime import datetime
    import time
    import logging
    import numpy as np

    logger = logging.getLogger()
    tf.logging.set_verbosity(tf.logging.DEBUG)

    worker_num = ctx.worker_num
    job_name = ctx.job_name
    task_index = ctx.task_index
    #cluster_spec = ctx.cluster_spec
    #num_workers = len(cluster_spec['worker'])

    cluster, server = ctx.start_cluster_server(1)
    #TFNode.start_cluster_server(ctx)

    def get_next_batch(batch):
    batch = np.array(batch)
    data = batch[:,2:-1].reshape((batch.shape[0],timesteps,num_features))
    labels = batch[:,-1].astype(int)
    return data,to_categorical(labels,num_classes=num_classes)

    if job_name == "ps":
    server.join()
    elif job_name == "worker":
    #https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
    #one task should be identified as chief. This is necessary to handle for exmaple initialization

    is_chiefing = (task_index == 0)
    with tf.device(tf.train.replica_device_setter(
    worker_device="/job:worker/task:%d" % task_index,
    cluster=cluster)):

    def build_model():
    pass

    model_input,\
    model_labels,\
    model_output,\
    tf_global_step,\
    tf_loss,\
    tf_optimizer,\
    tf_metrics = build_model()


    hooks=[tf.train.StepCounterHook()]

    with tf.train.MonitoredTrainingSession(master=server.target,\
    is_chief=is_chiefing,
    checkpoint_dir=arsg['save_dir'],\
    hooks=hooks,\
    save_checkpoint_secs=600.) as mon_sess:
    start_time = datetime.now()
    tf.logging.info("{0} session ready".format(start_time.isoformat()))

    #https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
    # see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
    tf_feed = ctx.get_data_feed(train_mode=True)


    step = 0
    while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
    batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))

    if len(batch_data) > 0:
    feed = {model_input: batch_data, model_labels: batch_labels}
    _, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)

    if mon_sess.should_stop() or step >= args['steps']:
    tf_feed.terminate()
    logger.info("{0} stopping supervisor".format(datetime.now().isoformat()))



    except Exception as e:
    logger.error(e)