Skip to content

Instantly share code, notes, and snippets.

@ProExpertProg
Last active September 22, 2025 16:08
Show Gist options
  • Select an option

  • Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.

Select an option

Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.

Revisions

  1. ProExpertProg renamed this gist Aug 13, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion gistfile1.txt → mirage-example-fx-graph.py
    Original file line number Diff line number Diff line change
    @@ -284,7 +284,7 @@ def run(self, *args, **kwargs):
    compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True)
    compiled_model(*inputs)

    qmodel = SimpleLlama(qdtype=torch.float8_e4m3fnuz)
    qmodel = SimpleLlama(qdtype=torch.float8_e4m3fn)
    qmodel(*inputs)
    compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True)
    compiled_qmodel(*inputs)
  2. ProExpertProg created this gist Aug 12, 2025.
    290 changes: 290 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,290 @@
    from typing import Optional, Callable, Sequence, Any

    import torch
    from torch import nn, fx
    from torch.library import Library
    import torch.nn.functional as F
    import torch._inductor
    import torch._inductor.compile_fx

    mirage_lib = Library("mirage", "FRAGMENT") # noqa


    def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: list[str] = [],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
    dispatch_key: str = "CUDA",
    tags: tuple[torch.Tag, ...] = (),
    ):
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
    import torch.library
    schema_str = torch.library.infer_schema(op_func,
    mutates_args=mutates_args)
    my_lib = target_lib or mirage_lib
    my_lib.define(op_name + schema_str, tags=tags)
    my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
    if fake_impl is not None:
    my_lib._register_fake(op_name, fake_impl)

    # ============================================================
    # Mirage placeholder op registration
    # ============================================================

    def rms_norm(input: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor],
    epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
    # Never actually called
    print("rms_norm")
    if residual is None:
    residual = input
    return torch.zeros_like(input), residual


    def rms_norm_fake(input: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor],
    epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(input), torch.empty_like(input)


    direct_register_custom_op("rms_norm", rms_norm, fake_impl=rms_norm_fake)


    def silu_mul(input: torch.Tensor) -> torch.Tensor:
    # Never actually called
    print("silu_mul")
    return torch.zeros_like(input[..., 0:input.shape[1] // 2])


    def silu_mul_fake(input: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(input[..., 0:input.shape[1] // 2])


    direct_register_custom_op("silu_mul", silu_mul, fake_impl=silu_mul_fake)


    def rope(q: torch.Tensor,
    k: torch.Tensor,
    positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    # Never actually called
    print("rope")
    return torch.zeros_like(q), torch.zeros_like(k)


    def rope_fake(q: torch.Tensor,
    k: torch.Tensor,
    positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(q), torch.empty_like(k)


    direct_register_custom_op("rope", rope, fake_impl=rope_fake)


    def quantize(input: torch.Tensor,
    scale: Optional[torch.Tensor],
    dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
    # Never actually called
    print("quantize")
    return torch.zeros_like(input, dtype=dtype), scale


    def quantize_fake(input: torch.Tensor,
    scale: Optional[torch.Tensor],
    dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(input, dtype=dtype), scale


    direct_register_custom_op("quantize", quantize, fake_impl=quantize_fake)


    def attention(q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor) -> torch.Tensor:
    # Never actually called
    print("attention")
    return torch.zeros_like(q)


    def attention_fake(q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(q)


    direct_register_custom_op("attention", attention, fake_impl=attention_fake)


    # ============================================================
    # Example PyTorch-model
    # ============================================================

    class SimpleLlamaLayer(nn.Module):
    def __init__(self,
    hidden_dim: int = 4096,
    num_heads: int = 32,
    num_kv_heads: int = 8,
    head_size: int = 128,
    dtype: torch.dtype = torch.float16,
    qdtype: Optional[torch.dtype] = None,
    ):
    super().__init__()
    if qdtype is None:
    qdtype = dtype

    self.hidden_dim = hidden_dim
    self.head_size = head_size
    self.num_heads = num_heads
    self.num_kv_heads = num_kv_heads

    self.dtype = dtype
    self.qdtype = qdtype
    self.quantized = qdtype != dtype

    rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")
    rand_wq = lambda *dims, **kwargs: rand_w(*dims, **kwargs).to(dtype=qdtype).t().contiguous().t() # column-major for scaled-mm

    self.weights = {
    "qkv_proj": rand_wq(hidden_dim, (num_heads + num_kv_heads * 2) * head_size),
    "o_proj": rand_wq(hidden_dim, hidden_dim),
    "gate_up_proj": rand_wq(hidden_dim, 2 * hidden_dim),
    "down_proj": rand_wq(hidden_dim, hidden_dim),
    "input_norm": rand_w(hidden_dim),
    "post_attn_norm": rand_w(hidden_dim),
    }
    if self.quantized:
    self.scales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}
    self.wscales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}

    def _linear(self, input: torch.Tensor, name: str) -> torch.Tensor:
    weight = self.weights[name]
    if not self.quantized:
    return input @ weight

    scale_a, scale_b = self.scales[name], self.wscales[name]
    qinput, scale_a = torch.ops.mirage.quantize(input, scale_a, dtype=self.qdtype)
    return torch._scaled_mm(qinput, weight, scale_a=scale_a, scale_b=scale_b)

    def forward(self, input: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor) \
    -> tuple[torch.Tensor, torch.Tensor]:
    input_norm, residual = torch.ops.mirage.rms_norm(input, self.weights["input_norm"], residual)

    qkv = self._linear(input_norm, "qkv_proj")
    q, k, v = qkv.split_with_sizes([
    self.num_heads * self.head_size,
    self.num_kv_heads * self.head_size,
    self.num_kv_heads * self.head_size
    ], dim=-1)

    q, k = torch.ops.mirage.rope(q, k, positions)

    out = torch.ops.mirage.attention(q, k, v)
    out2 = self._linear(out, "o_proj")

    out_norm, residual = torch.ops.mirage.rms_norm(out2, self.weights["post_attn_norm"], residual)

    # mlp
    up_gate = self._linear(out_norm, "gate_up_proj")
    silu = torch.ops.mirage.silu_mul(up_gate)
    down = self._linear(silu, "down_proj")

    return down, residual


    class SimpleLlama(nn.Module):
    def __init__(self,
    num_layers: int = 32,
    vocab_size: int = 128256,
    hidden_dim: int = 4096,
    num_heads: int = 32,
    num_kv_heads: int = 8,
    head_size: int = 128,
    dtype: torch.dtype = torch.float16,
    qdtype: Optional[torch.dtype] = None,
    ):
    super().__init__()

    rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")

    self.weights = {
    "embedding": rand_w(vocab_size, hidden_dim),
    "out_norm": rand_w(hidden_dim),
    }

    self.layers = nn.ModuleList([SimpleLlamaLayer(
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    head_size=head_size,
    dtype=dtype,
    qdtype=qdtype,
    ) for _ in range(num_layers)])

    def forward(self, input: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
    x_emb = F.embedding(input, self.weights["embedding"])
    x, residual = x_emb, None
    for layer in self.layers:
    x, residual = layer(x, residual, positions)

    x, _ = torch.ops.mirage.rms_norm(x, self.weights["out_norm"], residual)

    return x

    # ============================================================
    # Backends
    # ============================================================

    class AotBackend:
    """Boilerplace to get Mirage backend to """
    def __init__(self, compile_fn: Callable[[fx.GraphModule, Sequence], Callable[[Sequence], Any]]):
    self.compile_fn = compile_fn

    def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
    from torch._dynamo.backends.common import aot_autograd
    return aot_autograd(
    fw_compiler=self.compile_fn,
    decompositions=torch._inductor.compile_fx.select_decomp_table(),
    )(graph, example_inputs)

    # ============================================================
    # Skeleton for the actual Mirage backend that takes a Mirage-friendly fx graph and compiles it.
    # ============================================================
    class MirageBackend:
    def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
    """
    Receives normalized (post-grad) IR.
    """
    print(graph.graph.python_code(root_module="self").src)
    return self.run

    def run(self, *args, **kwargs):
    print(f"Forward called with {len(args)=} args and {len(kwargs)=} kwargs.")
    return torch.empty_like(args[0], dtype=torch.float16)



    torch.set_default_device("cuda")
    model = SimpleLlama()
    inputs = [torch.randint(0, 4096, (5,)), torch.arange(0, 4096)]
    model(*inputs)

    compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True)
    compiled_model(*inputs)

    qmodel = SimpleLlama(qdtype=torch.float8_e4m3fnuz)
    qmodel(*inputs)
    compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True)
    compiled_qmodel(*inputs)