Last active
November 5, 2024 21:33
-
-
Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.
Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.""" | |
| from itertools import chain | |
| from typing import List, Tuple | |
| import torch | |
| def plackett_luce_loss( | |
| scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Plackett-Luce loss function for ranking tasks. | |
| This function uses sum reduction for the loss value. Divide the sum by `len(rankings)` | |
| to get the mean loss value. | |
| If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability | |
| of a completion given its prompt, and `logp_ref` is the log probability of the completion | |
| given its prompt under a reference model, then this function computes the Plackett-Luce DPO | |
| loss (Appendix A.3 of https://arxiv.org/abs/2305.18290). | |
| Args: | |
| scores (torch.Tensor): Model output. Shape (n_scores). | |
| numerator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`. | |
| Shape (n_groups, n_scores). | |
| denominator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`. | |
| Shape (n_groups, n_scores). | |
| Returns: | |
| torch.Tensor: Scalar loss value. | |
| """ | |
| scores_n = torch.where(numerator_mask, scores, float("-inf")) | |
| scores_d = torch.where(denominator_mask, scores, float("-inf")) | |
| numerator = torch.logsumexp(scores_n, dim=-1) | |
| denominator = torch.logsumexp(scores_d, dim=-1) | |
| log_likelihood = torch.where(numerator == float("-inf"), 0.0, numerator - denominator) | |
| return -torch.sum(log_likelihood) | |
| def make_masks_for_plackett_luce_loss( | |
| rankings: List[List[List[int]]], n_scores: int, device: torch.device = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Make masks for the Plackett-Luce loss function. | |
| This function accepts a list of rankings in the following format: | |
| ``` | |
| rankings = [ | |
| [[1], [2], [0]], | |
| [[2], [0]], | |
| [[0, 2], [1]], | |
| ] | |
| ``` | |
| The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking | |
| prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1. | |
| Args: | |
| rankings (List[List[List[int]]]): List of rankings. | |
| n_scores (int): Number of scores. | |
| device (torch.device, optional): Device for the masks. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Tuple of masks. | |
| """ | |
| n_groups = sum(len(ranking) for ranking in rankings) | |
| numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) | |
| denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) | |
| i = 0 | |
| for ranking in rankings: | |
| remaining_items = set(chain.from_iterable(ranking)) | |
| for group in ranking: | |
| items = set(group) | |
| for item in items: | |
| numerator_mask[i, item] = True | |
| for item in remaining_items: | |
| denominator_mask[i, item] = True | |
| remaining_items -= items | |
| i += 1 | |
| return numerator_mask.to(device), denominator_mask.to(device) | |
| def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor: | |
| """Sample from a Plackett-Luce model. | |
| Args: | |
| scores (torch.Tensor): Model output. Shape (n_scores). | |
| shape (Tuple[int], optional): Shape of the samples. | |
| Returns: | |
| torch.Tensor: Samples. Shape (*shape, n_scores). | |
| """ | |
| gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_() | |
| _, indices = torch.sort(scores + gumbel, dim=-1, descending=True) | |
| return indices | |
| def main(): | |
| """Run tests.""" | |
| # Fit a model. | |
| rankings = [ | |
| [[1], [2], [0]], | |
| [[2], [0, 1]], | |
| [[0], [1]], | |
| ] | |
| masks = make_masks_for_plackett_luce_loss(rankings, 3) | |
| scores = torch.zeros(3, requires_grad=True) | |
| opt = torch.optim.SGD([scores], lr=1) | |
| for i in range(20): | |
| loss = plackett_luce_loss(scores, *masks) / len(rankings) | |
| print(f"step: {i}, loss: {loss:.6f}") | |
| loss.backward() | |
| opt.step() | |
| opt.zero_grad() | |
| scores = scores.detach() | |
| # Sample from the model. | |
| samples = sample_plackett_luce(scores, (1_000_000,)) | |
| # Check the probability of a specific ranking. | |
| rankings = [[[2], [0], [1]]] | |
| masks = make_masks_for_plackett_luce_loss(rankings, 3) | |
| expected = torch.exp(-plackett_luce_loss(scores, *masks)) | |
| cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1) | |
| sampled = torch.mean(cond.float()) | |
| print(f"expected: {expected:.4f}") | |
| print(f" sampled: {sampled:.4f}") | |
| # Check the probability of a ranking where we are indifferent to one value. | |
| rankings = [[[2], [0]]] | |
| masks = make_masks_for_plackett_luce_loss(rankings, 3) | |
| expected = torch.exp(-plackett_luce_loss(scores, *masks)) | |
| cond = torch.argmax((samples == 2).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1) | |
| sampled = torch.mean(cond.float()) | |
| print(f"expected: {expected:.4f}") | |
| print(f" sampled: {sampled:.4f}") | |
| # Check the probability of a ranking where we are indifferent between two values. | |
| rankings = [[[1], [0, 2]]] | |
| masks = make_masks_for_plackett_luce_loss(rankings, 3) | |
| expected = torch.exp(-plackett_luce_loss(scores, *masks)) | |
| cond1 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1) | |
| cond2 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 2).int(), dim=-1) | |
| cond = cond1 & cond2 | |
| sampled = torch.mean(cond.float()) | |
| print(f"expected: {expected:.4f}") | |
| print(f" sampled: {sampled:.4f}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment