Last active
March 26, 2026 19:01
-
-
Save makslevental/4ce58ffe1ebca6aef28bb6dfde1d2b91 to your computer and use it in GitHub Desktop.
linalg fusions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Attention mask + max-reduce using named linalg ops. | |
| // 3 kernels: broadcast mask → add to scores → max-reduce over j. | |
| // After generalize + fuse: broadcast and add are inlined into the reduction. | |
| func.func @fuse_mask_into_max_reduce( | |
| %scores: tensor<4x512x512xf32>, | |
| %mask: tensor<512x512xf32>, | |
| %neg_inf_init: tensor<4x512xf32>) -> tensor<4x512xf32> { | |
| %init3d = tensor.empty() : tensor<4x512x512xf32> | |
| // Kernel 1: broadcast mask [512,512] -> [4,512,512] (replicate over batch) | |
| %mask_3d = linalg.broadcast | |
| ins(%mask : tensor<512x512xf32>) | |
| outs(%init3d : tensor<4x512x512xf32>) dimensions = [0] | |
| // Kernel 2: elementwise add scores + mask | |
| %masked = linalg.add | |
| ins(%scores, %mask_3d : tensor<4x512x512xf32>, tensor<4x512x512xf32>) | |
| outs(%init3d : tensor<4x512x512xf32>) -> tensor<4x512x512xf32> | |
| // Kernel 3: max-reduce over j (dim 2) | |
| %max = linalg.reduce ins(%masked : tensor<4x512x512xf32>) | |
| outs(%neg_inf_init : tensor<4x512xf32>) dimensions = [2] | |
| (%in: f32, %acc: f32) { | |
| %m = arith.maximumf %in, %acc : f32 | |
| linalg.yield %m : f32 | |
| } | |
| return %max : tensor<4x512xf32> | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // RUN: mlir-opt %s --linalg-generalize-named-ops --linalg-fuse-elementwise-ops --split-input-file | FileCheck %s | |
| func.func @fuse_square_into_reduce( | |
| %x: tensor<4x1024xf32>) -> tensor<4xf32> { | |
| %cst = arith.constant 0.0 : f32 | |
| %init = tensor.empty() : tensor<4x1024xf32> | |
| %zero = tensor.empty() : tensor<4xf32> | |
| %fill = linalg.fill ins(%cst : f32) outs(%zero : tensor<4xf32>) -> tensor<4xf32> | |
| // Kernel 1: elementwise x^2 | |
| %sq = linalg.square | |
| ins(%x : tensor<4x1024xf32>) | |
| outs(%init : tensor<4x1024xf32>) -> tensor<4x1024xf32> | |
| // Kernel 2: sum-reduce over dim 1 | |
| %sum = linalg.reduce ins(%sq : tensor<4x1024xf32>) | |
| outs(%fill : tensor<4xf32>) dimensions = [1] | |
| (%in: f32, %acc: f32) { | |
| %add = arith.addf %in, %acc : f32 | |
| linalg.yield %add : f32 | |
| } | |
| return %sum : tensor<4xf32> | |
| } | |
| // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)> | |
| // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)> | |
| // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)> | |
| // CHECK-LABEL: func.func @fuse_square_into_reduce( | |
| // CHECK-SAME: %[[ARG0:.*]]: tensor<4x1024xf32>) -> tensor<4xf32> { | |
| // CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32 | |
| // CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<4xf32> | |
| // CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[EMPTY_0]] : tensor<4xf32>) { | |
| // CHECK: ^bb0(%[[VAL_0:.*]]: f32): | |
| // CHECK: linalg.yield %[[CONSTANT_0]] : f32 | |
| // CHECK: } -> tensor<4xf32> | |
| // CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "reduction"]} ins(%[[ARG0]] : tensor<4x1024xf32>) outs(%[[GENERIC_0]] : tensor<4xf32>) { | |
| // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32): | |
| // CHECK: %[[MULF_0:.*]] = arith.mulf %[[VAL_1]], %[[VAL_1]] : f32 | |
| // CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_2]] : f32 | |
| // CHECK: linalg.yield %[[ADDF_0]] : f32 | |
| // CHECK: } -> tensor<4xf32> | |
| // CHECK: return %[[GENERIC_1]] : tensor<4xf32> | |
| // CHECK: } | |
| // ----- | |
| // NLL loss: loss = -Σ_c q[c] * log(p[c]) | |
| // Three kernels: log (elementwise) -> mul (elementwise) -> reduce (sum) | |
| // After generalize + fuse: log and mul are inlined into the reduction body. | |
| func.func @fuse_nll_into_reduce( | |
| %p: tensor<1024xf32>, | |
| %q: tensor<1024xf32>) -> tensor<f32> { | |
| %cst = arith.constant 0.0 : f32 | |
| %init = tensor.empty() : tensor<1024xf32> | |
| %zero = tensor.empty() : tensor<f32> | |
| %fill = linalg.fill ins(%cst : f32) outs(%zero : tensor<f32>) -> tensor<f32> | |
| // Kernel 1: elementwise log(p) | |
| %log_p = linalg.log | |
| ins(%p : tensor<1024xf32>) | |
| outs(%init : tensor<1024xf32>) -> tensor<1024xf32> | |
| // Kernel 2: elementwise q * log(p) | |
| %qlogp = linalg.mul | |
| ins(%q, %log_p : tensor<1024xf32>, tensor<1024xf32>) | |
| outs(%init : tensor<1024xf32>) -> tensor<1024xf32> | |
| // Kernel 3: sum-reduce over classes | |
| %sum = linalg.reduce ins(%qlogp : tensor<1024xf32>) | |
| outs(%fill : tensor<f32>) dimensions = [0] | |
| (%in: f32, %acc: f32) { | |
| %add = arith.addf %in, %acc : f32 | |
| linalg.yield %add : f32 | |
| } | |
| return %sum : tensor<f32> | |
| } | |
| // CHECK-LABEL: func.func @fuse_nll_into_reduce( | |
| // CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>, | |
| // CHECK-SAME: %[[ARG1:.*]]: tensor<1024xf32>) -> tensor<f32> { | |
| // CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32 | |
| // CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<f32> | |
| // CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]]], iterator_types = []} outs(%[[EMPTY_0]] : tensor<f32>) { | |
| // CHECK: ^bb0(%[[VAL_0:.*]]: f32): | |
| // CHECK: linalg.yield %[[CONSTANT_0]] : f32 | |
| // CHECK: } -> tensor<f32> | |
| // CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_4]]], iterator_types = ["reduction"]} ins(%[[ARG1]], %[[ARG0]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[GENERIC_0]] : tensor<f32>) { | |
| // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32): | |
| // CHECK: %[[LOG_0:.*]] = math.log %[[VAL_2]] : f32 | |
| // CHECK: %[[MULF_0:.*]] = arith.mulf %[[VAL_1]], %[[LOG_0]] : f32 | |
| // CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_3]] : f32 | |
| // CHECK: linalg.yield %[[ADDF_0]] : f32 | |
| // CHECK: } -> tensor<f32> | |
| // CHECK: return %[[GENERIC_1]] : tensor<f32> | |
| // CHECK: } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
❯ what does this mean "Attention mask + max-reduce (first pass of softmax)."
⏺ Softmax is computed in two passes for numerical stability:
Pass 1 (this test): Find the max across each row
max_j = max_over_j(scores[b, i, j] + mask[i, j])
Pass 2: Subtract max, exponentiate, sum, divide
softmax[b,i,j] = exp(scores[b,i,j] + mask[i,j] - max_j[b,i]) / Σ_j exp(...)
The mask is typically a causal mask (upper triangle = -inf, lower = 0.0) so that future tokens get zeroed out after exponentiation. You add it before taking the max so that masked positions don't affect the max value, which would cause
numerical issues in pass 2.
This test shows pass 1: applying the mask and computing the row-wise max, fused into one kernel.
❯ "Softmax is computed in two passes for numerical stability" where does this usually happen?
⏺ In a transformer's self-attention block, right after computing the attention scores Q @ K^T / sqrt(d) and before multiplying by V:
scores = Q @ K^T / sqrt(d_k) # [B, heads, S, S]
scores = scores + mask # causal mask
attn_weights = softmax(scores) # along last dim
output = attn_weights @ V
The two-pass softmax (sometimes called "online softmax" or "safe softmax") happens inside that softmax call. Every deep learning framework does it — PyTorch's F.softmax, JAX's jax.nn.softmax, etc. If you just computed exp(x) / sum(exp(x))
naively, large values in x would overflow exp() to inf. Subtracting the max first keeps everything in a safe numerical range:
pass 1: m = max(x)
pass 2: softmax = exp(x - m) / Σ exp(x - m)
Flash Attention (Dao et al.) fuses both passes and the matmuls into a single tiled kernel by maintaining running max/sum statistics — that's the "online" trick that avoids materializing the full S x S attention matrix.
❯ why is the mask added instead of multiplied - i thought attention was multiplicative
⏺ The attention mechanism itself is multiplicative (Q times K^T), but the mask is additive. They work in different spaces:
softmax(scores + (-inf)) = exp(-inf) / Σ... = 0
softmax(scores + 0) = exp(scores) / Σ... = unchanged
So mask = 0 means "keep" and mask = -inf means "zero out". If you multiplied the scores by 0 instead, you'd get softmax(0) = 1/S (uniform) rather than 0 — the masked position would still get attention weight.