Last active
November 5, 2024 21:33
-
-
Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.
Revisions
-
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 41 additions and 68 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,7 +12,7 @@ def plackett_luce_loss( denominator_mask: torch.Tensor, weights: torch.Tensor, *, eps: float = 0.0, ) -> torch.Tensor: """Plackett-Luce loss function for ranking tasks. @@ -28,19 +28,17 @@ def plackett_luce_loss( denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups, n_scores). weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups). eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`. Returns: torch.Tensor: Scalar loss value. """ n1 = torch.logsumexp(torch.where(numerator_mask, scores, float("-inf")), dim=-1) d1 = torch.logsumexp(torch.where(denominator_mask, scores, float("-inf")), dim=-1) n2 = torch.logsumexp(torch.where(numerator_mask, -scores, float("-inf")), dim=-1) d2 = torch.logsumexp(torch.where(denominator_mask, -scores, float("-inf")), dim=-1) log_likelihood_parts = torch.where(n1 == float("-inf"), 0.0, torch.lerp(n1 - d1, n2 - d2, eps)) return -torch.sum(weights * log_likelihood_parts) def make_inputs_for_plackett_luce_loss( @@ -67,22 +65,22 @@ def make_inputs_for_plackett_luce_loss( Args: rankings (List[List[List[int]]]): List of rankings. n_scores (int): Number of scores. weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)` for all rankings. device (torch.device, optional): Device for the inputs. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs. """ n_groups = sum(len(ranking) - 1 for ranking in rankings) weights = [1 / len(rankings)] * len(rankings) if weights is None else weights numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) weights_out = torch.empty(n_groups, device="cpu") i = 0 for ranking, weight in zip(rankings, weights): remaining_items = set(chain.from_iterable(ranking)) for group in ranking[:-1]: items = set(group) for item in items: numerator_mask[i, item] = True @@ -98,7 +96,7 @@ def simple_plackett_luce_loss( scores: torch.Tensor, rankings: torch.Tensor, weights: Optional[torch.Tensor] = None, eps: float = 0.0, ) -> torch.Tensor: """Simple Plackett-Luce loss function for ranking tasks. @@ -126,20 +124,21 @@ def simple_plackett_luce_loss( Args: scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking). rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking). weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for all rankings. eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`. Returns: torch.Tensor: Scalar loss value. """ if weights is None: weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2]) n1 = torch.gather(scores, -1, rankings) d1 = n1.flip(-1).logcumsumexp(dim=-1).flip(-1) n2 = -n1 d2 = n2.flip(-1).logcumsumexp(dim=-1).flip(-1) log_likelihoods = torch.lerp(torch.sum(n1 - d1, dim=-1), torch.sum(n2 - d2, dim=-1), eps) return -torch.sum(weights * log_likelihoods, dim=-1) def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor: @@ -186,10 +185,10 @@ def main(): [[2], [0, 1]], [[0], [1]], ] weights = [0.5, 0.2, 0.3] inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights) scores = torch.zeros(3, requires_grad=True) opt = torch.optim.SGD([scores], lr=1) for i in range(20): loss = plackett_luce_loss(scores, *inputs) @@ -244,60 +243,34 @@ def main(): print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check that conservative Plackett-Luce does the right thing. scores = torch.randn(3) rankings = [[[0], [1], [2]]] inputs = make_inputs_for_plackett_luce_loss(rankings, 3) loss = plackett_luce_loss(scores, *inputs, eps=0.1) expected_1 = plackett_luce_loss(scores, *inputs) expected_2 = plackett_luce_loss(-scores, *inputs) expected = 0.9 * expected_1 + 0.1 * expected_2 if torch.allclose(loss, expected): print("Conservative Plackett-Luce test passed.") else: print("Conservative Plackett-Luce test failed.") # Check the simple Plackett-Luce loss function. scores = torch.randn(3, 3) rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]]) loss = simple_plackett_luce_loss(scores, rankings_simple, eps=0.1) rankings = [ [[0], [1], [2]], [[4], [3], [5]], [[8], [6], [7]], ] inputs = make_inputs_for_plackett_luce_loss(rankings, 9) expected = plackett_luce_loss(scores.flatten(), *inputs, eps=0.1) if torch.allclose(loss, expected): print("Simple Plackett-Luce loss function test passed.") else: print("Simple Plackett-Luce loss function test failed.") if __name__ == "__main__": -
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -134,7 +134,7 @@ def simple_plackett_luce_loss( torch.Tensor: Scalar loss value. """ if weights is None: weights = scores.new_ones(rankings.shape[:-1]) numerator = torch.gather(scores, -1, rankings) denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1) nll = -torch.sum(weights * torch.sum(numerator - denominator, dim=-1)) -
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 65 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -11,6 +11,8 @@ def plackett_luce_loss( numerator_mask: torch.Tensor, denominator_mask: torch.Tensor, weights: torch.Tensor, *, prior_weight: float = 0.0, ) -> torch.Tensor: """Plackett-Luce loss function for ranking tasks. @@ -26,6 +28,7 @@ def plackett_luce_loss( denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups, n_scores). weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups). prior_weight (float, optional): Weight for the prior. Returns: torch.Tensor: Scalar loss value. @@ -35,7 +38,9 @@ def plackett_luce_loss( numerator = torch.logsumexp(scores_n, dim=-1) denominator = torch.logsumexp(scores_d, dim=-1) log_likelihood_parts = torch.where(numerator == float("-inf"), 0.0, numerator - denominator) nll = -torch.sum(weights * log_likelihood_parts) prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores)) return nll + prior_weight * prior def make_inputs_for_plackett_luce_loss( @@ -62,15 +67,15 @@ def make_inputs_for_plackett_luce_loss( Args: rankings (List[List[List[int]]]): List of rankings. n_scores (int): Number of scores. weights (List[float], optional): Weights for the rankings. Defaults to `1.0` for all rankings. Note that the default does sum reduction. device (torch.device, optional): Device for the inputs. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs. """ n_groups = sum(len(ranking) for ranking in rankings) weights = [1.0] * len(rankings) if weights is None else weights numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) weights_out = torch.empty(n_groups, device="cpu") @@ -90,7 +95,10 @@ def make_inputs_for_plackett_luce_loss( def simple_plackett_luce_loss( scores: torch.Tensor, rankings: torch.Tensor, weights: Optional[torch.Tensor] = None, prior_weight: float = 0.0, ) -> torch.Tensor: """Simple Plackett-Luce loss function for ranking tasks. @@ -118,17 +126,20 @@ def simple_plackett_luce_loss( Args: scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking). rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking). weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1.0` for all rankings. Note that the default does sum reduction. prior_weight (float, optional): Weight for the prior. Returns: torch.Tensor: Scalar loss value. """ if weights is None: weights = scores.new_full(rankings.shape[:-1], 1.0) numerator = torch.gather(scores, -1, rankings) denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1) nll = -torch.sum(weights * torch.sum(numerator - denominator, dim=-1)) prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores)) return nll + prior_weight * prior def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor: @@ -175,12 +186,12 @@ def main(): [[2], [0, 1]], [[0], [1]], ] weights = [1.0, 1.0, 2.0] inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights) scores = torch.zeros(3, requires_grad=True) opt = torch.optim.SGD([scores], lr=0.5) for i in range(20): loss = plackett_luce_loss(scores, *inputs) print(f"step: {i}, loss: {loss:.6f}") loss.backward() @@ -245,7 +256,48 @@ def main(): inputs = make_inputs_for_plackett_luce_loss(rankings, 9) expected = plackett_luce_loss(scores.flatten(), *inputs) if torch.allclose(loss, expected): print("Simple Plackett-Luce loss function test passed.") else: print("Simple Plackett-Luce loss function test failed.") # Check that the prior does the same thing as adding pseudo-rankings. scores = torch.randn(3) rankings = [ [[1], [2], [0]], [[2], [0, 1]], [[0], [1]], ] weights = [1.0] * 3 inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights) loss = plackett_luce_loss(scores, *inputs, prior_weight=0.5) for i in range(3): rankings.append([[i], [3]]) weights.append(0.5) rankings.append([[3], [i]]) weights.append(0.5) inputs = make_inputs_for_plackett_luce_loss(rankings, 4, weights) scores_pseudo = torch.cat((scores, torch.zeros(1))) expected = plackett_luce_loss(scores_pseudo, *inputs) if torch.allclose(loss, expected): print("Prior test passed.") else: print("Prior test failed.") # Check that the prior does the same thing in both versions of the loss function. scores = torch.randn(3, 3) rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]]) loss_simple = simple_plackett_luce_loss(scores, rankings_simple, prior_weight=0.5) rankings = [ [[0], [1], [2]], [[4], [3], [5]], [[8], [6], [7]], ] inputs = make_inputs_for_plackett_luce_loss(rankings, 9) expected = plackett_luce_loss(scores.flatten(), *inputs, prior_weight=0.5) if torch.allclose(loss_simple, expected): print("Simple loss function prior weight test passed.") else: print("Simple loss function prior weight test failed.") if __name__ == "__main__": -
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -116,9 +116,9 @@ def simple_plackett_luce_loss( with a shape (3, 3) tensor of scores. Args: scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking). rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking). weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for all rankings. Returns: -
crowsonkb revised this gist
Nov 5, 2024 . 2 changed files with 252 additions and 154 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,252 @@ """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.""" from itertools import chain from typing import List, Optional, Tuple import torch def plackett_luce_loss( scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor, weights: torch.Tensor, ) -> torch.Tensor: """Plackett-Luce loss function for ranking tasks. If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability of a completion given its prompt, and `logp_ref` is the log probability of the completion given its prompt under a reference model, then this function computes the Plackett-Luce DPO loss (Appendix A.3 of https://arxiv.org/abs/2305.18290). Args: scores (torch.Tensor): Model output. Shape (n_scores). numerator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups, n_scores). denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups, n_scores). weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups). Returns: torch.Tensor: Scalar loss value. """ scores_n = torch.where(numerator_mask, scores, float("-inf")) scores_d = torch.where(denominator_mask, scores, float("-inf")) numerator = torch.logsumexp(scores_n, dim=-1) denominator = torch.logsumexp(scores_d, dim=-1) log_likelihood_parts = torch.where(numerator == float("-inf"), 0.0, numerator - denominator) return -torch.sum(weights * log_likelihood_parts) def make_inputs_for_plackett_luce_loss( rankings: List[List[List[int]]], n_scores: int, weights: Optional[List[float]] = None, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Make inputs for the Plackett-Luce loss function. This function accepts a list of rankings in the following format: ``` rankings = [ [[1], [2], [0]], [[2], [0]], [[0, 2], [1]], ] ``` The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1. Args: rankings (List[List[List[int]]]): List of rankings. n_scores (int): Number of scores. weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)` for all rankings. device (torch.device, optional): Device for the inputs. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs. """ n_groups = sum(len(ranking) for ranking in rankings) weights = [1 / len(rankings)] * len(rankings) if weights is None else weights numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) weights_out = torch.empty(n_groups, device="cpu") i = 0 for ranking, weight in zip(rankings, weights): remaining_items = set(chain.from_iterable(ranking)) for group in ranking: items = set(group) for item in items: numerator_mask[i, item] = True for item in remaining_items: denominator_mask[i, item] = True weights_out[i] = weight remaining_items -= items i += 1 return numerator_mask.to(device), denominator_mask.to(device), weights_out.to(device) def simple_plackett_luce_loss( scores: torch.Tensor, rankings: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """Simple Plackett-Luce loss function for ranking tasks. This function is an easier to use version of `plackett_luce_loss()`. It accepts a batched tensor of scores, a batched tensor of rankings, and an optional tensor of weights. Ties are not supported, all rankings must rank the same number of items, and all rankings are indifferent to items not in the ranking. For example, if the rankings for `plackett_luce_loss()` are: ``` rankings = [ [[0], [1], [2]], [[4], [3], [5]], [[8], [6], [7]], ] ``` with a shape (9) tensor of scores, then the rankings for `simple_plackett_luce_loss()` are: ``` rankings = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]]) ``` with a shape (3, 3) tensor of scores. Args: scores (torch.Tensor): Model output. Shape (n_groups, n_scores_per_group). rankings (torch.Tensor): Rankings. Shape (n_groups, n_scores_per_group). weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_groups` for all rankings. Returns: torch.Tensor: Scalar loss value. """ if weights is None: weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2]) numerator = torch.gather(scores, -1, rankings) denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1) return -torch.sum(weights * torch.sum(numerator - denominator, dim=-1)) def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor: """Sample from a Plackett-Luce model. Args: scores (torch.Tensor): Model output. Shape (n_scores). shape (Tuple[int], optional): Shape of the samples. Returns: torch.Tensor: Samples. Shape (*shape, n_scores). """ gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_() _, indices = torch.sort(scores + gumbel, dim=-1, descending=True) return indices def is_before(permutation: torch.Tensor, a: int, b: int) -> torch.Tensor: """Check if a is before b in a permutation. Args: permutation (torch.Tensor): Permutation. Shape (*, n_scores). a (int): The first element. b (int): The second element. Returns: torch.Tensor: Boolean tensor. Shape (*). """ a_mask = permutation == a b_mask = permutation == b a_present = torch.any(a_mask, dim=-1) b_present = torch.any(b_mask, dim=-1) pos_a = torch.argmax(a_mask.byte(), dim=-1) pos_b = torch.argmax(b_mask.byte(), dim=-1) return a_present & b_present & (pos_a < pos_b) def main(): """Run tests.""" # Fit a model. rankings = [ [[1], [2], [0]], [[2], [0, 1]], [[0], [1]], ] weights = [0.5, 0.3, 0.2] inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights) scores = torch.zeros(3, requires_grad=True) opt = torch.optim.SGD([scores], lr=1) for i in range(25): loss = plackett_luce_loss(scores, *inputs) print(f"step: {i}, loss: {loss:.6f}") loss.backward() opt.step() opt.zero_grad() scores = scores.detach() # Sample from the model. samples = sample_plackett_luce(scores, (1_000_000,)) # Check the probability of a specific ranking. rankings = [[[2], [0], [1]]] inputs = make_inputs_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *inputs)) cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1) sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check the probability of a ranking where we are indifferent to one value. rankings = [[[2], [0]]] inputs = make_inputs_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *inputs)) cond = is_before(samples, 2, 0) sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check the probability of a ranking where we are indifferent between two values. rankings = [[[1], [0, 2]]] inputs = make_inputs_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *inputs)) cond = is_before(samples, 1, 0) & is_before(samples, 1, 2) sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check a more complicated example. Make up a Plackett-Luce model for this one. scores = torch.linspace(0, 2, 5) rankings = [[[0, 1], [2], [3]]] inputs = make_inputs_for_plackett_luce_loss(rankings, 5) expected = torch.exp(-plackett_luce_loss(scores, *inputs)) samples = sample_plackett_luce(scores, (1_000_000,)) cond1 = is_before(samples, 0, 2) & is_before(samples, 0, 3) cond2 = is_before(samples, 1, 2) & is_before(samples, 1, 3) cond3 = is_before(samples, 2, 3) cond = (cond1 | cond2) & cond3 sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check the simple Plackett-Luce loss function. scores = torch.randn(3, 3) rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]]) loss = simple_plackett_luce_loss(scores, rankings_simple) rankings = [ [[0], [1], [2]], [[4], [3], [5]], [[8], [6], [7]], ] inputs = make_inputs_for_plackett_luce_loss(rankings, 9) expected = plackett_luce_loss(scores.flatten(), *inputs) if torch.allclose(loss, expected): print("Simple Plackett-Luce loss function passed.") if __name__ == "__main__": main() This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,154 +0,0 @@ -
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -53,7 +53,7 @@ def make_masks_for_plackett_luce_loss( ``` The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1. Args: rankings (List[List[List[int]]]): List of rankings. -
crowsonkb revised this gist
Nov 5, 2024 . 1 changed file with 1 addition and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -53,8 +53,7 @@ def make_masks_for_plackett_luce_loss( ``` The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking prefers 2 to 0 and is indifferent to 1. The third ranking prefers either of 0 and 2 to 1. Args: rankings (List[List[List[int]]]): List of rankings. -
crowsonkb created this gist
Nov 4, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,155 @@ """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.""" from itertools import chain from typing import List, Tuple import torch def plackett_luce_loss( scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor ) -> torch.Tensor: """Plackett-Luce loss function for ranking tasks. This function uses sum reduction for the loss value. Divide the sum by `len(rankings)` to get the mean loss value. If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability of a completion given its prompt, and `logp_ref` is the log probability of the completion given its prompt under a reference model, then this function computes the Plackett-Luce DPO loss (Appendix A.3 of https://arxiv.org/abs/2305.18290). Args: scores (torch.Tensor): Model output. Shape (n_scores). numerator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`. Shape (n_groups, n_scores). denominator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`. Shape (n_groups, n_scores). Returns: torch.Tensor: Scalar loss value. """ scores_n = torch.where(numerator_mask, scores, float("-inf")) scores_d = torch.where(denominator_mask, scores, float("-inf")) numerator = torch.logsumexp(scores_n, dim=-1) denominator = torch.logsumexp(scores_d, dim=-1) log_likelihood = torch.where(numerator == float("-inf"), 0.0, numerator - denominator) return -torch.sum(log_likelihood) def make_masks_for_plackett_luce_loss( rankings: List[List[List[int]]], n_scores: int, device: torch.device = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Make masks for the Plackett-Luce loss function. This function accepts a list of rankings in the following format: ``` rankings = [ [[1], [2], [0]], [[2], [0]], [[0, 2], [1]], ] ``` The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking prefers 2 to 0 and is indifferent to 1. The third ranking prefers 0 and 2 to 1 and is indifferent between 0 and 2. Args: rankings (List[List[List[int]]]): List of rankings. n_scores (int): Number of scores. device (torch.device, optional): Device for the masks. Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of masks. """ n_groups = sum(len(ranking) for ranking in rankings) numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool) i = 0 for ranking in rankings: remaining_items = set(chain.from_iterable(ranking)) for group in ranking: items = set(group) for item in items: numerator_mask[i, item] = True for item in remaining_items: denominator_mask[i, item] = True remaining_items -= items i += 1 return numerator_mask.to(device), denominator_mask.to(device) def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor: """Sample from a Plackett-Luce model. Args: scores (torch.Tensor): Model output. Shape (n_scores). shape (Tuple[int], optional): Shape of the samples. Returns: torch.Tensor: Samples. Shape (*shape, n_scores). """ gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_() _, indices = torch.sort(scores + gumbel, dim=-1, descending=True) return indices def main(): """Run tests.""" # Fit a model. rankings = [ [[1], [2], [0]], [[2], [0, 1]], [[0], [1]], ] masks = make_masks_for_plackett_luce_loss(rankings, 3) scores = torch.zeros(3, requires_grad=True) opt = torch.optim.SGD([scores], lr=1) for i in range(20): loss = plackett_luce_loss(scores, *masks) / len(rankings) print(f"step: {i}, loss: {loss:.6f}") loss.backward() opt.step() opt.zero_grad() scores = scores.detach() # Sample from the model. samples = sample_plackett_luce(scores, (1_000_000,)) # Check the probability of a specific ranking. rankings = [[[2], [0], [1]]] masks = make_masks_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *masks)) cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1) sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check the probability of a ranking where we are indifferent to one value. rankings = [[[2], [0]]] masks = make_masks_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *masks)) cond = torch.argmax((samples == 2).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1) sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") # Check the probability of a ranking where we are indifferent between two values. rankings = [[[1], [0, 2]]] masks = make_masks_for_plackett_luce_loss(rankings, 3) expected = torch.exp(-plackett_luce_loss(scores, *masks)) cond1 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1) cond2 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 2).int(), dim=-1) cond = cond1 & cond2 sampled = torch.mean(cond.float()) print(f"expected: {expected:.4f}") print(f" sampled: {sampled:.4f}") if __name__ == "__main__": main()