Skip to content

Instantly share code, notes, and snippets.

@bantmen
Created December 29, 2025 09:33
Show Gist options
  • Select an option

  • Save bantmen/524fd752a279943930ed134d3ea3003f to your computer and use it in GitHub Desktop.

Select an option

Save bantmen/524fd752a279943930ed134d3ea3003f to your computer and use it in GitHub Desktop.
Make torch.compile rmsnorm faster
import torch
from triton.testing import do_bench
def bench(M, N=128, dtype=torch.bfloat16):
x = torch.randn(M, N, device="cuda", dtype=dtype)
w = torch.randn(N, device="cuda", dtype=dtype)
eps = 1e-6
with torch.inference_mode():
# Compile torch kernel
compiled_torch = torch.compile(
torch.nn.functional.rms_norm, mode="default", fullgraph=True
)
torch.cuda.synchronize()
compiled_torch_time = do_bench(
lambda: compiled_torch(x, (N,), weight=w, eps=eps)
)
bwd = lambda m, n, time: M * N * dtype.itemsize * 2 / time / 1e6
print(
f"M={M:7d} | "
f"compiled_torch: {bwd(M, N, compiled_torch_time):4.1f} GB/s"
)
if __name__ == "__main__":
print("N=128, dtype=bfloat16")
print("-" * 85)
# Sweep multiple M values (comment above, uncomment below)
M = 1024
for _ in range(14):
bench(M)
M *= 2
@bantmen
Copy link
Author

bantmen commented Dec 29, 2025

N=128, dtype=bfloat16
-------------------------------------------------------------------------------------
M=   1024 | compiled_torch: 16.6 GB/s
M=   2048 | compiled_torch: 22.7 GB/s
M=   4096 | compiled_torch: 47.6 GB/s
M=   8192 | compiled_torch: 93.8 GB/s
M=  16384 | compiled_torch: 186.6 GB/s
M=  32768 | compiled_torch: 372.5 GB/s
M=  65536 | compiled_torch: 749.3 GB/s
M= 131072 | compiled_torch: 1482.4 GB/s
M= 262144 | compiled_torch: 2880.2 GB/s
M= 524288 | compiled_torch: 5507.9 GB/s
M=1048576 | compiled_torch: 6298.7 GB/s
M=2097152 | compiled_torch: 6629.1 GB/s
M=4194304 | compiled_torch: 6789.2 GB/s
M=8388608 | compiled_torch: 6899.4 GB/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment