Last active
April 2, 2026 16:38
-
-
Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.
Revisions
-
araffin revised this gist
Apr 25, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -41,7 +41,7 @@ def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]: """Sampler for PPO hyperparameters.""" # From 2**5=32 to 2**12=4096 n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12) gamma = trial.suggest_float("gamma", 0.97, 0.9999) learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True) activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) -
araffin created this gist
Apr 25, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,168 @@ """Optuna example that optimizes the hyperparameters of a reinforcement learning agent using PPO implementation from Stable-Baselines3 on a Gymnasium environment. This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo. You can run this example as follows: $ python optimize_ppo.py """ from typing import Any import gymnasium import optuna from optuna.pruners import MedianPruner from optuna.samplers import TPESampler from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.env_util import make_vec_env import torch import torch.nn as nn N_TRIALS = 500 N_STARTUP_TRIALS = 10 N_EVALUATIONS = 2 N_TIMESTEPS = 40_000 EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS) N_EVAL_EPISODES = 10 ENV_ID = "Pendulum-v1" N_ENVS = 5 DEFAULT_HYPERPARAMS = { "policy": "MlpPolicy", } def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]: """Sampler for PPO hyperparameters.""" # From 2**5=32 to 2**12=4096 n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12) gamma = trial.suggest_float("one_minus_gamma", 0.97, 0.9999) learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True) activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) n_steps = 2**n_steps_pow # Display true values trial.set_user_attr("n_steps", n_steps) # Convert to PyTorch objects activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn] return { "n_steps": n_steps, "gamma": gamma, "learning_rate": learning_rate, "policy_kwargs": { "activation_fn": activation_fn, }, } class TrialEvalCallback(EvalCallback): """Callback used for evaluating and reporting a trial.""" def __init__( self, eval_env: gymnasium.Env, trial: optuna.Trial, n_eval_episodes: int = 5, eval_freq: int = 10000, deterministic: bool = True, verbose: int = 0, ): super().__init__( eval_env=eval_env, n_eval_episodes=n_eval_episodes, eval_freq=eval_freq, deterministic=deterministic, verbose=verbose, ) self.trial = trial self.eval_idx = 0 self.is_pruned = False def _on_step(self) -> bool: if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: super()._on_step() self.eval_idx += 1 self.trial.report(self.last_mean_reward, self.eval_idx) # Prune trial if need. if self.trial.should_prune(): self.is_pruned = True return False return True def objective(trial: optuna.Trial) -> float: vec_env = make_vec_env(ENV_ID, n_envs=N_ENVS) kwargs = DEFAULT_HYPERPARAMS.copy() # Sample hyperparameters. kwargs.update(sample_ppo_params(trial)) # Create the RL model. model = PPO(env=vec_env, **kwargs) # Create env used for evaluation. eval_env = make_vec_env(ENV_ID, n_envs=N_ENVS) # Create the callback that will periodically evaluate and report the performance. eval_callback = TrialEvalCallback( eval_env, trial, n_eval_episodes=N_EVAL_EPISODES, eval_freq=max(EVAL_FREQ // N_ENVS, 1), deterministic=True, ) nan_encountered = False try: model.learn(N_TIMESTEPS, callback=eval_callback) except AssertionError as e: # Sometimes, random hyperparams can generate NaN. print(e) nan_encountered = True finally: # Free memory. model.env.close() eval_env.close() # Tell the optimizer that the trial failed. if nan_encountered: return float("nan") if eval_callback.is_pruned: raise optuna.exceptions.TrialPruned() return eval_callback.last_mean_reward if __name__ == "__main__": # Set pytorch num threads to 1 for faster training. torch.set_num_threads(1) sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS, multivariate=True) # Do not prune before 1/3 of the max budget is used. pruner = MedianPruner( n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3 ) study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize") try: study.optimize(objective, n_trials=N_TRIALS, timeout=600) except KeyboardInterrupt: pass print(f"Number of finished trials: {len(study.trials)}") print("Best trial:") trial = study.best_trial print(" Value: ", trial.value) print(" Params: ") for key, value in trial.params.items(): print(f" {key}: {value}") print(" User attrs:") for key, value in trial.user_attrs.items(): print(f" {key}: {value}")