Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active June 30, 2023 19:12
Show Gist options
  • Select an option

  • Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.

Select an option

Save crowsonkb/7a9d6a852e47b4f8026947a08a47774c to your computer and use it in GitHub Desktop.

Revisions

  1. crowsonkb revised this gist Jun 30, 2023. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -75,10 +75,10 @@ def prepare_losses(
    Returns:
    torch.Tensor: Prepared loss.
    """
    with torch.no_grad():
    self.beta_cumprod.mul_(self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(losses.mean(), alpha=1 - self.beta)
    loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)
    loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)
    loss_mean.nan_to_num_()
    self.beta_cumprod.mul_(self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(losses.detach().mean(), alpha=1 - self.beta)

    if baseline is not None:
    pass
  2. crowsonkb revised this gist Jun 30, 2023. 1 changed file with 15 additions and 35 deletions.
    50 changes: 15 additions & 35 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -13,7 +13,7 @@ class Reinforce(nn.Module):
    Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).
    Args:
    use_baseline (bool): Subtract baseline from loss. Defaults to True.
    use_baseline (bool): Subtract baseline from losses. Defaults to True.
    beta (float): Exponential moving average decay factor for baseline.
    Example:
    @@ -25,9 +25,10 @@ class Reinforce(nn.Module):
    >>> opt.zero_grad()
    >>> actions = estimator.sample_categorical(logits)
    Then, after you have computed your loss:
    Then, after you have computed a batch of losses:
    >>> estimator.backward(loss)
    >>> loss = estimator.prepare_losses(losses)
    >>> loss.backward()
    >>> opt.step()
    """

    @@ -60,24 +61,23 @@ def register_actions(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor])
    """
    if mask is not None:
    logprobs = logprobs * mask
    logprob = logprobs.sum()
    self.logprobs.append(logprob)
    self.logprobs.append(logprobs)

    def prepare_loss(
    self, loss: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None
    def prepare_losses(
    self, losses: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None
    ) -> torch.Tensor:
    """Prepare loss for backward pass, subtracting the baseline and attaching grad paths.
    """Prepare a batch of losses for the backward pass.
    Args:
    loss (torch.Tensor): Loss to prepare.
    losses (torch.Tensor): Batch of losses to prepare.
    baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
    Returns:
    torch.Tensor: Prepared loss.
    """
    with torch.no_grad():
    self.beta_cumprod.mul_(self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(loss, alpha=1 - self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(losses.mean(), alpha=1 - self.beta)
    loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)

    if baseline is not None:
    @@ -87,17 +87,19 @@ def prepare_loss(
    else:
    baseline = 0.0

    logprob = sum(self.logprobs, loss.new_tensor(0.0))
    logprobs = [logprobs.flatten(losses.ndim).sum(losses.ndim) for logprobs in self.logprobs]
    logprobs = sum(logprobs, torch.zeros_like(losses))
    self.logprobs.clear()
    return loss * self.magic_box(logprob) + (1 - self.magic_box(logprob)) * baseline
    surrogates = losses * self.magic_box(logprobs) + (1 - self.magic_box(logprobs)) * baseline
    return surrogates.mean()

    def sample_categorical(
    self,
    logits: torch.Tensor,
    actions: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
    """Sample from categorical distribution and register the actions' grad paths for the
    """Sample from a categorical distribution and register the actions' grad paths for the
    backward pass. If actions are provided, register them instead of sampling.
    Args:
    @@ -127,25 +129,3 @@ def sample_categorical(
    logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
    self.register_actions(logprobs, mask)
    return actions

    def backward(
    self,
    loss: torch.Tensor,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    baseline: Optional[Union[float, torch.Tensor]] = None,
    ) -> torch.Tensor:
    """Prepare the loss and perform the backward pass.
    Args:
    loss (torch.Tensor): Loss to prepare.
    retain_graph (bool, optional): Retain graph for backward pass. Defaults to None.
    create_graph (bool): Create graph for backward pass. Defaults to False.
    baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
    Returns:
    torch.Tensor: Prepared loss after backward.
    """
    loss = self.prepare_loss(loss, baseline)
    loss.backward(retain_graph=retain_graph, create_graph=create_graph)
    return loss
  3. crowsonkb revised this gist Jun 30, 2023. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -51,7 +51,7 @@ def magic_box(w: torch.Tensor) -> torch.Tensor:
    """
    return torch.exp(w - w.detach())

    def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> None:
    def register_actions(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> None:
    """Register logprobs of actions to attach their grad path before the backward pass.
    Args:
    @@ -125,7 +125,7 @@ def sample_categorical(
    g = torch.rand_like(logits).log_().neg_().log_().neg_()
    actions = torch.argmax(logits + g, dim=-1)
    logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
    self.register_action(logprobs, mask)
    self.register_actions(logprobs, mask)
    return actions

    def backward(
  4. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 22 additions and 15 deletions.
    37 changes: 22 additions & 15 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,5 @@
    """REINFORCE with exponential moving average baseline."""
    """REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
    Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098)."""

    import torch
    from torch import nn
    @@ -8,7 +9,8 @@


    class Reinforce(nn.Module):
    """REINFORCE with exponential moving average baseline.
    """REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
    Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).
    Args:
    use_baseline (bool): Subtract baseline from loss. Defaults to True.
    @@ -35,24 +37,31 @@ def __init__(self, use_baseline: bool = True, beta: float = 0.99):
    self.beta = beta
    self.register_buffer("beta_cumprod", torch.tensor(1.0))
    self.register_buffer("loss_mean_biased", torch.tensor(0.0))
    self.grad_paths = []
    self.logprobs = []

    def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
    @staticmethod
    def magic_box(w: torch.Tensor) -> torch.Tensor:
    """MagicBox operator (see https://arxiv.org/abs/1802.05098).
    Args:
    w (torch.Tensor): Input tensor.
    Returns:
    torch.Tensor: The result of the MagicBox operator.
    """
    return torch.exp(w - w.detach())

    def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> None:
    """Register logprobs of actions to attach their grad path before the backward pass.
    Args:
    logprobs (torch.Tensor): Logprobs of actions.
    mask (torch.Tensor, optional): Mask for actions. Defaults to None.
    Returns:
    torch.Tensor: Grad path.
    """
    if mask is not None:
    logprobs = logprobs * mask
    logprob = logprobs.sum()
    grad_path = torch.exp(logprob - logprob.detach())
    self.grad_paths.append(grad_path)
    return grad_path
    self.logprobs.append(logprob)

    def prepare_loss(
    self, loss: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None
    @@ -78,11 +87,9 @@ def prepare_loss(
    else:
    baseline = 0.0

    loss = loss - baseline
    for grad_path in self.grad_paths:
    loss = loss * grad_path
    self.grad_paths.clear()
    return loss + baseline
    logprob = sum(self.logprobs, loss.new_tensor(0.0))
    self.logprobs.clear()
    return loss * self.magic_box(logprob) + (1 - self.magic_box(logprob)) * baseline

    def sample_categorical(
    self,
  5. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 22 additions and 7 deletions.
    29 changes: 22 additions & 7 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -4,7 +4,7 @@
    from torch import nn
    from torch.nn import functional as F

    from typing import Optional
    from typing import Optional, Union


    class Reinforce(nn.Module):
    @@ -54,11 +54,14 @@ def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor])
    self.grad_paths.append(grad_path)
    return grad_path

    def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    def prepare_loss(
    self, loss: torch.Tensor, baseline: Optional[Union[float, torch.Tensor]] = None
    ) -> torch.Tensor:
    """Prepare loss for backward pass, subtracting the baseline and attaching grad paths.
    Args:
    loss (torch.Tensor): Loss to prepare.
    baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
    Returns:
    torch.Tensor: Prepared loss.
    @@ -67,12 +70,19 @@ def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    self.beta_cumprod.mul_(self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(loss, alpha=1 - self.beta)
    loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)
    if self.use_baseline:
    loss = loss - loss_mean

    if baseline is not None:
    pass
    elif self.use_baseline:
    baseline = loss_mean
    else:
    baseline = 0.0

    loss = loss - baseline
    for grad_path in self.grad_paths:
    loss = loss * grad_path
    self.grad_paths.clear()
    return loss
    return loss + baseline

    def sample_categorical(
    self,
    @@ -112,18 +122,23 @@ def sample_categorical(
    return actions

    def backward(
    self, loss: torch.Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False
    self,
    loss: torch.Tensor,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    baseline: Optional[Union[float, torch.Tensor]] = None,
    ) -> torch.Tensor:
    """Prepare the loss and perform the backward pass.
    Args:
    loss (torch.Tensor): Loss to prepare.
    retain_graph (bool, optional): Retain graph for backward pass. Defaults to None.
    create_graph (bool): Create graph for backward pass. Defaults to False.
    baseline (Optional[Union[float, torch.Tensor]], optional): Custom baseline to subtract.
    Returns:
    torch.Tensor: Prepared loss after backward.
    """
    loss = self.prepare_loss(loss)
    loss = self.prepare_loss(loss, baseline)
    loss.backward(retain_graph=retain_graph, create_graph=create_graph)
    return loss
  6. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion reinforce.py
    Original file line number Diff line number Diff line change
    @@ -71,7 +71,7 @@ def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    loss = loss - loss_mean
    for grad_path in self.grad_paths:
    loss = loss * grad_path
    self.grad_paths = []
    self.grad_paths.clear()
    return loss

    def sample_categorical(
  7. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 5 additions and 4 deletions.
    9 changes: 5 additions & 4 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -37,15 +37,18 @@ def __init__(self, use_baseline: bool = True, beta: float = 0.99):
    self.register_buffer("loss_mean_biased", torch.tensor(0.0))
    self.grad_paths = []

    def register_action(self, logprobs: torch.Tensor) -> torch.Tensor:
    def register_action(self, logprobs: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
    """Register logprobs of actions to attach their grad path before the backward pass.
    Args:
    logprobs (torch.Tensor): Logprobs of actions.
    mask (torch.Tensor, optional): Mask for actions. Defaults to None.
    Returns:
    torch.Tensor: Grad path.
    """
    if mask is not None:
    logprobs = logprobs * mask
    logprob = logprobs.sum()
    grad_path = torch.exp(logprob - logprob.detach())
    self.grad_paths.append(grad_path)
    @@ -105,9 +108,7 @@ def sample_categorical(
    g = torch.rand_like(logits).log_().neg_().log_().neg_()
    actions = torch.argmax(logits + g, dim=-1)
    logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
    if mask is not None:
    logprobs = logprobs * mask
    self.register_action(logprobs)
    self.register_action(logprobs, mask)
    return actions

    def backward(
  8. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 10 additions and 2 deletions.
    12 changes: 10 additions & 2 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -72,14 +72,18 @@ def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    return loss

    def sample_categorical(
    self, logits: torch.Tensor, actions: Optional[torch.Tensor] = None
    self,
    logits: torch.Tensor,
    actions: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
    """Sample from categorical distribution and register the actions' grad paths for the
    backward pass. If actions are provided, register them instead of sampling.
    Args:
    logits (torch.Tensor): Unnormalized logits of categorical distribution.
    actions (torch.Tensor, optional): Actions that were taken. Defaults to None.
    mask (torch.Tensor, optional): Mask for tokens. Defaults to None.
    Returns:
    torch.Tensor: Actions that were taken.
    @@ -93,12 +97,16 @@ def sample_categorical(
    >>> estimator.sample_categorical(logits[:, prompt_len - 1 : -1], tokens[:, prompt_len:])
    Notice how the tokens are shifted one position right from the logits they were sampled
    from and the prompt tokens aren't included.
    from and the prompt tokens aren't included. If you cannot exclude your prompt or
    padding tokens with simple slicing, you can provide a mask (1/True for token positions
    that grads should propagate through, 0/False to stop gradients).
    """
    if actions is None:
    g = torch.rand_like(logits).log_().neg_().log_().neg_()
    actions = torch.argmax(logits + g, dim=-1)
    logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
    if mask is not None:
    logprobs = logprobs * mask
    self.register_action(logprobs)
    return actions

  9. crowsonkb revised this gist Jun 29, 2023. 1 changed file with 27 additions and 2 deletions.
    29 changes: 27 additions & 2 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -13,6 +13,20 @@ class Reinforce(nn.Module):
    Args:
    use_baseline (bool): Subtract baseline from loss. Defaults to True.
    beta (float): Exponential moving average decay factor for baseline.
    Example:
    >>> opt = optim.Adam(model.parameters(), lr=1e-3)
    >>> estimator = Reinforce().to(device)
    In your training loop:
    >>> opt.zero_grad()
    >>> actions = estimator.sample_categorical(logits)
    Then, after you have computed your loss:
    >>> estimator.backward(loss)
    >>> opt.step()
    """

    def __init__(self, use_baseline: bool = True, beta: float = 0.99):
    @@ -38,7 +52,7 @@ def register_action(self, logprobs: torch.Tensor) -> torch.Tensor:
    return grad_path

    def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    """Prepare loss for backward pass, subtracting baseline and attaching grad paths.
    """Prepare loss for backward pass, subtracting the baseline and attaching grad paths.
    Args:
    loss (torch.Tensor): Loss to prepare.
    @@ -69,6 +83,17 @@ def sample_categorical(
    Returns:
    torch.Tensor: Actions that were taken.
    Example:
    If you have sampled tokens from a HuggingFace model, you can use this method to
    register the grad paths of the sampled tokens. You need to obtain logits from the
    model that have a grad_fn:
    >>> logits = model(tokens).logits
    >>> estimator.sample_categorical(logits[:, prompt_len - 1 : -1], tokens[:, prompt_len:])
    Notice how the tokens are shifted one position right from the logits they were sampled
    from and the prompt tokens aren't included.
    """
    if actions is None:
    g = torch.rand_like(logits).log_().neg_().log_().neg_()
    @@ -80,7 +105,7 @@ def sample_categorical(
    def backward(
    self, loss: torch.Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False
    ) -> torch.Tensor:
    """Attach grad paths to loss and call backward.
    """Prepare the loss and perform the backward pass.
    Args:
    loss (torch.Tensor): Loss to prepare.
  10. crowsonkb created this gist Jun 29, 2023.
    95 changes: 95 additions & 0 deletions reinforce.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,95 @@
    """REINFORCE with exponential moving average baseline."""

    import torch
    from torch import nn
    from torch.nn import functional as F

    from typing import Optional


    class Reinforce(nn.Module):
    """REINFORCE with exponential moving average baseline.
    Args:
    use_baseline (bool): Subtract baseline from loss. Defaults to True.
    beta (float): Exponential moving average decay factor for baseline.
    """

    def __init__(self, use_baseline: bool = True, beta: float = 0.99):
    super().__init__()
    self.use_baseline = use_baseline
    self.beta = beta
    self.register_buffer("beta_cumprod", torch.tensor(1.0))
    self.register_buffer("loss_mean_biased", torch.tensor(0.0))
    self.grad_paths = []

    def register_action(self, logprobs: torch.Tensor) -> torch.Tensor:
    """Register logprobs of actions to attach their grad path before the backward pass.
    Args:
    logprobs (torch.Tensor): Logprobs of actions.
    Returns:
    torch.Tensor: Grad path.
    """
    logprob = logprobs.sum()
    grad_path = torch.exp(logprob - logprob.detach())
    self.grad_paths.append(grad_path)
    return grad_path

    def prepare_loss(self, loss: torch.Tensor) -> torch.Tensor:
    """Prepare loss for backward pass, subtracting baseline and attaching grad paths.
    Args:
    loss (torch.Tensor): Loss to prepare.
    Returns:
    torch.Tensor: Prepared loss.
    """
    with torch.no_grad():
    self.beta_cumprod.mul_(self.beta)
    self.loss_mean_biased.mul_(self.beta).add_(loss, alpha=1 - self.beta)
    loss_mean = self.loss_mean_biased / (1 - self.beta_cumprod)
    if self.use_baseline:
    loss = loss - loss_mean
    for grad_path in self.grad_paths:
    loss = loss * grad_path
    self.grad_paths = []
    return loss

    def sample_categorical(
    self, logits: torch.Tensor, actions: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
    """Sample from categorical distribution and register the actions' grad paths for the
    backward pass. If actions are provided, register them instead of sampling.
    Args:
    logits (torch.Tensor): Unnormalized logits of categorical distribution.
    actions (torch.Tensor, optional): Actions that were taken. Defaults to None.
    Returns:
    torch.Tensor: Actions that were taken.
    """
    if actions is None:
    g = torch.rand_like(logits).log_().neg_().log_().neg_()
    actions = torch.argmax(logits + g, dim=-1)
    logprobs = F.log_softmax(logits, dim=-1).gather(-1, actions[..., None])
    self.register_action(logprobs)
    return actions

    def backward(
    self, loss: torch.Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False
    ) -> torch.Tensor:
    """Attach grad paths to loss and call backward.
    Args:
    loss (torch.Tensor): Loss to prepare.
    retain_graph (bool, optional): Retain graph for backward pass. Defaults to None.
    create_graph (bool): Create graph for backward pass. Defaults to False.
    Returns:
    torch.Tensor: Prepared loss after backward.
    """
    loss = self.prepare_loss(loss)
    loss.backward(retain_graph=retain_graph, create_graph=create_graph)
    return loss