Skip to content

Instantly share code, notes, and snippets.

@dendenxu
Last active October 8, 2022 14:42
Show Gist options
  • Select an option

  • Save dendenxu/d23b3a2946790cf9d58297a8f2c75fbe to your computer and use it in GitHub Desktop.

Select an option

Save dendenxu/d23b3a2946790cf9d58297a8f2c75fbe to your computer and use it in GitHub Desktop.
Faster pytorch inverse of small matrices (batched) than torch.inverse (which sometimes uses CPU and does not batch properly), with some decomposition methods, added a faster version
import torch
def torch_inverse_2x2(A: torch.Tensor, eps=torch.finfo(torch.float).eps):
B = torch.zeros_like(A)
# for readability
a = A[..., 0, 0]
b = A[..., 0, 1]
c = A[..., 1, 0]
d = A[..., 1, 1]
# slightly slower but save 20% of space (??) by
# storing determinat inplace
det = B[..., 1, 1]
det = (a*d - b*c)
det = det + eps
B[..., 0, 0] = d / det
B[..., 0, 1] = -b / det
B[..., 1, 0] = -c / det
B[..., 1, 1] = a / det
return B
def torch_inverse_3x3(R: torch.Tensor, eps=torch.finfo(torch.float).eps):
# B, N, 3, 3
"""
a, b, c | m00, m01, m02
d, e, f | m10, m11, m12
g, h, i | m20, m21, m22
"""
# return R.inverse() # FIXME: possible bug when performing inverse
B, N, _, _ = R.shape
minors = R.new_zeros(B, N, 3, 3, 2, 2)
idx_i = g_idx_i.to(R.device) # almost never need to copy
idx_j = g_idx_j.to(R.device) # almost never need to copy
signs = g_signs.to(R.device) # almost never need to copy
for i in range(3):
for j in range(3):
minors[:, :, i, j, :, :] = R[:, :, idx_i[i, j], idx_j[i, j]]
minors = minors[:, :, :, :, 0, 0] * minors[:, :, :, :, 1, 1] - minors[:, :, :, :, 0, 1] * minors[:, :, :, :, 1, 0]
cofactors = minors * signs[None, None] # 3,3 -> B,N,3,3
cofactors_t = cofactors.transpose(-2, -1) # B, N, 3, 3
determinant = R[:, :, 0, 0] * minors[:, :, 0, 0] - R[:, :, 0, 1] * minors[:, :, 0, 1] + R[:, :, 0, 2] * minors[:, :, 0, 2] # B, N
inverse = cofactors_t / (determinant[:, :, None, None] + eps)
return inverse
g_idx_i = torch.tensor(
[
[
[[1, 1], [2, 2]],
[[1, 1], [2, 2]],
[[1, 1], [2, 2]],
],
[
[[0, 0], [2, 2]],
[[0, 0], [2, 2]],
[[0, 0], [2, 2]],
],
[
[[0, 0], [1, 1]],
[[0, 0], [1, 1]],
[[0, 0], [1, 1]],
],
], device='cuda', dtype=torch.long)
g_idx_j = torch.tensor(
[
[
[[1, 2], [1, 2]],
[[0, 2], [0, 2]],
[[0, 1], [0, 1]],
],
[
[[1, 2], [1, 2]],
[[0, 2], [0, 2]],
[[0, 1], [0, 1]],
],
[
[[1, 2], [1, 2]],
[[0, 2], [0, 2]],
[[0, 1], [0, 1]],
],
], device='cuda', dtype=torch.long)
g_signs = torch.tensor([
[+1, -1, +1],
[-1, +1, -1],
[+1, -1, +1],
], device='cuda', dtype=torch.long)
def torch_inverse_3x3_naive(R, eps=1e-10):
# n_batch, n_bones, 3, 3
"""
a, b, c | m00, m01, m02
d, e, f | m10, m11, m12
g, h, i | m20, m21, m22
"""
# convenient access
r00 = R[..., 0, 0]
r01 = R[..., 0, 1]
r02 = R[..., 0, 2]
r10 = R[..., 1, 0]
r11 = R[..., 1, 1]
r12 = R[..., 1, 2]
r20 = R[..., 2, 0]
r21 = R[..., 2, 1]
r22 = R[..., 2, 2]
# determinant of matrix minors
m00 = + r11 * r22 - r21 * r12
m01 = - r10 * r22 + r20 * r12
m02 = + r10 * r21 - r20 * r11
m10 = - r01 * r22 + r21 * r02
m11 = + r00 * r22 - r20 * r02
m12 = - r00 * r21 + r20 * r01
m20 = + r01 * r12 - r11 * r02
m21 = - r00 * r12 + r10 * r02
m22 = + r00 * r11 - r10 * r01
# transpose of determinant of matrix minors
col0 = torch.stack([m00, m01, m02], dim=-1)
col1 = torch.stack([m10, m11, m12], dim=-1)
col2 = torch.stack([m20, m21, m22], dim=-1)
m = torch.stack([col0, col1, col2], dim=-1)
# determinant of matrix
d = r00 * m00 + r01 * m01 + r02 * m02
# inverse of 3x3 matrix
inv = m / (d[..., None, None] + eps)
return inv
def torch_inverse_decomp(L: torch.Tensor, eps=1e-10):
n = L.shape[-1]
invL = torch.zeros_like(L)
for j in range(0, n):
invL[..., j, j] = 1.0 / (L[..., j, j] + eps)
for i in range(j+1, n):
S = 0.0
for k in range(i+1):
S = S - L[..., i, k] * invL[..., k, j].clone()
invL[..., i, j] = S / (L[..., i, i] + eps)
return invL
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment