Skip to content

Instantly share code, notes, and snippets.

@tyommik
Created August 15, 2020 14:32
Show Gist options
  • Select an option

  • Save tyommik/5b6385ccaea7426525c614bcc257cd66 to your computer and use it in GitHub Desktop.

Select an option

Save tyommik/5b6385ccaea7426525c614bcc257cd66 to your computer and use it in GitHub Desktop.
FocalLoss with LabelSmoothing
class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return (-true_dist * pred).max(dim=1).values
class FocalLossWithLabelSmoothing(nn.Module):
def __init__(self, gamma=2., reduction='mean'):
super().__init__()
self.gamma = gamma
self.reduction = reduction
self.CE = LabelSmoothingLoss(classes=2, smoothing=0.1)
def forward(self, inputs, targets):
CE_loss = self.CE(inputs, targets)
pt = torch.exp(-CE_loss)
F_loss = ((1 - pt)**self.gamma) * CE_loss
if self.reduction == 'sum':
return F_loss.sum()
elif self.reduction == 'mean':
return F_loss.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment