Skip to content

Instantly share code, notes, and snippets.

@dimartinot
Created June 28, 2020 21:44
Show Gist options
  • Select an option

  • Save dimartinot/6ab415ce215d1dd310e86a4a7079e43e to your computer and use it in GitHub Desktop.

Select an option

Save dimartinot/6ab415ce215d1dd310e86a4a7079e43e to your computer and use it in GitHub Desktop.
Quadruplet loss function
class QuadrupletLoss(torch.nn.Module):
"""
Quadruplet loss function.
Builds on the Triplet Loss and takes 4 data input: one anchor, one positive and two negative examples. The negative examples needs not to be matching the anchor, the positive and each other.
"""
def __init__(self, margin1=2.0, margin2=1.0):
super(QuadrupletLoss, self).__init__()
self.margin1 = margin1
self.margin2 = margin2
def forward(self, anchor, positive, negative1, negative2):
squarred_distance_pos = (anchor - positive).pow(2).sum(1)
squarred_distance_neg = (anchor - negative1).pow(2).sum(1)
squarred_distance_neg_b = (negative1 - negative2).pow(2).sum(1)
quadruplet_loss = \
F.relu(self.margin1 + squarred_distance_pos - squarred_distance_neg) \
+ F.relu(self.margin2 + squarred_distance_pos - squarred_distance_neg_b)
return quadruplet_loss.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment