Skip to content

Instantly share code, notes, and snippets.

@enhuiz
Last active November 16, 2020 12:09
Show Gist options
  • Select an option

  • Save enhuiz/ce6b0ee7c9ff49ba118ab9b487b7b27c to your computer and use it in GitHub Desktop.

Select an option

Save enhuiz/ce6b0ee7c9ff49ba118ab9b487b7b27c to your computer and use it in GitHub Desktop.
multigpu-matrix-multiplication
import torch
import torch.nn as nn
from einops import rearrange
class MatrixMultiplication(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
"""
Args:
x: (bg' n chw)
y: (bg' m chw)
bg' != bg, it is just a slice of bg on this gpu.
Returns:
(bg' n m)
"""
print(f"bg on gpu {x.device} is:", len(x))
return x @ y.transpose(1, 2)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.parallel_matrix_multiplication = nn.DataParallel(MatrixMultiplication())
def forward(self, x, y):
"""
x: (b g c h w n)
y: (b g c h w m)
"""
b, g, c, h, w, n, m = (*x.shape, y.shape[-1])
x = rearrange(x, "b g c h w n -> (b g) n (c h w)")
y = rearrange(y, "b g c h w m -> (b g) m (c h w)")
print("total bg is", b * g)
z = self.parallel_matrix_multiplication(x, y)
z = rearrange(z, "(b g) n m -> b g n m", b=b, g=g)
return z
model = Model()
b = 4
g = 8
c = 8
h = 16 # not running 224 as it is too large
w = 16
m = 49
n = 9
x = torch.randn(b, g, c, h, w, n)
y = torch.randn(b, g, c, h, w, m)
z = model(x, y)
assert z.shape == (b, g, n, m), "Not Expected."
print(z.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment