Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active November 5, 2024 21:33
Show Gist options
  • Select an option

  • Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.

Select an option

Save crowsonkb/7df88ec63ea19ac335aa8b6c8f530769 to your computer and use it in GitHub Desktop.

Revisions

  1. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 41 additions and 68 deletions.
    109 changes: 41 additions & 68 deletions plackett_luce.py
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@ def plackett_luce_loss(
    denominator_mask: torch.Tensor,
    weights: torch.Tensor,
    *,
    prior_weight: float = 0.0,
    eps: float = 0.0,
    ) -> torch.Tensor:
    """Plackett-Luce loss function for ranking tasks.
    @@ -28,19 +28,17 @@ def plackett_luce_loss(
    denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups).
    prior_weight (float, optional): Weight for the prior.
    eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`.
    Returns:
    torch.Tensor: Scalar loss value.
    """
    scores_n = torch.where(numerator_mask, scores, float("-inf"))
    scores_d = torch.where(denominator_mask, scores, float("-inf"))
    numerator = torch.logsumexp(scores_n, dim=-1)
    denominator = torch.logsumexp(scores_d, dim=-1)
    log_likelihood_parts = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
    nll = -torch.sum(weights * log_likelihood_parts)
    prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores))
    return nll + prior_weight * prior
    n1 = torch.logsumexp(torch.where(numerator_mask, scores, float("-inf")), dim=-1)
    d1 = torch.logsumexp(torch.where(denominator_mask, scores, float("-inf")), dim=-1)
    n2 = torch.logsumexp(torch.where(numerator_mask, -scores, float("-inf")), dim=-1)
    d2 = torch.logsumexp(torch.where(denominator_mask, -scores, float("-inf")), dim=-1)
    log_likelihood_parts = torch.where(n1 == float("-inf"), 0.0, torch.lerp(n1 - d1, n2 - d2, eps))
    return -torch.sum(weights * log_likelihood_parts)


    def make_inputs_for_plackett_luce_loss(
    @@ -67,22 +65,22 @@ def make_inputs_for_plackett_luce_loss(
    Args:
    rankings (List[List[List[int]]]): List of rankings.
    n_scores (int): Number of scores.
    weights (List[float], optional): Weights for the rankings. Defaults to `1.0` for all
    rankings. Note that the default does sum reduction.
    weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)`
    for all rankings.
    device (torch.device, optional): Device for the inputs.
    Returns:
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs.
    """
    n_groups = sum(len(ranking) for ranking in rankings)
    weights = [1.0] * len(rankings) if weights is None else weights
    n_groups = sum(len(ranking) - 1 for ranking in rankings)
    weights = [1 / len(rankings)] * len(rankings) if weights is None else weights
    numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    weights_out = torch.empty(n_groups, device="cpu")
    i = 0
    for ranking, weight in zip(rankings, weights):
    remaining_items = set(chain.from_iterable(ranking))
    for group in ranking:
    for group in ranking[:-1]:
    items = set(group)
    for item in items:
    numerator_mask[i, item] = True
    @@ -98,7 +96,7 @@ def simple_plackett_luce_loss(
    scores: torch.Tensor,
    rankings: torch.Tensor,
    weights: Optional[torch.Tensor] = None,
    prior_weight: float = 0.0,
    eps: float = 0.0,
    ) -> torch.Tensor:
    """Simple Plackett-Luce loss function for ranking tasks.
    @@ -126,20 +124,21 @@ def simple_plackett_luce_loss(
    Args:
    scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking).
    rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking).
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1.0` for all
    rankings. Note that the default does sum reduction.
    prior_weight (float, optional): Weight for the prior.
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for
    all rankings.
    eps (float, optional): Epsilon for conservative Plackett-Luce. Defaults to `0.0`.
    Returns:
    torch.Tensor: Scalar loss value.
    """
    if weights is None:
    weights = scores.new_ones(rankings.shape[:-1])
    numerator = torch.gather(scores, -1, rankings)
    denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1)
    nll = -torch.sum(weights * torch.sum(numerator - denominator, dim=-1))
    prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores))
    return nll + prior_weight * prior
    weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2])
    n1 = torch.gather(scores, -1, rankings)
    d1 = n1.flip(-1).logcumsumexp(dim=-1).flip(-1)
    n2 = -n1
    d2 = n2.flip(-1).logcumsumexp(dim=-1).flip(-1)
    log_likelihoods = torch.lerp(torch.sum(n1 - d1, dim=-1), torch.sum(n2 - d2, dim=-1), eps)
    return -torch.sum(weights * log_likelihoods, dim=-1)


    def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
    @@ -186,10 +185,10 @@ def main():
    [[2], [0, 1]],
    [[0], [1]],
    ]
    weights = [1.0, 1.0, 2.0]
    weights = [0.5, 0.2, 0.3]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
    scores = torch.zeros(3, requires_grad=True)
    opt = torch.optim.SGD([scores], lr=0.5)
    opt = torch.optim.SGD([scores], lr=1)

    for i in range(20):
    loss = plackett_luce_loss(scores, *inputs)
    @@ -244,60 +243,34 @@ def main():
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the simple Plackett-Luce loss function.
    scores = torch.randn(3, 3)
    rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
    loss = simple_plackett_luce_loss(scores, rankings_simple)
    rankings = [
    [[0], [1], [2]],
    [[4], [3], [5]],
    [[8], [6], [7]],
    ]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
    expected = plackett_luce_loss(scores.flatten(), *inputs)
    if torch.allclose(loss, expected):
    print("Simple Plackett-Luce loss function test passed.")
    else:
    print("Simple Plackett-Luce loss function test failed.")

    # Check that the prior does the same thing as adding pseudo-rankings.
    # Check that conservative Plackett-Luce does the right thing.
    scores = torch.randn(3)
    rankings = [
    [[1], [2], [0]],
    [[2], [0, 1]],
    [[0], [1]],
    ]
    weights = [1.0] * 3
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
    loss = plackett_luce_loss(scores, *inputs, prior_weight=0.5)
    for i in range(3):
    rankings.append([[i], [3]])
    weights.append(0.5)
    rankings.append([[3], [i]])
    weights.append(0.5)
    inputs = make_inputs_for_plackett_luce_loss(rankings, 4, weights)
    scores_pseudo = torch.cat((scores, torch.zeros(1)))
    expected = plackett_luce_loss(scores_pseudo, *inputs)
    rankings = [[[0], [1], [2]]]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
    loss = plackett_luce_loss(scores, *inputs, eps=0.1)
    expected_1 = plackett_luce_loss(scores, *inputs)
    expected_2 = plackett_luce_loss(-scores, *inputs)
    expected = 0.9 * expected_1 + 0.1 * expected_2
    if torch.allclose(loss, expected):
    print("Prior test passed.")
    print("Conservative Plackett-Luce test passed.")
    else:
    print("Prior test failed.")
    print("Conservative Plackett-Luce test failed.")

    # Check that the prior does the same thing in both versions of the loss function.
    # Check the simple Plackett-Luce loss function.
    scores = torch.randn(3, 3)
    rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
    loss_simple = simple_plackett_luce_loss(scores, rankings_simple, prior_weight=0.5)
    loss = simple_plackett_luce_loss(scores, rankings_simple, eps=0.1)
    rankings = [
    [[0], [1], [2]],
    [[4], [3], [5]],
    [[8], [6], [7]],
    ]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
    expected = plackett_luce_loss(scores.flatten(), *inputs, prior_weight=0.5)
    if torch.allclose(loss_simple, expected):
    print("Simple loss function prior weight test passed.")
    expected = plackett_luce_loss(scores.flatten(), *inputs, eps=0.1)
    if torch.allclose(loss, expected):
    print("Simple Plackett-Luce loss function test passed.")
    else:
    print("Simple loss function prior weight test failed.")
    print("Simple Plackett-Luce loss function test failed.")


    if __name__ == "__main__":
  2. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion plackett_luce.py
    Original file line number Diff line number Diff line change
    @@ -134,7 +134,7 @@ def simple_plackett_luce_loss(
    torch.Tensor: Scalar loss value.
    """
    if weights is None:
    weights = scores.new_full(rankings.shape[:-1], 1.0)
    weights = scores.new_ones(rankings.shape[:-1])
    numerator = torch.gather(scores, -1, rankings)
    denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1)
    nll = -torch.sum(weights * torch.sum(numerator - denominator, dim=-1))
  3. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 65 additions and 13 deletions.
    78 changes: 65 additions & 13 deletions plackett_luce.py
    Original file line number Diff line number Diff line change
    @@ -11,6 +11,8 @@ def plackett_luce_loss(
    numerator_mask: torch.Tensor,
    denominator_mask: torch.Tensor,
    weights: torch.Tensor,
    *,
    prior_weight: float = 0.0,
    ) -> torch.Tensor:
    """Plackett-Luce loss function for ranking tasks.
    @@ -26,6 +28,7 @@ def plackett_luce_loss(
    denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups).
    prior_weight (float, optional): Weight for the prior.
    Returns:
    torch.Tensor: Scalar loss value.
    @@ -35,7 +38,9 @@ def plackett_luce_loss(
    numerator = torch.logsumexp(scores_n, dim=-1)
    denominator = torch.logsumexp(scores_d, dim=-1)
    log_likelihood_parts = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
    return -torch.sum(weights * log_likelihood_parts)
    nll = -torch.sum(weights * log_likelihood_parts)
    prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores))
    return nll + prior_weight * prior


    def make_inputs_for_plackett_luce_loss(
    @@ -62,15 +67,15 @@ def make_inputs_for_plackett_luce_loss(
    Args:
    rankings (List[List[List[int]]]): List of rankings.
    n_scores (int): Number of scores.
    weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)`
    for all rankings.
    weights (List[float], optional): Weights for the rankings. Defaults to `1.0` for all
    rankings. Note that the default does sum reduction.
    device (torch.device, optional): Device for the inputs.
    Returns:
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs.
    """
    n_groups = sum(len(ranking) for ranking in rankings)
    weights = [1 / len(rankings)] * len(rankings) if weights is None else weights
    weights = [1.0] * len(rankings) if weights is None else weights
    numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    weights_out = torch.empty(n_groups, device="cpu")
    @@ -90,7 +95,10 @@ def make_inputs_for_plackett_luce_loss(


    def simple_plackett_luce_loss(
    scores: torch.Tensor, rankings: torch.Tensor, weights: Optional[torch.Tensor] = None
    scores: torch.Tensor,
    rankings: torch.Tensor,
    weights: Optional[torch.Tensor] = None,
    prior_weight: float = 0.0,
    ) -> torch.Tensor:
    """Simple Plackett-Luce loss function for ranking tasks.
    @@ -118,17 +126,20 @@ def simple_plackett_luce_loss(
    Args:
    scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking).
    rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking).
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for
    all rankings.
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1.0` for all
    rankings. Note that the default does sum reduction.
    prior_weight (float, optional): Weight for the prior.
    Returns:
    torch.Tensor: Scalar loss value.
    """
    if weights is None:
    weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2])
    weights = scores.new_full(rankings.shape[:-1], 1.0)
    numerator = torch.gather(scores, -1, rankings)
    denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1)
    return -torch.sum(weights * torch.sum(numerator - denominator, dim=-1))
    nll = -torch.sum(weights * torch.sum(numerator - denominator, dim=-1))
    prior = 2 * torch.sum(torch.logaddexp(0.5 * scores, -0.5 * scores))
    return nll + prior_weight * prior


    def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
    @@ -175,12 +186,12 @@ def main():
    [[2], [0, 1]],
    [[0], [1]],
    ]
    weights = [0.5, 0.3, 0.2]
    weights = [1.0, 1.0, 2.0]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
    scores = torch.zeros(3, requires_grad=True)
    opt = torch.optim.SGD([scores], lr=1)
    opt = torch.optim.SGD([scores], lr=0.5)

    for i in range(25):
    for i in range(20):
    loss = plackett_luce_loss(scores, *inputs)
    print(f"step: {i}, loss: {loss:.6f}")
    loss.backward()
    @@ -245,7 +256,48 @@ def main():
    inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
    expected = plackett_luce_loss(scores.flatten(), *inputs)
    if torch.allclose(loss, expected):
    print("Simple Plackett-Luce loss function passed.")
    print("Simple Plackett-Luce loss function test passed.")
    else:
    print("Simple Plackett-Luce loss function test failed.")

    # Check that the prior does the same thing as adding pseudo-rankings.
    scores = torch.randn(3)
    rankings = [
    [[1], [2], [0]],
    [[2], [0, 1]],
    [[0], [1]],
    ]
    weights = [1.0] * 3
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
    loss = plackett_luce_loss(scores, *inputs, prior_weight=0.5)
    for i in range(3):
    rankings.append([[i], [3]])
    weights.append(0.5)
    rankings.append([[3], [i]])
    weights.append(0.5)
    inputs = make_inputs_for_plackett_luce_loss(rankings, 4, weights)
    scores_pseudo = torch.cat((scores, torch.zeros(1)))
    expected = plackett_luce_loss(scores_pseudo, *inputs)
    if torch.allclose(loss, expected):
    print("Prior test passed.")
    else:
    print("Prior test failed.")

    # Check that the prior does the same thing in both versions of the loss function.
    scores = torch.randn(3, 3)
    rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
    loss_simple = simple_plackett_luce_loss(scores, rankings_simple, prior_weight=0.5)
    rankings = [
    [[0], [1], [2]],
    [[4], [3], [5]],
    [[8], [6], [7]],
    ]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
    expected = plackett_luce_loss(scores.flatten(), *inputs, prior_weight=0.5)
    if torch.allclose(loss_simple, expected):
    print("Simple loss function prior weight test passed.")
    else:
    print("Simple loss function prior weight test failed.")


    if __name__ == "__main__":
  4. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions plackett_luce.py
    Original file line number Diff line number Diff line change
    @@ -116,9 +116,9 @@ def simple_plackett_luce_loss(
    with a shape (3, 3) tensor of scores.
    Args:
    scores (torch.Tensor): Model output. Shape (n_groups, n_scores_per_group).
    rankings (torch.Tensor): Rankings. Shape (n_groups, n_scores_per_group).
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_groups` for
    scores (torch.Tensor): Model output. Shape (n_rankings, n_scores_per_ranking).
    rankings (torch.Tensor): Rankings. Shape (n_rankings, n_scores_per_ranking).
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_rankings` for
    all rankings.
    Returns:
  5. crowsonkb revised this gist Nov 5, 2024. 2 changed files with 252 additions and 154 deletions.
    252 changes: 252 additions & 0 deletions plackett_luce.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,252 @@
    """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""

    from itertools import chain
    from typing import List, Optional, Tuple

    import torch


    def plackett_luce_loss(
    scores: torch.Tensor,
    numerator_mask: torch.Tensor,
    denominator_mask: torch.Tensor,
    weights: torch.Tensor,
    ) -> torch.Tensor:
    """Plackett-Luce loss function for ranking tasks.
    If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability
    of a completion given its prompt, and `logp_ref` is the log probability of the completion
    given its prompt under a reference model, then this function computes the Plackett-Luce DPO
    loss (Appendix A.3 of https://arxiv.org/abs/2305.18290).
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    numerator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    denominator_mask (torch.Tensor): From `make_inputs_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    weights: From `make_inputs_for_plackett_luce_loss()`. Shape (n_groups).
    Returns:
    torch.Tensor: Scalar loss value.
    """
    scores_n = torch.where(numerator_mask, scores, float("-inf"))
    scores_d = torch.where(denominator_mask, scores, float("-inf"))
    numerator = torch.logsumexp(scores_n, dim=-1)
    denominator = torch.logsumexp(scores_d, dim=-1)
    log_likelihood_parts = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
    return -torch.sum(weights * log_likelihood_parts)


    def make_inputs_for_plackett_luce_loss(
    rankings: List[List[List[int]]],
    n_scores: int,
    weights: Optional[List[float]] = None,
    device: Optional[torch.device] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Make inputs for the Plackett-Luce loss function.
    This function accepts a list of rankings in the following format:
    ```
    rankings = [
    [[1], [2], [0]],
    [[2], [0]],
    [[0, 2], [1]],
    ]
    ```
    The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1.
    Args:
    rankings (List[List[List[int]]]): List of rankings.
    n_scores (int): Number of scores.
    weights (List[float], optional): Weights for the rankings. Defaults to `1 / len(rankings)`
    for all rankings.
    device (torch.device, optional): Device for the inputs.
    Returns:
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of inputs.
    """
    n_groups = sum(len(ranking) for ranking in rankings)
    weights = [1 / len(rankings)] * len(rankings) if weights is None else weights
    numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    weights_out = torch.empty(n_groups, device="cpu")
    i = 0
    for ranking, weight in zip(rankings, weights):
    remaining_items = set(chain.from_iterable(ranking))
    for group in ranking:
    items = set(group)
    for item in items:
    numerator_mask[i, item] = True
    for item in remaining_items:
    denominator_mask[i, item] = True
    weights_out[i] = weight
    remaining_items -= items
    i += 1
    return numerator_mask.to(device), denominator_mask.to(device), weights_out.to(device)


    def simple_plackett_luce_loss(
    scores: torch.Tensor, rankings: torch.Tensor, weights: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
    """Simple Plackett-Luce loss function for ranking tasks.
    This function is an easier to use version of `plackett_luce_loss()`. It accepts a batched tensor
    of scores, a batched tensor of rankings, and an optional tensor of weights. Ties are not
    supported, all rankings must rank the same number of items, and all rankings are indifferent to
    items not in the ranking. For example, if the rankings for `plackett_luce_loss()` are:
    ```
    rankings = [
    [[0], [1], [2]],
    [[4], [3], [5]],
    [[8], [6], [7]],
    ]
    ```
    with a shape (9) tensor of scores, then the rankings for `simple_plackett_luce_loss()` are:
    ```
    rankings = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
    ```
    with a shape (3, 3) tensor of scores.
    Args:
    scores (torch.Tensor): Model output. Shape (n_groups, n_scores_per_group).
    rankings (torch.Tensor): Rankings. Shape (n_groups, n_scores_per_group).
    weights (torch.Tensor, optional): Weights for the rankings. Defaults to `1 / n_groups` for
    all rankings.
    Returns:
    torch.Tensor: Scalar loss value.
    """
    if weights is None:
    weights = scores.new_full(rankings.shape[:-1], 1 / rankings.shape[-2])
    numerator = torch.gather(scores, -1, rankings)
    denominator = numerator.flip(-1).logcumsumexp(dim=-1).flip(-1)
    return -torch.sum(weights * torch.sum(numerator - denominator, dim=-1))


    def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
    """Sample from a Plackett-Luce model.
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    shape (Tuple[int], optional): Shape of the samples.
    Returns:
    torch.Tensor: Samples. Shape (*shape, n_scores).
    """
    gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_()
    _, indices = torch.sort(scores + gumbel, dim=-1, descending=True)
    return indices


    def is_before(permutation: torch.Tensor, a: int, b: int) -> torch.Tensor:
    """Check if a is before b in a permutation.
    Args:
    permutation (torch.Tensor): Permutation. Shape (*, n_scores).
    a (int): The first element.
    b (int): The second element.
    Returns:
    torch.Tensor: Boolean tensor. Shape (*).
    """
    a_mask = permutation == a
    b_mask = permutation == b
    a_present = torch.any(a_mask, dim=-1)
    b_present = torch.any(b_mask, dim=-1)
    pos_a = torch.argmax(a_mask.byte(), dim=-1)
    pos_b = torch.argmax(b_mask.byte(), dim=-1)
    return a_present & b_present & (pos_a < pos_b)


    def main():
    """Run tests."""

    # Fit a model.
    rankings = [
    [[1], [2], [0]],
    [[2], [0, 1]],
    [[0], [1]],
    ]
    weights = [0.5, 0.3, 0.2]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3, weights)
    scores = torch.zeros(3, requires_grad=True)
    opt = torch.optim.SGD([scores], lr=1)

    for i in range(25):
    loss = plackett_luce_loss(scores, *inputs)
    print(f"step: {i}, loss: {loss:.6f}")
    loss.backward()
    opt.step()
    opt.zero_grad()

    scores = scores.detach()

    # Sample from the model.
    samples = sample_plackett_luce(scores, (1_000_000,))

    # Check the probability of a specific ranking.
    rankings = [[[2], [0], [1]]]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *inputs))
    cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent to one value.
    rankings = [[[2], [0]]]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *inputs))
    cond = is_before(samples, 2, 0)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent between two values.
    rankings = [[[1], [0, 2]]]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *inputs))
    cond = is_before(samples, 1, 0) & is_before(samples, 1, 2)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check a more complicated example. Make up a Plackett-Luce model for this one.
    scores = torch.linspace(0, 2, 5)
    rankings = [[[0, 1], [2], [3]]]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 5)
    expected = torch.exp(-plackett_luce_loss(scores, *inputs))
    samples = sample_plackett_luce(scores, (1_000_000,))
    cond1 = is_before(samples, 0, 2) & is_before(samples, 0, 3)
    cond2 = is_before(samples, 1, 2) & is_before(samples, 1, 3)
    cond3 = is_before(samples, 2, 3)
    cond = (cond1 | cond2) & cond3
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the simple Plackett-Luce loss function.
    scores = torch.randn(3, 3)
    rankings_simple = torch.tensor([[0, 1, 2], [1, 0, 2], [2, 0, 1]])
    loss = simple_plackett_luce_loss(scores, rankings_simple)
    rankings = [
    [[0], [1], [2]],
    [[4], [3], [5]],
    [[8], [6], [7]],
    ]
    inputs = make_inputs_for_plackett_luce_loss(rankings, 9)
    expected = plackett_luce_loss(scores.flatten(), *inputs)
    if torch.allclose(loss, expected):
    print("Simple Plackett-Luce loss function passed.")


    if __name__ == "__main__":
    main()
    154 changes: 0 additions & 154 deletions plackett_luce_loss.py
    Original file line number Diff line number Diff line change
    @@ -1,154 +0,0 @@
    """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""

    from itertools import chain
    from typing import List, Tuple

    import torch


    def plackett_luce_loss(
    scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor
    ) -> torch.Tensor:
    """Plackett-Luce loss function for ranking tasks.
    This function uses sum reduction for the loss value. Divide the sum by `len(rankings)`
    to get the mean loss value.
    If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability
    of a completion given its prompt, and `logp_ref` is the log probability of the completion
    given its prompt under a reference model, then this function computes the Plackett-Luce DPO
    loss (Appendix A.3 of https://arxiv.org/abs/2305.18290).
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    numerator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    denominator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    Returns:
    torch.Tensor: Scalar loss value.
    """
    scores_n = torch.where(numerator_mask, scores, float("-inf"))
    scores_d = torch.where(denominator_mask, scores, float("-inf"))
    numerator = torch.logsumexp(scores_n, dim=-1)
    denominator = torch.logsumexp(scores_d, dim=-1)
    log_likelihood = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
    return -torch.sum(log_likelihood)


    def make_masks_for_plackett_luce_loss(
    rankings: List[List[List[int]]], n_scores: int, device: torch.device = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Make masks for the Plackett-Luce loss function.
    This function accepts a list of rankings in the following format:
    ```
    rankings = [
    [[1], [2], [0]],
    [[2], [0]],
    [[0, 2], [1]],
    ]
    ```
    The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1.
    Args:
    rankings (List[List[List[int]]]): List of rankings.
    n_scores (int): Number of scores.
    device (torch.device, optional): Device for the masks.
    Returns:
    Tuple[torch.Tensor, torch.Tensor]: Tuple of masks.
    """
    n_groups = sum(len(ranking) for ranking in rankings)
    numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    i = 0
    for ranking in rankings:
    remaining_items = set(chain.from_iterable(ranking))
    for group in ranking:
    items = set(group)
    for item in items:
    numerator_mask[i, item] = True
    for item in remaining_items:
    denominator_mask[i, item] = True
    remaining_items -= items
    i += 1
    return numerator_mask.to(device), denominator_mask.to(device)


    def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
    """Sample from a Plackett-Luce model.
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    shape (Tuple[int], optional): Shape of the samples.
    Returns:
    torch.Tensor: Samples. Shape (*shape, n_scores).
    """
    gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_()
    _, indices = torch.sort(scores + gumbel, dim=-1, descending=True)
    return indices


    def main():
    """Run tests."""

    # Fit a model.
    rankings = [
    [[1], [2], [0]],
    [[2], [0, 1]],
    [[0], [1]],
    ]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    scores = torch.zeros(3, requires_grad=True)
    opt = torch.optim.SGD([scores], lr=1)

    for i in range(20):
    loss = plackett_luce_loss(scores, *masks) / len(rankings)
    print(f"step: {i}, loss: {loss:.6f}")
    loss.backward()
    opt.step()
    opt.zero_grad()

    scores = scores.detach()

    # Sample from the model.
    samples = sample_plackett_luce(scores, (1_000_000,))

    # Check the probability of a specific ranking.
    rankings = [[[2], [0], [1]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent to one value.
    rankings = [[[2], [0]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond = torch.argmax((samples == 2).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent between two values.
    rankings = [[[1], [0, 2]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond1 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
    cond2 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 2).int(), dim=-1)
    cond = cond1 & cond2
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")


    if __name__ == "__main__":
    main()
  6. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion plackett_luce_loss.py
    Original file line number Diff line number Diff line change
    @@ -53,7 +53,7 @@ def make_masks_for_plackett_luce_loss(
    ```
    The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers either of 0 and 2 to 1.
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers either 0 or 2 to 1.
    Args:
    rankings (List[List[List[int]]]): List of rankings.
  7. crowsonkb revised this gist Nov 5, 2024. 1 changed file with 1 addition and 2 deletions.
    3 changes: 1 addition & 2 deletions plackett_luce_loss.py
    Original file line number Diff line number Diff line change
    @@ -53,8 +53,7 @@ def make_masks_for_plackett_luce_loss(
    ```
    The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers 0 and 2 to 1 and is
    indifferent between 0 and 2.
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers either of 0 and 2 to 1.
    Args:
    rankings (List[List[List[int]]]): List of rankings.
  8. crowsonkb created this gist Nov 4, 2024.
    155 changes: 155 additions & 0 deletions plackett_luce_loss.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,155 @@
    """Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""

    from itertools import chain
    from typing import List, Tuple

    import torch


    def plackett_luce_loss(
    scores: torch.Tensor, numerator_mask: torch.Tensor, denominator_mask: torch.Tensor
    ) -> torch.Tensor:
    """Plackett-Luce loss function for ranking tasks.
    This function uses sum reduction for the loss value. Divide the sum by `len(rankings)`
    to get the mean loss value.
    If `beta * (logp - logp_ref)` is input as `scores`, where `logp` is the log probability
    of a completion given its prompt, and `logp_ref` is the log probability of the completion
    given its prompt under a reference model, then this function computes the Plackett-Luce DPO
    loss (Appendix A.3 of https://arxiv.org/abs/2305.18290).
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    numerator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    denominator_mask (torch.Tensor): From `make_masks_for_plackett_luce_loss()`.
    Shape (n_groups, n_scores).
    Returns:
    torch.Tensor: Scalar loss value.
    """
    scores_n = torch.where(numerator_mask, scores, float("-inf"))
    scores_d = torch.where(denominator_mask, scores, float("-inf"))
    numerator = torch.logsumexp(scores_n, dim=-1)
    denominator = torch.logsumexp(scores_d, dim=-1)
    log_likelihood = torch.where(numerator == float("-inf"), 0.0, numerator - denominator)
    return -torch.sum(log_likelihood)


    def make_masks_for_plackett_luce_loss(
    rankings: List[List[List[int]]], n_scores: int, device: torch.device = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Make masks for the Plackett-Luce loss function.
    This function accepts a list of rankings in the following format:
    ```
    rankings = [
    [[1], [2], [0]],
    [[2], [0]],
    [[0, 2], [1]],
    ]
    ```
    The above example has 3 rankings. The first ranking prefers 1 to 2 to 0. The second ranking
    prefers 2 to 0 and is indifferent to 1. The third ranking prefers 0 and 2 to 1 and is
    indifferent between 0 and 2.
    Args:
    rankings (List[List[List[int]]]): List of rankings.
    n_scores (int): Number of scores.
    device (torch.device, optional): Device for the masks.
    Returns:
    Tuple[torch.Tensor, torch.Tensor]: Tuple of masks.
    """
    n_groups = sum(len(ranking) for ranking in rankings)
    numerator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    denominator_mask = torch.zeros(n_groups, n_scores, device="cpu", dtype=torch.bool)
    i = 0
    for ranking in rankings:
    remaining_items = set(chain.from_iterable(ranking))
    for group in ranking:
    items = set(group)
    for item in items:
    numerator_mask[i, item] = True
    for item in remaining_items:
    denominator_mask[i, item] = True
    remaining_items -= items
    i += 1
    return numerator_mask.to(device), denominator_mask.to(device)


    def sample_plackett_luce(scores: torch.Tensor, shape: Tuple[int] = ()) -> torch.Tensor:
    """Sample from a Plackett-Luce model.
    Args:
    scores (torch.Tensor): Model output. Shape (n_scores).
    shape (Tuple[int], optional): Shape of the samples.
    Returns:
    torch.Tensor: Samples. Shape (*shape, n_scores).
    """
    gumbel = scores.new_empty(*shape, *scores.shape).exponential_().log_().neg_()
    _, indices = torch.sort(scores + gumbel, dim=-1, descending=True)
    return indices


    def main():
    """Run tests."""

    # Fit a model.
    rankings = [
    [[1], [2], [0]],
    [[2], [0, 1]],
    [[0], [1]],
    ]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    scores = torch.zeros(3, requires_grad=True)
    opt = torch.optim.SGD([scores], lr=1)

    for i in range(20):
    loss = plackett_luce_loss(scores, *masks) / len(rankings)
    print(f"step: {i}, loss: {loss:.6f}")
    loss.backward()
    opt.step()
    opt.zero_grad()

    scores = scores.detach()

    # Sample from the model.
    samples = sample_plackett_luce(scores, (1_000_000,))

    # Check the probability of a specific ranking.
    rankings = [[[2], [0], [1]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond = torch.all(samples == torch.tensor([2, 0, 1]), dim=-1)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent to one value.
    rankings = [[[2], [0]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond = torch.argmax((samples == 2).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")

    # Check the probability of a ranking where we are indifferent between two values.
    rankings = [[[1], [0, 2]]]
    masks = make_masks_for_plackett_luce_loss(rankings, 3)
    expected = torch.exp(-plackett_luce_loss(scores, *masks))
    cond1 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 0).int(), dim=-1)
    cond2 = torch.argmax((samples == 1).int(), dim=-1) < torch.argmax((samples == 2).int(), dim=-1)
    cond = cond1 & cond2
    sampled = torch.mean(cond.float())
    print(f"expected: {expected:.4f}")
    print(f" sampled: {sampled:.4f}")


    if __name__ == "__main__":
    main()