Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save binshengliu/b9972da2b93c60d98b4650a77e54e26f to your computer and use it in GitHub Desktop.

Select an option

Save binshengliu/b9972da2b93c60d98b4650a77e54e26f to your computer and use it in GitHub Desktop.

Revisions

  1. @erenon erenon revised this gist Oct 1, 2018. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions train_on_batch_with_tensorboard.py
    Original file line number Diff line number Diff line change
    @@ -27,6 +27,7 @@ def named_logs(model, logs):
    result = {}
    for l in zip(model.metrics_names, logs):
    result[l[0]] = l[1]
    return result

    # Run training batches, notify tensorboard at the end of each epoch
    for batch_id in range(1000):
  2. @erenon erenon created this gist Sep 30, 2018.
    38 changes: 38 additions & 0 deletions train_on_batch_with_tensorboard.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,38 @@
    # This example shows how to use keras TensorBoard callback
    # with model.train_on_batch

    import tensorflow.keras as keras

    # Setup the model
    model = keras.models.Sequential()
    model.add(...) # Add your layers
    model.compile(...) # Compile as usual

    batch_size=256

    # Create the TensorBoard callback,
    # which we will drive manually
    tensorboard = keras.callbacks.TensorBoard(
    log_dir='/tmp/my_tf_logs',
    histogram_freq=0,
    batch_size=batch_size,
    write_graph=True,
    write_grads=True
    )
    tensorboard.set_model(model)

    # Transform train_on_batch return value
    # to dict expected by on_batch_end callback
    def named_logs(model, logs):
    result = {}
    for l in zip(model.metrics_names, logs):
    result[l[0]] = l[1]

    # Run training batches, notify tensorboard at the end of each epoch
    for batch_id in range(1000):
    x_train,y_train = create_training_data(batch_size)
    logs = model.train_on_batch(x_train, y_train)
    tensorboard.on_epoch_end(batch_id, named_logs(model, logs))

    tensorboard.on_train_end(None)