Last active
November 16, 2020 12:09
-
-
Save enhuiz/ce6b0ee7c9ff49ba118ab9b487b7b27c to your computer and use it in GitHub Desktop.
multigpu-matrix-multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| class MatrixMultiplication(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, y): | |
| """ | |
| Args: | |
| x: (bg' n chw) | |
| y: (bg' m chw) | |
| bg' != bg, it is just a slice of bg on this gpu. | |
| Returns: | |
| (bg' n m) | |
| """ | |
| print(f"bg on gpu {x.device} is:", len(x)) | |
| return x @ y.transpose(1, 2) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.parallel_matrix_multiplication = nn.DataParallel(MatrixMultiplication()) | |
| def forward(self, x, y): | |
| """ | |
| x: (b g c h w n) | |
| y: (b g c h w m) | |
| """ | |
| b, g, c, h, w, n, m = (*x.shape, y.shape[-1]) | |
| x = rearrange(x, "b g c h w n -> (b g) n (c h w)") | |
| y = rearrange(y, "b g c h w m -> (b g) m (c h w)") | |
| print("total bg is", b * g) | |
| z = self.parallel_matrix_multiplication(x, y) | |
| z = rearrange(z, "(b g) n m -> b g n m", b=b, g=g) | |
| return z | |
| model = Model() | |
| b = 4 | |
| g = 8 | |
| c = 8 | |
| h = 16 # not running 224 as it is too large | |
| w = 16 | |
| m = 49 | |
| n = 9 | |
| x = torch.randn(b, g, c, h, w, n) | |
| y = torch.randn(b, g, c, h, w, m) | |
| z = model(x, y) | |
| assert z.shape == (b, g, n, m), "Not Expected." | |
| print(z.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment