Last active
October 8, 2022 14:42
-
-
Save dendenxu/d23b3a2946790cf9d58297a8f2c75fbe to your computer and use it in GitHub Desktop.
Faster pytorch inverse of small matrices (batched) than torch.inverse (which sometimes uses CPU and does not batch properly), with some decomposition methods, added a faster version
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def torch_inverse_2x2(A: torch.Tensor, eps=torch.finfo(torch.float).eps): | |
| B = torch.zeros_like(A) | |
| # for readability | |
| a = A[..., 0, 0] | |
| b = A[..., 0, 1] | |
| c = A[..., 1, 0] | |
| d = A[..., 1, 1] | |
| # slightly slower but save 20% of space (??) by | |
| # storing determinat inplace | |
| det = B[..., 1, 1] | |
| det = (a*d - b*c) | |
| det = det + eps | |
| B[..., 0, 0] = d / det | |
| B[..., 0, 1] = -b / det | |
| B[..., 1, 0] = -c / det | |
| B[..., 1, 1] = a / det | |
| return B | |
| def torch_inverse_3x3(R: torch.Tensor, eps=torch.finfo(torch.float).eps): | |
| # B, N, 3, 3 | |
| """ | |
| a, b, c | m00, m01, m02 | |
| d, e, f | m10, m11, m12 | |
| g, h, i | m20, m21, m22 | |
| """ | |
| # return R.inverse() # FIXME: possible bug when performing inverse | |
| B, N, _, _ = R.shape | |
| minors = R.new_zeros(B, N, 3, 3, 2, 2) | |
| idx_i = g_idx_i.to(R.device) # almost never need to copy | |
| idx_j = g_idx_j.to(R.device) # almost never need to copy | |
| signs = g_signs.to(R.device) # almost never need to copy | |
| for i in range(3): | |
| for j in range(3): | |
| minors[:, :, i, j, :, :] = R[:, :, idx_i[i, j], idx_j[i, j]] | |
| minors = minors[:, :, :, :, 0, 0] * minors[:, :, :, :, 1, 1] - minors[:, :, :, :, 0, 1] * minors[:, :, :, :, 1, 0] | |
| cofactors = minors * signs[None, None] # 3,3 -> B,N,3,3 | |
| cofactors_t = cofactors.transpose(-2, -1) # B, N, 3, 3 | |
| determinant = R[:, :, 0, 0] * minors[:, :, 0, 0] - R[:, :, 0, 1] * minors[:, :, 0, 1] + R[:, :, 0, 2] * minors[:, :, 0, 2] # B, N | |
| inverse = cofactors_t / (determinant[:, :, None, None] + eps) | |
| return inverse | |
| g_idx_i = torch.tensor( | |
| [ | |
| [ | |
| [[1, 1], [2, 2]], | |
| [[1, 1], [2, 2]], | |
| [[1, 1], [2, 2]], | |
| ], | |
| [ | |
| [[0, 0], [2, 2]], | |
| [[0, 0], [2, 2]], | |
| [[0, 0], [2, 2]], | |
| ], | |
| [ | |
| [[0, 0], [1, 1]], | |
| [[0, 0], [1, 1]], | |
| [[0, 0], [1, 1]], | |
| ], | |
| ], device='cuda', dtype=torch.long) | |
| g_idx_j = torch.tensor( | |
| [ | |
| [ | |
| [[1, 2], [1, 2]], | |
| [[0, 2], [0, 2]], | |
| [[0, 1], [0, 1]], | |
| ], | |
| [ | |
| [[1, 2], [1, 2]], | |
| [[0, 2], [0, 2]], | |
| [[0, 1], [0, 1]], | |
| ], | |
| [ | |
| [[1, 2], [1, 2]], | |
| [[0, 2], [0, 2]], | |
| [[0, 1], [0, 1]], | |
| ], | |
| ], device='cuda', dtype=torch.long) | |
| g_signs = torch.tensor([ | |
| [+1, -1, +1], | |
| [-1, +1, -1], | |
| [+1, -1, +1], | |
| ], device='cuda', dtype=torch.long) | |
| def torch_inverse_3x3_naive(R, eps=1e-10): | |
| # n_batch, n_bones, 3, 3 | |
| """ | |
| a, b, c | m00, m01, m02 | |
| d, e, f | m10, m11, m12 | |
| g, h, i | m20, m21, m22 | |
| """ | |
| # convenient access | |
| r00 = R[..., 0, 0] | |
| r01 = R[..., 0, 1] | |
| r02 = R[..., 0, 2] | |
| r10 = R[..., 1, 0] | |
| r11 = R[..., 1, 1] | |
| r12 = R[..., 1, 2] | |
| r20 = R[..., 2, 0] | |
| r21 = R[..., 2, 1] | |
| r22 = R[..., 2, 2] | |
| # determinant of matrix minors | |
| m00 = + r11 * r22 - r21 * r12 | |
| m01 = - r10 * r22 + r20 * r12 | |
| m02 = + r10 * r21 - r20 * r11 | |
| m10 = - r01 * r22 + r21 * r02 | |
| m11 = + r00 * r22 - r20 * r02 | |
| m12 = - r00 * r21 + r20 * r01 | |
| m20 = + r01 * r12 - r11 * r02 | |
| m21 = - r00 * r12 + r10 * r02 | |
| m22 = + r00 * r11 - r10 * r01 | |
| # transpose of determinant of matrix minors | |
| col0 = torch.stack([m00, m01, m02], dim=-1) | |
| col1 = torch.stack([m10, m11, m12], dim=-1) | |
| col2 = torch.stack([m20, m21, m22], dim=-1) | |
| m = torch.stack([col0, col1, col2], dim=-1) | |
| # determinant of matrix | |
| d = r00 * m00 + r01 * m01 + r02 * m02 | |
| # inverse of 3x3 matrix | |
| inv = m / (d[..., None, None] + eps) | |
| return inv | |
| def torch_inverse_decomp(L: torch.Tensor, eps=1e-10): | |
| n = L.shape[-1] | |
| invL = torch.zeros_like(L) | |
| for j in range(0, n): | |
| invL[..., j, j] = 1.0 / (L[..., j, j] + eps) | |
| for i in range(j+1, n): | |
| S = 0.0 | |
| for k in range(i+1): | |
| S = S - L[..., i, k] * invL[..., k, j].clone() | |
| invL[..., i, j] = S / (L[..., i, i] + eps) | |
| return invL |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment