Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active June 30, 2023 19:12
Show Gist options
  • Select an option

  • Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.

Select an option

Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.
REINFORCE with exponential moving average baseline
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098)."""
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Union
class Reinforce(nn.Module):
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).
Args:
use_baseline (bool): Subtract baseline from loss. Defaults to True.
beta (float): Exponential moving average decay factor for baseline.
Example:
>>> opt = optim.Adam(model.parameters(), lr=1e-3)
>>> estimator = Reinforce().to(device)
In your training loop:
>>> opt.zero_grad()
>>> actions = estimator.sample_categorical(logits)
Then, after you have computed your loss:
>>> estimator.backward(loss)
>>> opt.step()
"""
def __init__(self, use_baseline: bool = True, beta: float = 0.99):
super().__init__()
self.use_baseline = use_baseline
self.beta = beta
self.register_buffer("beta_cumprod", torch.tensor(1.0))
self.register_buffer("loss_mean_biased", torch.tensor(0.0))
self.logprobs = []
@staticmethod
def magic_box(w: torch.Tensor) -> torch.Tensor:
"""MagicBox operator (see https://arxiv.org/abs/1802.05098).
Args:
w (torch.Tensor): Input tensor.
Returns:
torch.Tensor: The result of the MagicBox operator.
"""
return torch.exp(w - w.detach())
def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> None:
"""Register logprobs of actions to attach their grad path before the backward pass.
Args:
logprobs (torch.Tensor): Logprobs of actions.
mask (torch.Tensor, optional): Mask for actions. Defaults to None.
"""
if mask is not None:
logprobs = logprobs * mask
logprob = logprobs.sum()
self.logprobs.append(logprob)
def prepare_loss(
self, loss: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None
) -> torch.Tensor:
"""Prepare loss for backward pass, subtracting the baseline and attaching grad paths.
Args:
loss (torch.Tensor): Loss to prepare.
baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
Returns:
torch.Tensor: Prepared loss.
"""
with torch.no_grad():
self.beta_cumprod.mul_(self.beta)
self.loss_mean_biased.mul_(self.beta).add_(loss, alpha=1 - self.beta)
loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)
if baseline is not None:
pass
elif self.use_baseline:
baseline = loss_mean
else:
baseline = 0.0
logprob = sum(self.logprobs, loss.new_tensor(0.0))
self.logprobs.clear()
return loss * self.magic_box(logprob) + (1 - self.magic_box(logprob)) * baseline
def sample_categorical(
self,
logits: torch.Tensor,
actions: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Sample from categorical distribution and register the actions' grad paths for the
backward pass. If actions are provided, register them instead of sampling.
Args:
logits (torch.Tensor): Unnormalized logits of categorical distribution.
actions (torch.Tensor, optional): Actions that were taken. Defaults to None.
mask (torch.Tensor, optional): Mask for tokens. Defaults to None.
Returns:
torch.Tensor: Actions that were taken.
Example:
If you have sampled tokens from a HuggingFace model, you can use this method to
register the grad paths of the sampled tokens. You need to obtain logits from the
model that have a grad_fn:
>>> logits = model(tokens).logits
>>> estimator.sample_categorical(logits[:, prompt_len - 1 : -1], tokens[:, prompt_len:])
Notice how the tokens are shifted one position right from the logits they were sampled
from and the prompt tokens aren't included. If you cannot exclude your prompt or
padding tokens with simple slicing, you can provide a mask (1/True for token positions
that grads should propagate through, 0/False to stop gradients).
"""
if actions is None:
g = torch.rand_like(logits).log_().neg_().log_().neg_()
actions = torch.argmax(logits + g, dim=-1)
logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
self.register_action(logprobs, mask)
return actions
def backward(
self,
loss: torch.Tensor,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
baseline: Optional[Union[float, torch.Tensor]] = None,
) -> torch.Tensor:
"""Prepare the loss and perform the backward pass.
Args:
loss (torch.Tensor): Loss to prepare.
retain_graph (bool, optional): Retain graph for backward pass. Defaults to None.
create_graph (bool): Create graph for backward pass. Defaults to False.
baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
Returns:
torch.Tensor: Prepared loss after backward.
"""
loss = self.prepare_loss(loss, baseline)
loss.backward(retain_graph=retain_graph, create_graph=create_graph)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment