Last active
April 8, 2020 08:55
-
-
Save sile/1158ba37ff5b8c290f8953acebffed80 to your computer and use it in GitHub Desktop.
Revisions
-
sile revised this gist
Apr 8, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,7 +12,7 @@ def objective(trial): # Replace config entries with the suggested parameters. config = open("config.gin").read() for name, value in trial.params.items(): config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config) # FIXME: Escape `name` # Create a temporary config file. temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") -
sile created this gist
Apr 8, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,2 @@ train.batch_size = 10 train.learning_rate = 0.1 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,34 @@ import re import optuna import subprocess import tempfile def objective(trial): # Suggest parameters. trial.suggest_int("train.batch_size", 4, 100) trial.suggest_loguniform("train.learning_rate", 0.0001, 1.0) # Replace config entries with the suggested parameters. config = open("config.gin").read() for name, value in trial.params.items(): config = re.sub("(?<=" + name + ") *=.*", "=" + str(value), config) # Create a temporary config file. temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") temp.write(config) temp.flush() # Run train script with the temporary config. result = subprocess.run( ["python3", "train.py", "--config-path", temp.name], stdout=subprocess.PIPE, encoding="utf-8", ) # Parse the script output to get the objective value. return float(result.stdout) study = optuna.create_study() study.optimize(objective, n_trials=10) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,18 @@ import argparse import gin @gin.configurable def train(batch_size, learning_rate): value = batch_size * learning_rate # TODO: Replace with a real training code. print("{}".format(value)) parser = argparse.ArgumentParser() parser.add_argument("--config-path", default="config.gin") args = parser.parse_args() gin.parse_config_file(args.config_path) train()