Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active November 5, 2024 21:33
Show Gist options
  • Select an option

  • Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.

Select an option

Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.
Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.
"""Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""
from itertools import chain
from typing import List, Tuple
import torch
def plackett_luce_loss(
scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor
) -> torch.Tensor:
"""Plackett-Luce loss function for ranking tasks.
This function uses sum reduction for the loss value. Divide the sum by `len(rankings)`
to get the mean loss value.
If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability
of a completion given its prompt, and `logp_ref` is the log probability of the completion
given its prompt under a reference model, then this function computes the Plackett-Luce DPO
loss (Appendix A.3 of https://arxiv.org/abs/2305.18290).
Args:
scores (torch.Tensor): Model output. Shape (n_scores).
numerator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
Shape (n_groups, n_scores).
denominator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
Shape (n_groups, n_scores).
Returns:
torch.Tensor: Scalar loss value.
"""
scores_n = torch.where(numerator_mask, scores, float("-inf"))
scores_d = torch.where(denominator_mask, scores, float("-inf"))
numerator = torch.logsumexp(scores_n, dim=-1)
denominator = torch.logsumexp(scores_d, dim=-1)
log_likelihood = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
return -torch.sum(log_likelihood)
def make_masks_for_plackett_luce_loss(
rankings: List[List[List[int]]], n_scores: int, device: torch.device = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Make masks for the Plackett-Luce loss function.
This function accepts a list of rankings in the following format:
```
rankings = [
[[1], [2], [0]],
[[2], [0]],
[[0, 2], [1]],
]
```
The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
prefers 2 to 0 and is indifferent to 1. The third ranking prefers 0 and 2 to 1 and is
indifferent between 0 and 2.
Args:
rankings (List[List[List[int]]]): List of rankings.
n_scores (int): Number of scores.
device (torch.device, optional): Device for the masks.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of masks.
"""
n_groups = sum(len(ranking) for ranking in rankings)
numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
i = 0
for ranking in rankings:
remaining_items = set(chain.from_iterable(ranking))
for group in ranking:
items = set(group)
for item in items:
numerator_mask[i, item] = True
for item in remaining_items:
denominator_mask[i, item] = True
remaining_items -= items
i += 1
return numerator_mask.to(device), denominator_mask.to(device)
def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
"""Sample from a Plackett-Luce model.
Args:
scores (torch.Tensor): Model output. Shape (n_scores).
shape (Tuple[int], optional): Shape of the samples.
Returns:
torch.Tensor: Samples. Shape (*shape, n_scores).
"""
gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_()
_, indices = torch.sort(scores + gumbel, dim=-1, descending=True)
return indices
def main():
"""Run tests."""
# Fit a model.
rankings = [
[[1], [2], [0]],
[[2], [0, 1]],
[[0], [1]],
]
masks = make_masks_for_plackett_luce_loss(rankings, 3)
scores = torch.zeros(3, requires_grad=True)
opt = torch.optim.SGD([scores], lr=1)
for i in range(20):
loss = plackett_luce_loss(scores, *masks) / len(rankings)
print(f"step: {i}, loss: {loss:.6f}")
loss.backward()
opt.step()
opt.zero_grad()
scores = scores.detach()
# Sample from the model.
samples = sample_plackett_luce(scores, (1_000_000,))
# Check the probability of a specific ranking.
rankings = [[[2], [0], [1]]]
masks = make_masks_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *masks))
cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1)
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check the probability of a ranking where we are indifferent to one value.
rankings = [[[2], [0]]]
masks = make_masks_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *masks))
cond = torch.argmax((samples == 2).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
# Check the probability of a ranking where we are indifferent between two values.
rankings = [[[1], [0, 2]]]
masks = make_masks_for_plackett_luce_loss(rankings, 3)
expected = torch.exp(-plackett_luce_loss(scores, *masks))
cond1 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
cond2 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 2).int(), dim=-1)
cond = cond1 & cond2
sampled = torch.mean(cond.float())
print(f"expected: {expected:.4f}")
print(f" sampled: {sampled:.4f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment