Last active
June 30, 2023 19:12
-
-
Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.
REINFORCE with exponential moving average baseline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """REINFORCE with exponential moving average baseline.""" | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from typing import Optional | |
| class Reinforce(nn.Module): | |
| """REINFORCE with exponential moving average baseline. | |
| Args: | |
| use_baseline (bool): Subtract baseline from loss. Defaults to True. | |
| beta (float): Exponential moving average decay factor for baseline. | |
| """ | |
| def __init__(self, use_baseline: bool = True, beta: float = 0.99): | |
| super().__init__() | |
| self.use_baseline = use_baseline | |
| self.beta = beta | |
| self.register_buffer("beta_cumprod", torch.tensor(1.0)) | |
| self.register_buffer("loss_mean_biased", torch.tensor(0.0)) | |
| self.grad_paths = [] | |
| def register_action(self, logprobs: torch.Tensor) -> torch.Tensor: | |
| """Register logprobs of actions to attach their grad path before the backward pass. | |
| Args: | |
| logprobs (torch.Tensor): Logprobs of actions. | |
| Returns: | |
| torch.Tensor: Grad path. | |
| """ | |
| logprob = logprobs.sum() | |
| grad_path = torch.exp(logprob - logprob.detach()) | |
| self.grad_paths.append(grad_path) | |
| return grad_path | |
| def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor: | |
| """Prepare loss for backward pass, subtracting baseline and attaching grad paths. | |
| Args: | |
| loss (torch.Tensor): Loss to prepare. | |
| Returns: | |
| torch.Tensor: Prepared loss. | |
| """ | |
| with torch.no_grad(): | |
| self.beta_cumprod.mul_(self.beta) | |
| self.loss_mean_biased.mul_(self.beta).add_(loss, alpha=1 - self.beta) | |
| loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod) | |
| if self.use_baseline: | |
| loss = loss - loss_mean | |
| for grad_path in self.grad_paths: | |
| loss = loss * grad_path | |
| self.grad_paths = [] | |
| return loss | |
| def sample_categorical( | |
| self, logits: torch.Tensor, actions: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """Sample from categorical distribution and register the actions' grad paths for the | |
| backward pass. If actions are provided, register them instead of sampling. | |
| Args: | |
| logits (torch.Tensor): Unnormalized logits of categorical distribution. | |
| actions (torch.Tensor, optional): Actions that were taken. Defaults to None. | |
| Returns: | |
| torch.Tensor: Actions that were taken. | |
| """ | |
| if actions is None: | |
| g = torch.rand_like(logits).log_().neg_().log_().neg_() | |
| actions = torch.argmax(logits + g, dim=-1) | |
| logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None]) | |
| self.register_action(logprobs) | |
| return actions | |
| def backward( | |
| self, loss: torch.Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False | |
| ) -> torch.Tensor: | |
| """Attach grad paths to loss and call backward. | |
| Args: | |
| loss (torch.Tensor): Loss to prepare. | |
| retain_graph (bool, optional): Retain graph for backward pass. Defaults to None. | |
| create_graph (bool): Create graph for backward pass. Defaults to False. | |
| Returns: | |
| torch.Tensor: Prepared loss after backward. | |
| """ | |
| loss = self.prepare_loss(loss) | |
| loss.backward(retain_graph=retain_graph, create_graph=create_graph) | |
| return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment