Skip to content

Instantly share code, notes, and snippets.

@makslevental
Last active March 26, 2026 19:01
Show Gist options
  • Select an option

  • Save makslevental/4ce58ffe1ebca6aef28bb6dfde1d2b91 to your computer and use it in GitHub Desktop.

Select an option

Save makslevental/4ce58ffe1ebca6aef28bb6dfde1d2b91 to your computer and use it in GitHub Desktop.
linalg fusions
// Attention mask + max-reduce using named linalg ops.
// 3 kernels: broadcast mask → add to scores → max-reduce over j.
// After generalize + fuse: broadcast and add are inlined into the reduction.
func.func @fuse_mask_into_max_reduce(
%scores: tensor<4x512x512xf32>,
%mask: tensor<512x512xf32>,
%neg_inf_init: tensor<4x512xf32>) -> tensor<4x512xf32> {
%init3d = tensor.empty() : tensor<4x512x512xf32>
// Kernel 1: broadcast mask [512,512] -> [4,512,512] (replicate over batch)
%mask_3d = linalg.broadcast
ins(%mask : tensor<512x512xf32>)
outs(%init3d : tensor<4x512x512xf32>) dimensions = [0]
// Kernel 2: elementwise add scores + mask
%masked = linalg.add
ins(%scores, %mask_3d : tensor<4x512x512xf32>, tensor<4x512x512xf32>)
outs(%init3d : tensor<4x512x512xf32>) -> tensor<4x512x512xf32>
// Kernel 3: max-reduce over j (dim 2)
%max = linalg.reduce ins(%masked : tensor<4x512x512xf32>)
outs(%neg_inf_init : tensor<4x512xf32>) dimensions = [2]
(%in: f32, %acc: f32) {
%m = arith.maximumf %in, %acc : f32
linalg.yield %m : f32
}
return %max : tensor<4x512xf32>
}
// RUN: mlir-opt %s --linalg-generalize-named-ops --linalg-fuse-elementwise-ops --split-input-file | FileCheck %s
func.func @fuse_square_into_reduce(
%x: tensor<4x1024xf32>) -> tensor<4xf32> {
%cst = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<4x1024xf32>
%zero = tensor.empty() : tensor<4xf32>
%fill = linalg.fill ins(%cst : f32) outs(%zero : tensor<4xf32>) -> tensor<4xf32>
// Kernel 1: elementwise x^2
%sq = linalg.square
ins(%x : tensor<4x1024xf32>)
outs(%init : tensor<4x1024xf32>) -> tensor<4x1024xf32>
// Kernel 2: sum-reduce over dim 1
%sum = linalg.reduce ins(%sq : tensor<4x1024xf32>)
outs(%fill : tensor<4xf32>) dimensions = [1]
(%in: f32, %acc: f32) {
%add = arith.addf %in, %acc : f32
linalg.yield %add : f32
}
return %sum : tensor<4xf32>
}
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func.func @fuse_square_into_reduce(
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x1024xf32>) -> tensor<4xf32> {
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<4xf32>
// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[EMPTY_0]] : tensor<4xf32>) {
// CHECK: ^bb0(%[[VAL_0:.*]]: f32):
// CHECK: linalg.yield %[[CONSTANT_0]] : f32
// CHECK: } -> tensor<4xf32>
// CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "reduction"]} ins(%[[ARG0]] : tensor<4x1024xf32>) outs(%[[GENERIC_0]] : tensor<4xf32>) {
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[VAL_1]], %[[VAL_1]] : f32
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_2]] : f32
// CHECK: linalg.yield %[[ADDF_0]] : f32
// CHECK: } -> tensor<4xf32>
// CHECK: return %[[GENERIC_1]] : tensor<4xf32>
// CHECK: }
// -----
// NLL loss: loss = -Σ_c q[c] * log(p[c])
// Three kernels: log (elementwise) -> mul (elementwise) -> reduce (sum)
// After generalize + fuse: log and mul are inlined into the reduction body.
func.func @fuse_nll_into_reduce(
%p: tensor<1024xf32>,
%q: tensor<1024xf32>) -> tensor<f32> {
%cst = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<1024xf32>
%zero = tensor.empty() : tensor<f32>
%fill = linalg.fill ins(%cst : f32) outs(%zero : tensor<f32>) -> tensor<f32>
// Kernel 1: elementwise log(p)
%log_p = linalg.log
ins(%p : tensor<1024xf32>)
outs(%init : tensor<1024xf32>) -> tensor<1024xf32>
// Kernel 2: elementwise q * log(p)
%qlogp = linalg.mul
ins(%q, %log_p : tensor<1024xf32>, tensor<1024xf32>)
outs(%init : tensor<1024xf32>) -> tensor<1024xf32>
// Kernel 3: sum-reduce over classes
%sum = linalg.reduce ins(%qlogp : tensor<1024xf32>)
outs(%fill : tensor<f32>) dimensions = [0]
(%in: f32, %acc: f32) {
%add = arith.addf %in, %acc : f32
linalg.yield %add : f32
}
return %sum : tensor<f32>
}
// CHECK-LABEL: func.func @fuse_nll_into_reduce(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<1024xf32>) -> tensor<f32> {
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]]], iterator_types = []} outs(%[[EMPTY_0]] : tensor<f32>) {
// CHECK: ^bb0(%[[VAL_0:.*]]: f32):
// CHECK: linalg.yield %[[CONSTANT_0]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_4]]], iterator_types = ["reduction"]} ins(%[[ARG1]], %[[ARG0]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[GENERIC_0]] : tensor<f32>) {
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
// CHECK: %[[LOG_0:.*]] = math.log %[[VAL_2]] : f32
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[VAL_1]], %[[LOG_0]] : f32
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_3]] : f32
// CHECK: linalg.yield %[[ADDF_0]] : f32
// CHECK: } -> tensor<f32>
// CHECK: return %[[GENERIC_1]] : tensor<f32>
// CHECK: }
@makslevental
Copy link
Copy Markdown
Author

❯ what does this mean "Attention mask + max-reduce (first pass of softmax)."

⏺ Softmax is computed in two passes for numerical stability:

Pass 1 (this test): Find the max across each row
max_j = max_over_j(scores[b, i, j] + mask[i, j])

Pass 2: Subtract max, exponentiate, sum, divide
softmax[b,i,j] = exp(scores[b,i,j] + mask[i,j] - max_j[b,i]) / Σ_j exp(...)

The mask is typically a causal mask (upper triangle = -inf, lower = 0.0) so that future tokens get zeroed out after exponentiation. You add it before taking the max so that masked positions don't affect the max value, which would cause
numerical issues in pass 2.

This test shows pass 1: applying the mask and computing the row-wise max, fused into one kernel.

❯ "Softmax is computed in two passes for numerical stability" where does this usually happen?

⏺ In a transformer's self-attention block, right after computing the attention scores Q @ K^T / sqrt(d) and before multiplying by V:

scores = Q @ K^T / sqrt(d_k) # [B, heads, S, S]
scores = scores + mask # causal mask
attn_weights = softmax(scores) # along last dim
output = attn_weights @ V

The two-pass softmax (sometimes called "online softmax" or "safe softmax") happens inside that softmax call. Every deep learning framework does it — PyTorch's F.softmax, JAX's jax.nn.softmax, etc. If you just computed exp(x) / sum(exp(x))
naively, large values in x would overflow exp() to inf. Subtracting the max first keeps everything in a safe numerical range:

pass 1: m = max(x)
pass 2: softmax = exp(x - m) / Σ exp(x - m)

Flash Attention (Dao et al.) fuses both passes and the matmuls into a single tiled kernel by maintaining running max/sum statistics — that's the "online" trick that avoids materializing the full S x S attention matrix.

❯ why is the mask added instead of multiplied - i thought attention was multiplicative

⏺ The attention mechanism itself is multiplicative (Q times K^T), but the mask is additive. They work in different spaces:

  • The mask is applied in log space (before softmax)
  • Adding -inf in log space is equivalent to multiplying by 0 in probability space

softmax(scores + (-inf)) = exp(-inf) / Σ... = 0
softmax(scores + 0) = exp(scores) / Σ... = unchanged

So mask = 0 means "keep" and mask = -inf means "zero out". If you multiplied the scores by 0 instead, you'd get softmax(0) = 1/S (uniform) rather than 0 — the masked position would still get attention weight.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment