Skip to content

Instantly share code, notes, and snippets.

View huynhthaihoa's full-sized avatar

huynhthaihoa

View GitHub Profile
@huynhthaihoa
huynhthaihoa / midas_loss.py
Created June 19, 2023 06:42 — forked from dvdhfnr/midas_loss.py
Loss function of MiDaS
import torch
import torch.nn as nn
def compute_scale_and_shift(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))