Skip to content

Instantly share code, notes, and snippets.

@sile
Last active April 8, 2020 08:55
Show Gist options
  • Select an option

  • Save sile/1158ba37ff5b8c290f8953acebffed80 to your computer and use it in GitHub Desktop.

Select an option

Save sile/1158ba37ff5b8c290f8953acebffed80 to your computer and use it in GitHub Desktop.

Revisions

  1. sile revised this gist Apr 8, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion optimize.py
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@ def objective(trial):
    # Replace config entries with the suggested parameters.
    config = open("config.gin").read()
    for name, value in trial.params.items():
    config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config)
    config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config) # FIXME: Escape `name`

    # Create a temporary config file.
    temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8")
  2. sile created this gist Apr 8, 2020.
    2 changes: 2 additions & 0 deletions config.gin
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,2 @@
    train.batch_size = 10
    train.learning_rate = 0.1
    34 changes: 34 additions & 0 deletions optimize.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,34 @@
    import re
    import optuna
    import subprocess
    import tempfile


    def objective(trial):
    # Suggest parameters.
    trial.suggest_int("train.batch_size", 4, 100)
    trial.suggest_loguniform("train.learning_rate", 0.0001, 1.0)

    # Replace config entries with the suggested parameters.
    config = open("config.gin").read()
    for name, value in trial.params.items():
    config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config)

    # Create a temporary config file.
    temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8")
    temp.write(config)
    temp.flush()

    # Run train script with the temporary config.
    result = subprocess.run(
    ["python3", "train.py", "--config-path", temp.name],
    stdout=subprocess.PIPE,
    encoding="utf-8",
    )

    # Parse the script output to get the objective value.
    return float(result.stdout)


    study = optuna.create_study()
    study.optimize(objective, n_trials=10)
    18 changes: 18 additions & 0 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,18 @@
    import argparse
    import gin


    @gin.configurable
    def train(batch_size, learning_rate):
    value = batch_size * learning_rate # TODO: Replace with a real training code.
    print("{}".format(value))


    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", default="config.gin")
    args = parser.parse_args()

    gin.parse_config_file(args.config_path)


    train()