class Acc: def __init__(self, ignore_index=1): self.ignore_index = ignore_index def __call__(self, pred, tgt): # both pred and tgt have shape (bs,seq_len) mask = tgt != self.ignore_index pred *= mask tgt *= mask correct = torch.eq(pred, tgt).all(1).sum() return correct.item()