Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active April 30, 2026 15:24
Show Gist options
  • Select an option

  • Save malfet/2a1db8bfbc5aacdf30dd5c8c0010e722 to your computer and use it in GitHub Desktop.

Select an option

Save malfet/2a1db8bfbc5aacdf30dd5c8c0010e722 to your computer and use it in GitHub Desktop.
Tune MPP matmul2d tile sizes for MPS F.linear
"""Tune MPP matmul2d tile sizes across dtypes and shapes."""
import time
import torch
import torch.nn.functional as F
torch.set_grad_enabled(False)
WARMUP = 20
REPEAT = 200
BATCH = 10
SHAPES = [
# (label, M, K, N)
# Aligned to (128, 64) — the LLM/transformer norm
("llama-7b qkv", 2048, 4096, 4096),
("llama-7b ffn-up", 2048, 4096, 11008),
("llama-7b ffn-dn", 2048, 11008, 4096),
("gpt2-sm qkv", 1024, 768, 2304),
("gpt2-sm ffn", 1024, 768, 3072),
("bert proj", 4096, 768, 768),
("batched-sm", 4096, 512, 512),
("small-2d", 64, 256, 128),
# Unaligned — exercise the edge-tile path of NEEDS_PADDING=true
("vit patch", 197, 768, 768),
("clip-img", 257, 1024, 4096),
("seq-33", 33, 4096, 4096),
("seq-1023", 1023, 4096, 4096),
("vocab-32001", 1024, 4096, 32001),
("both-odd", 197, 513, 195),
("tiny", 1, 64, 32),
]
DTYPES = [torch.float32, torch.float16, torch.bfloat16]
DTYPE_LABELS = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"}
METAL_TYPES = {torch.float32: "float", torch.float16: "half", torch.bfloat16: "bfloat"}
CONFIGS = [
# (tile_m, tile_n, n_simdgroups)
(32, 32, 4),
(32, 64, 4),
(64, 32, 4),
(64, 64, 4),
(128, 32, 4),
(128, 64, 4),
(128, 128, 4),
(32, 32, 2),
(32, 64, 2),
(64, 32, 2),
(64, 64, 2),
]
# Cache compiled shaders by (metal_type, tile_m, tile_n, n_simdgroups, needs_padding)
_shader_cache = {}
def make_mpp_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding):
pad = "true" if needs_padding else "false"
return f"""
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
using namespace metal;
using namespace mpp::tensor_ops;
constant constexpr int TILE_M = {tile_m};
constant constexpr int TILE_N = {tile_n};
constant constexpr bool NEEDS_PADDING = {pad};
kernel void mpp_gemm(
device {metal_type}* A [[buffer(0)]],
device {metal_type}* B [[buffer(1)]],
device {metal_type}* C [[buffer(2)]],
constant uint3& sizes [[buffer(3)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]])
{{
const uint M = sizes.x;
const uint K = sizes.y;
const uint N = sizes.z;
const uint m_off = tgid.y * TILE_M;
const uint n_off = tgid.x * TILE_N;
using ext2d = dextents<int32_t, 2>;
tensor<device {metal_type}, ext2d, tensor_inline> mA_full(A, ext2d((int)K, (int)M));
tensor<device {metal_type}, ext2d, tensor_inline> mB_full(B, ext2d((int)K, (int)N));
tensor<device {metal_type}, ext2d, tensor_inline> mC_full(C, ext2d((int)N, (int)M));
constexpr auto desc = matmul2d_descriptor(
TILE_M, TILE_N, static_cast<int>(dynamic_extent), false, true);
matmul2d<desc, execution_simdgroups<{n_simdgroups}>> op;
if (!NEEDS_PADDING || (m_off + TILE_M <= M && n_off + TILE_N <= N)) {{
auto mA = mA_full.template slice<dynamic_extent, TILE_M>(0, (int)m_off);
auto mB = mB_full.template slice<dynamic_extent, TILE_N>(0, (int)n_off);
auto mC = mC_full.template slice<TILE_N, TILE_M>((int)n_off, (int)m_off);
op.run(mA, mB, mC);
}} else {{
auto mA = mA_full.slice(0, (int)m_off);
auto mB = mB_full.slice(0, (int)n_off);
auto mC = mC_full.slice((int)n_off, (int)m_off);
op.run(mA, mB, mC);
}}
}}
"""
def get_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding):
key = (metal_type, tile_m, tile_n, n_simdgroups, needs_padding)
if key not in _shader_cache:
src = make_mpp_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding)
_shader_cache[key] = torch.mps.compile_shader(src)
return _shader_cache[key]
def bench_mpp(M, K, N, dtype, tile_m, tile_n, n_simdgroups, needs_padding):
aligned = (M % tile_m == 0) and (N % tile_n == 0)
if not needs_padding and not aligned:
return None, "shape not aligned"
metal_type = METAL_TYPES[dtype]
try:
lib = get_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding)
except Exception as e:
return None, str(e)[:80]
kernel = lib.mpp_gemm
simd_w = kernel.thread_execution_width
threads_per_tg = simd_w * n_simdgroups
A = torch.randn(M, K, device="mps", dtype=dtype)
B = torch.randn(N, K, device="mps", dtype=dtype)
C = torch.zeros(M, N, device="mps", dtype=dtype)
sizes = torch.tensor([M, K, N], device="mps", dtype=torch.int32)
num_tg_x = (N + tile_n - 1) // tile_n
num_tg_y = (M + tile_m - 1) // tile_m
def run():
kernel(A, B, C, sizes,
threads=[num_tg_x * threads_per_tg, num_tg_y, 1],
group_size=[threads_per_tg, 1, 1])
# Validate
run()
torch.mps.synchronize()
ref = (A.float() @ B.float().T).to(dtype=dtype)
maxdiff = (C - ref).abs().max().item()
if maxdiff > 1.0:
return None, f"bad accuracy: maxdiff={maxdiff}"
# Warmup
for _ in range(WARMUP):
run()
torch.mps.synchronize()
# Timed
num_batches = REPEAT // BATCH
start = time.perf_counter()
for _ in range(num_batches):
for _ in range(BATCH):
run()
torch.mps.synchronize()
elapsed = time.perf_counter() - start
total = num_batches * BATCH
avg_us = elapsed / total * 1e6
gflops = 2 * M * K * N / (elapsed / total) * 1e-9
return avg_us, gflops
def bench_baseline(M, K, N, dtype):
x = torch.randn(M, K, device="mps", dtype=dtype)
w = torch.randn(N, K, device="mps", dtype=dtype)
for _ in range(WARMUP):
F.linear(x, w)
torch.mps.synchronize()
num_batches = REPEAT // BATCH
start = time.perf_counter()
for _ in range(num_batches):
for _ in range(BATCH):
F.linear(x, w)
torch.mps.synchronize()
elapsed = time.perf_counter() - start
total = num_batches * BATCH
avg_us = elapsed / total * 1e6
gflops = 2 * M * K * N / (elapsed / total) * 1e-9
return avg_us, gflops
def main():
print(f"Device: {torch.backends.mps.get_name()} torch={torch.__version__}")
print()
for dtype in DTYPES:
dl = DTYPE_LABELS[dtype]
print(f"===== dtype={dl} =====")
for label, M, K, N in SHAPES:
print(f" --- {label}: M={M}, K={K}, N={N} ---")
avg_us, gflops = bench_baseline(M, K, N, dtype)
print(f" F.linear: {avg_us:8.1f} us {gflops:8.1f} GFLOPs")
print(f" {'tile_m':>6} {'tile_n':>6} {'sg':>3} {'pad':>3} {'time(us)':>10} {'GFLOPs':>10}")
for tile_m, tile_n, sg in CONFIGS:
for needs_padding in (False, True):
pad_label = "T" if needs_padding else "F"
result, info = bench_mpp(M, K, N, dtype, tile_m, tile_n, sg, needs_padding)
if result is None:
if info == "shape not aligned":
continue # skip silently — only meaningful for pad=T
print(f" {tile_m:>6} {tile_n:>6} {sg:>3} {pad_label:>3} FAILED: {info}")
else:
print(f" {tile_m:>6} {tile_n:>6} {sg:>3} {pad_label:>3} {result:>10.1f} {info:>10.1f}")
print()
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment