Last active
February 1, 2024 19:50
-
-
Save rehno-lindeque/22ea91274900ea195b6a8bf89c70cd68 to your computer and use it in GitHub Desktop.
Revisions
-
rehno-lindeque revised this gist
Feb 1, 2024 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -77,21 +77,21 @@ Assuming that both arguments have the same number of dims and dims ≥ 2: ```python A = rearrange(A, '... m n -> ... m n 1') B = rearrange(B, '... n p -> ... 1 n p') C = (A * B).sum(dim=-2) # ... m p ``` ### b. [`bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm) ```python A = rearrange(A, 'b m n -> b m n 1') B = rearrange(B, 'b n p -> b 1 n p') C = (A * B).sum(dim=-2) # b m p ``` ### d. [`mm`](https://pytorch.org/docs/stable/generated/torch.mm.html#torch.mm) ```python A = rearrange(A, 'm n -> m n 1') B = rearrange(B, 'n p -> 1 n p') C = (A * B).sum(dim=-2) # m p ``` ## Appendix: Reshaping common variants -
rehno-lindeque revised this gist
Oct 1, 2023 . No changes.There are no files selected for viewing
-
rehno-lindeque revised this gist
Oct 1, 2023 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ # PyTorch Cheatsheet: Low effort from scratch layers * `b`: Batch size * `i`: Input features (Linear) @@ -23,7 +23,7 @@ W = rearrange(W, 'c_out c_in -> 1 i o') Y = (X * W).sum(dim=1) # + bias ``` ## [`Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d) layer ```python unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride) unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k') @@ -44,7 +44,7 @@ Y = (unfolded_X * W).sum(dim=(3, 5)) Y = rearrange(Y, 'b groups ... -> (b groups) ...') ``` ## [`Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d#torch.nn.Conv2d) layer ```python unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride) unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw') -
rehno-lindeque revised this gist
Oct 1, 2023 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -41,6 +41,7 @@ Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4)) unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups) W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 5)) Y = rearrange(Y, 'b groups ... -> (b groups) ...') ``` ## [`Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d#torch.nn.Conv2d) @@ -61,6 +62,7 @@ Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6)) unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups) W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 6, 7)) Y = rearrange(Y, 'b groups ... -> (b groups) ...') ``` ### Side-note: [`pad`](https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad) @@ -123,4 +125,4 @@ X = X.squeeze(-1) # X = rearrange(X, 'b i -> b i 1') X = X.unsqueeze(-1) X = X[..., None] ``` -
rehno-lindeque revised this gist
Oct 1, 2023 . 1 changed file with 5 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,8 +5,9 @@ * `o`: Output features * `c_in`: Input channels * `c_out`: Output channels * `n`: Input/output length (1D sequence, often number of tokens) * `n_in`: Input length (1D sequence) * `n_out`: Output length (1D sequence) * `h_in`: Input height * `w_in`: Input width * `h_out`: Output height @@ -25,7 +26,7 @@ Y = (X * W).sum(dim=1) # + bias ## [`Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d) ```python unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride) unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k') W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k') Y = (unfolded_X * W).sum(dim=(2, 4)) ``` @@ -37,7 +38,7 @@ Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4)) ### Including groups ```python unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups) W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 5)) ``` -
rehno-lindeque revised this gist
Oct 1, 2023 . 1 changed file with 6 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -15,34 +15,34 @@ * `kh`: Kernel height * `kw`: Kernel width ## [`Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear) layer ```python X = rearrange(X, 'b c_in -> b i 1') W = rearrange(W, 'c_out c_in -> 1 i o') Y = (X * W).sum(dim=1) # + bias ``` ## [`Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d) ```python unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride) unfolded_X = rearrange(unfolded_X, 'b c_in l_out k -> b 1 c_in l_out k') W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k') Y = (unfolded_X * W).sum(dim=(2, 4)) ``` ### Including dilation ```python Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4)) ``` ### Including groups ```python unfolded_X = rearrange(unfolded_X, 'b (groups c_in) l_out k1 -> b groups c_in l_out k', groups=groups) W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 5)) ``` ## [`Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d#torch.nn.Conv2d) ```python unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride) unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw') @@ -62,8 +62,7 @@ W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', g Y = (unfolded_X * W).sum(dim=(3, 6, 7)) ``` ### Side-note: [`pad`](https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad) ```python X_padded = nn.functional.pad(X, (left, right, top, bottom)) ``` -
rehno-lindeque created this gist
Oct 1, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,126 @@ # PyTorch Cheat sheet: Low effort from scratch layers * `b`: Batch size * `i`: Input features (Linear) * `o`: Output features * `c_in`: Input channels * `c_out`: Output channels * `l_in`: Input length (1D sequence) * `l_out`: Output length (1D sequence) * `h_in`: Input height * `w_in`: Input width * `h_out`: Output height * `w_out`: Output width * `k`: Kernel size (1D) * `kh`: Kernel height * `kw`: Kernel width ## 1. **Linear Layer** ```python X = rearrange(X, 'b c_in -> b i 1') W = rearrange(W, 'c_out c_in -> 1 i o') Y = (X * W).sum(dim=1) # + bias ``` ## 2. **Conv1d** ```python unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride) unfolded_X = rearrange(unfolded_X, 'b c_in l_out k -> b 1 c_in l_out k') W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k') Y = (unfolded_X * W).sum(dim=(2, 4)) ``` ### **Including dilation** ```python Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4)) ``` ### **Including groups** ```python unfolded_X = rearrange(unfolded_X, 'b (groups c_in) l_out k1 -> b groups c_in l_out k', groups=groups) W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 5)) ``` ## 3. **Conv2d** ```python unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride) unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw') W = rearrange(W, 'c_out c_in kh kw -> 1 c_out c_in 1 1 kh kw') Y = (unfolded_X * W).sum(dim=(2, 5, 6)) ``` ### Including dilation ```python Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6)) ``` ### Including groups ```python unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups) W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups) Y = (unfolded_X * W).sum(dim=(3, 6, 7)) ``` #### **Padding for Conv2d** ```python X_padded = nn.functional.pad(X, (left, right, top, bottom)) ``` ## Appendix: Matrix Multiplication common variants ### a. [`matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul) or [`@`](https://docs.python.org/3/library/operator.html#mapping-operators-to-functions) Assuming that both arguments have the same number of dims and dims ≥ 2: ```python A = rearrange(A, '... m n -> ... m n 1') B = rearrange(B, '... n p -> ... 1 n p') C = (A * B).sum(dim=-2) ``` ### b. [`bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm) ```python A = rearrange(A, 'b m n -> b m n 1') B = rearrange(B, 'b n p -> b 1 n p') C = (A * B).sum(dim=-2) ``` ### d. [`mm`](https://pytorch.org/docs/stable/generated/torch.mm.html#torch.mm) ```python A = rearrange(A, 'm n -> m n 1') B = rearrange(B, 'n p -> 1 n p') C = (A * B).sum(dim=-2) ``` ## Appendix: Reshaping common variants ### a. Flatten dimensions ```python # X = rearrange(X, 'b c1 c2 h w -> b (c1 c2) h w') X = X.view(X.size(0), -1, *X.shape[-2:]) ``` ### b. Unflatten dimensions ```python # X = rearrange(X, 'b (c1 c2) h w -> b c1 c2 h w', c1=c1, c2=c2) X = X.view(X.size(0), c1, c2, *X.shape[-2]) ``` ### c. Permute dimensions ```python # X = rearrange(X, 'b c h w -> b w h c') X = X.permute(0, 3, 2, 1) ``` ### d. Squeeze dimension ```python # X = rearrange(X, 'b i 1 -> b i') X = X.squeeze(-1) ``` ### e. Unsqueeze dimension ```python # X = rearrange(X, 'b i -> b i 1') X = X.unsqueeze(-1) X = X[..., None] ```