Last active
April 30, 2026 15:24
-
-
Save malfet/2a1db8bfbc5aacdf30dd5c8c0010e722 to your computer and use it in GitHub Desktop.
Tune MPP matmul2d tile sizes for MPS F.linear
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Tune MPP matmul2d tile sizes across dtypes and shapes.""" | |
| import time | |
| import torch | |
| import torch.nn.functional as F | |
| torch.set_grad_enabled(False) | |
| WARMUP = 20 | |
| REPEAT = 200 | |
| BATCH = 10 | |
| SHAPES = [ | |
| # (label, M, K, N) | |
| # Aligned to (128, 64) — the LLM/transformer norm | |
| ("llama-7b qkv", 2048, 4096, 4096), | |
| ("llama-7b ffn-up", 2048, 4096, 11008), | |
| ("llama-7b ffn-dn", 2048, 11008, 4096), | |
| ("gpt2-sm qkv", 1024, 768, 2304), | |
| ("gpt2-sm ffn", 1024, 768, 3072), | |
| ("bert proj", 4096, 768, 768), | |
| ("batched-sm", 4096, 512, 512), | |
| ("small-2d", 64, 256, 128), | |
| # Unaligned — exercise the edge-tile path of NEEDS_PADDING=true | |
| ("vit patch", 197, 768, 768), | |
| ("clip-img", 257, 1024, 4096), | |
| ("seq-33", 33, 4096, 4096), | |
| ("seq-1023", 1023, 4096, 4096), | |
| ("vocab-32001", 1024, 4096, 32001), | |
| ("both-odd", 197, 513, 195), | |
| ("tiny", 1, 64, 32), | |
| ] | |
| DTYPES = [torch.float32, torch.float16, torch.bfloat16] | |
| DTYPE_LABELS = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"} | |
| METAL_TYPES = {torch.float32: "float", torch.float16: "half", torch.bfloat16: "bfloat"} | |
| CONFIGS = [ | |
| # (tile_m, tile_n, n_simdgroups) | |
| (32, 32, 4), | |
| (32, 64, 4), | |
| (64, 32, 4), | |
| (64, 64, 4), | |
| (128, 32, 4), | |
| (128, 64, 4), | |
| (128, 128, 4), | |
| (32, 32, 2), | |
| (32, 64, 2), | |
| (64, 32, 2), | |
| (64, 64, 2), | |
| ] | |
| # Cache compiled shaders by (metal_type, tile_m, tile_n, n_simdgroups, needs_padding) | |
| _shader_cache = {} | |
| def make_mpp_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding): | |
| pad = "true" if needs_padding else "false" | |
| return f""" | |
| #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> | |
| using namespace metal; | |
| using namespace mpp::tensor_ops; | |
| constant constexpr int TILE_M = {tile_m}; | |
| constant constexpr int TILE_N = {tile_n}; | |
| constant constexpr bool NEEDS_PADDING = {pad}; | |
| kernel void mpp_gemm( | |
| device {metal_type}* A [[buffer(0)]], | |
| device {metal_type}* B [[buffer(1)]], | |
| device {metal_type}* C [[buffer(2)]], | |
| constant uint3& sizes [[buffer(3)]], | |
| uint2 tgid [[threadgroup_position_in_grid]], | |
| uint tid [[thread_index_in_threadgroup]]) | |
| {{ | |
| const uint M = sizes.x; | |
| const uint K = sizes.y; | |
| const uint N = sizes.z; | |
| const uint m_off = tgid.y * TILE_M; | |
| const uint n_off = tgid.x * TILE_N; | |
| using ext2d = dextents<int32_t, 2>; | |
| tensor<device {metal_type}, ext2d, tensor_inline> mA_full(A, ext2d((int)K, (int)M)); | |
| tensor<device {metal_type}, ext2d, tensor_inline> mB_full(B, ext2d((int)K, (int)N)); | |
| tensor<device {metal_type}, ext2d, tensor_inline> mC_full(C, ext2d((int)N, (int)M)); | |
| constexpr auto desc = matmul2d_descriptor( | |
| TILE_M, TILE_N, static_cast<int>(dynamic_extent), false, true); | |
| matmul2d<desc, execution_simdgroups<{n_simdgroups}>> op; | |
| if (!NEEDS_PADDING || (m_off + TILE_M <= M && n_off + TILE_N <= N)) {{ | |
| auto mA = mA_full.template slice<dynamic_extent, TILE_M>(0, (int)m_off); | |
| auto mB = mB_full.template slice<dynamic_extent, TILE_N>(0, (int)n_off); | |
| auto mC = mC_full.template slice<TILE_N, TILE_M>((int)n_off, (int)m_off); | |
| op.run(mA, mB, mC); | |
| }} else {{ | |
| auto mA = mA_full.slice(0, (int)m_off); | |
| auto mB = mB_full.slice(0, (int)n_off); | |
| auto mC = mC_full.slice((int)n_off, (int)m_off); | |
| op.run(mA, mB, mC); | |
| }} | |
| }} | |
| """ | |
| def get_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding): | |
| key = (metal_type, tile_m, tile_n, n_simdgroups, needs_padding) | |
| if key not in _shader_cache: | |
| src = make_mpp_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding) | |
| _shader_cache[key] = torch.mps.compile_shader(src) | |
| return _shader_cache[key] | |
| def bench_mpp(M, K, N, dtype, tile_m, tile_n, n_simdgroups, needs_padding): | |
| aligned = (M % tile_m == 0) and (N % tile_n == 0) | |
| if not needs_padding and not aligned: | |
| return None, "shape not aligned" | |
| metal_type = METAL_TYPES[dtype] | |
| try: | |
| lib = get_shader(metal_type, tile_m, tile_n, n_simdgroups, needs_padding) | |
| except Exception as e: | |
| return None, str(e)[:80] | |
| kernel = lib.mpp_gemm | |
| simd_w = kernel.thread_execution_width | |
| threads_per_tg = simd_w * n_simdgroups | |
| A = torch.randn(M, K, device="mps", dtype=dtype) | |
| B = torch.randn(N, K, device="mps", dtype=dtype) | |
| C = torch.zeros(M, N, device="mps", dtype=dtype) | |
| sizes = torch.tensor([M, K, N], device="mps", dtype=torch.int32) | |
| num_tg_x = (N + tile_n - 1) // tile_n | |
| num_tg_y = (M + tile_m - 1) // tile_m | |
| def run(): | |
| kernel(A, B, C, sizes, | |
| threads=[num_tg_x * threads_per_tg, num_tg_y, 1], | |
| group_size=[threads_per_tg, 1, 1]) | |
| # Validate | |
| run() | |
| torch.mps.synchronize() | |
| ref = (A.float() @ B.float().T).to(dtype=dtype) | |
| maxdiff = (C - ref).abs().max().item() | |
| if maxdiff > 1.0: | |
| return None, f"bad accuracy: maxdiff={maxdiff}" | |
| # Warmup | |
| for _ in range(WARMUP): | |
| run() | |
| torch.mps.synchronize() | |
| # Timed | |
| num_batches = REPEAT // BATCH | |
| start = time.perf_counter() | |
| for _ in range(num_batches): | |
| for _ in range(BATCH): | |
| run() | |
| torch.mps.synchronize() | |
| elapsed = time.perf_counter() - start | |
| total = num_batches * BATCH | |
| avg_us = elapsed / total * 1e6 | |
| gflops = 2 * M * K * N / (elapsed / total) * 1e-9 | |
| return avg_us, gflops | |
| def bench_baseline(M, K, N, dtype): | |
| x = torch.randn(M, K, device="mps", dtype=dtype) | |
| w = torch.randn(N, K, device="mps", dtype=dtype) | |
| for _ in range(WARMUP): | |
| F.linear(x, w) | |
| torch.mps.synchronize() | |
| num_batches = REPEAT // BATCH | |
| start = time.perf_counter() | |
| for _ in range(num_batches): | |
| for _ in range(BATCH): | |
| F.linear(x, w) | |
| torch.mps.synchronize() | |
| elapsed = time.perf_counter() - start | |
| total = num_batches * BATCH | |
| avg_us = elapsed / total * 1e6 | |
| gflops = 2 * M * K * N / (elapsed / total) * 1e-9 | |
| return avg_us, gflops | |
| def main(): | |
| print(f"Device: {torch.backends.mps.get_name()} torch={torch.__version__}") | |
| print() | |
| for dtype in DTYPES: | |
| dl = DTYPE_LABELS[dtype] | |
| print(f"===== dtype={dl} =====") | |
| for label, M, K, N in SHAPES: | |
| print(f" --- {label}: M={M}, K={K}, N={N} ---") | |
| avg_us, gflops = bench_baseline(M, K, N, dtype) | |
| print(f" F.linear: {avg_us:8.1f} us {gflops:8.1f} GFLOPs") | |
| print(f" {'tile_m':>6} {'tile_n':>6} {'sg':>3} {'pad':>3} {'time(us)':>10} {'GFLOPs':>10}") | |
| for tile_m, tile_n, sg in CONFIGS: | |
| for needs_padding in (False, True): | |
| pad_label = "T" if needs_padding else "F" | |
| result, info = bench_mpp(M, K, N, dtype, tile_m, tile_n, sg, needs_padding) | |
| if result is None: | |
| if info == "shape not aligned": | |
| continue # skip silently — only meaningful for pad=T | |
| print(f" {tile_m:>6} {tile_n:>6} {sg:>3} {pad_label:>3} FAILED: {info}") | |
| else: | |
| print(f" {tile_m:>6} {tile_n:>6} {sg:>3} {pad_label:>3} {result:>10.1f} {info:>10.1f}") | |
| print() | |
| print() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment