Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active April 2, 2026 16:38
Show Gist options
  • Select an option

  • Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.

Select an option

Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.

Revisions

  1. araffin revised this gist Apr 25, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion optimize_ppo.py
    Original file line number Diff line number Diff line change
    @@ -41,7 +41,7 @@ def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]:
    """Sampler for PPO hyperparameters."""
    # From 2**5=32 to 2**12=4096
    n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12)
    gamma = trial.suggest_float("one_minus_gamma", 0.97, 0.9999)
    gamma = trial.suggest_float("gamma", 0.97, 0.9999)
    learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True)
    activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])

  2. araffin created this gist Apr 25, 2025.
    168 changes: 168 additions & 0 deletions optimize_ppo.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,168 @@
    """Optuna example that optimizes the hyperparameters of
    a reinforcement learning agent using PPO implementation from Stable-Baselines3
    on a Gymnasium environment.
    This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo.
    You can run this example as follows:
    $ python optimize_ppo.py
    """

    from typing import Any

    import gymnasium
    import optuna
    from optuna.pruners import MedianPruner
    from optuna.samplers import TPESampler
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import EvalCallback
    from stable_baselines3.common.env_util import make_vec_env
    import torch
    import torch.nn as nn


    N_TRIALS = 500
    N_STARTUP_TRIALS = 10
    N_EVALUATIONS = 2
    N_TIMESTEPS = 40_000
    EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
    N_EVAL_EPISODES = 10

    ENV_ID = "Pendulum-v1"
    N_ENVS = 5

    DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    }


    def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]:
    """Sampler for PPO hyperparameters."""
    # From 2**5=32 to 2**12=4096
    n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12)
    gamma = trial.suggest_float("one_minus_gamma", 0.97, 0.9999)
    learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True)
    activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])

    n_steps = 2**n_steps_pow
    # Display true values
    trial.set_user_attr("n_steps", n_steps)
    # Convert to PyTorch objects
    activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn]

    return {
    "n_steps": n_steps,
    "gamma": gamma,
    "learning_rate": learning_rate,
    "policy_kwargs": {
    "activation_fn": activation_fn,
    },
    }


    class TrialEvalCallback(EvalCallback):
    """Callback used for evaluating and reporting a trial."""

    def __init__(
    self,
    eval_env: gymnasium.Env,
    trial: optuna.Trial,
    n_eval_episodes: int = 5,
    eval_freq: int = 10000,
    deterministic: bool = True,
    verbose: int = 0,
    ):
    super().__init__(
    eval_env=eval_env,
    n_eval_episodes=n_eval_episodes,
    eval_freq=eval_freq,
    deterministic=deterministic,
    verbose=verbose,
    )
    self.trial = trial
    self.eval_idx = 0
    self.is_pruned = False

    def _on_step(self) -> bool:
    if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
    super()._on_step()
    self.eval_idx += 1
    self.trial.report(self.last_mean_reward, self.eval_idx)
    # Prune trial if need.
    if self.trial.should_prune():
    self.is_pruned = True
    return False
    return True


    def objective(trial: optuna.Trial) -> float:
    vec_env = make_vec_env(ENV_ID, n_envs=N_ENVS)
    kwargs = DEFAULT_HYPERPARAMS.copy()
    # Sample hyperparameters.
    kwargs.update(sample_ppo_params(trial))
    # Create the RL model.
    model = PPO(env=vec_env, **kwargs)
    # Create env used for evaluation.
    eval_env = make_vec_env(ENV_ID, n_envs=N_ENVS)
    # Create the callback that will periodically evaluate and report the performance.
    eval_callback = TrialEvalCallback(
    eval_env,
    trial,
    n_eval_episodes=N_EVAL_EPISODES,
    eval_freq=max(EVAL_FREQ // N_ENVS, 1),
    deterministic=True,
    )

    nan_encountered = False
    try:
    model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
    # Sometimes, random hyperparams can generate NaN.
    print(e)
    nan_encountered = True
    finally:
    # Free memory.
    model.env.close()
    eval_env.close()

    # Tell the optimizer that the trial failed.
    if nan_encountered:
    return float("nan")

    if eval_callback.is_pruned:
    raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward


    if __name__ == "__main__":
    # Set pytorch num threads to 1 for faster training.
    torch.set_num_threads(1)

    sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS, multivariate=True)
    # Do not prune before 1/3 of the max budget is used.
    pruner = MedianPruner(
    n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
    )

    study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
    try:
    study.optimize(objective, n_trials=N_TRIALS, timeout=600)
    except KeyboardInterrupt:
    pass

    print(f"Number of finished trials: {len(study.trials)}")

    print("Best trial:")
    trial = study.best_trial

    print(" Value: ", trial.value)

    print(" Params: ")
    for key, value in trial.params.items():
    print(f" {key}: {value}")

    print(" User attrs:")
    for key, value in trial.user_attrs.items():
    print(f" {key}: {value}")