Created
June 28, 2020 21:44
-
-
Save dimartinot/6ab415ce215d1dd310e86a4a7079e43e to your computer and use it in GitHub Desktop.
Quadruplet loss function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class QuadrupletLoss(torch.nn.Module): | |
| """ | |
| Quadruplet loss function. | |
| Builds on the Triplet Loss and takes 4 data input: one anchor, one positive and two negative examples. The negative examples needs not to be matching the anchor, the positive and each other. | |
| """ | |
| def __init__(self, margin1=2.0, margin2=1.0): | |
| super(QuadrupletLoss, self).__init__() | |
| self.margin1 = margin1 | |
| self.margin2 = margin2 | |
| def forward(self, anchor, positive, negative1, negative2): | |
| squarred_distance_pos = (anchor - positive).pow(2).sum(1) | |
| squarred_distance_neg = (anchor - negative1).pow(2).sum(1) | |
| squarred_distance_neg_b = (negative1 - negative2).pow(2).sum(1) | |
| quadruplet_loss = \ | |
| F.relu(self.margin1 + squarred_distance_pos - squarred_distance_neg) \ | |
| + F.relu(self.margin2 + squarred_distance_pos - squarred_distance_neg_b) | |
| return quadruplet_loss.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment