Last active
September 22, 2025 16:08
-
-
Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.
Revisions
-
ProExpertProg renamed this gist
Aug 13, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -284,7 +284,7 @@ def run(self, *args, **kwargs): compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True) compiled_model(*inputs) qmodel = SimpleLlama(qdtype=torch.float8_e4m3fn) qmodel(*inputs) compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True) compiled_qmodel(*inputs) -
ProExpertProg created this gist
Aug 12, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,290 @@ from typing import Optional, Callable, Sequence, Any import torch from torch import nn, fx from torch.library import Library import torch.nn.functional as F import torch._inductor import torch._inductor.compile_fx mirage_lib = Library("mirage", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: list[str] = [], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, dispatch_key: str = "CUDA", tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ import torch.library schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) my_lib = target_lib or mirage_lib my_lib.define(op_name + schema_str, tags=tags) my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) # ============================================================ # Mirage placeholder op registration # ============================================================ def rms_norm(input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor], epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]: # Never actually called print("rms_norm") if residual is None: residual = input return torch.zeros_like(input), residual def rms_norm_fake(input: torch.Tensor, weight: torch.Tensor, residual: Optional[torch.Tensor], epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(input), torch.empty_like(input) direct_register_custom_op("rms_norm", rms_norm, fake_impl=rms_norm_fake) def silu_mul(input: torch.Tensor) -> torch.Tensor: # Never actually called print("silu_mul") return torch.zeros_like(input[..., 0:input.shape[1] // 2]) def silu_mul_fake(input: torch.Tensor) -> torch.Tensor: return torch.empty_like(input[..., 0:input.shape[1] // 2]) direct_register_custom_op("silu_mul", silu_mul, fake_impl=silu_mul_fake) def rope(q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Never actually called print("rope") return torch.zeros_like(q), torch.zeros_like(k) def rope_fake(q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(q), torch.empty_like(k) direct_register_custom_op("rope", rope, fake_impl=rope_fake) def quantize(input: torch.Tensor, scale: Optional[torch.Tensor], dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: # Never actually called print("quantize") return torch.zeros_like(input, dtype=dtype), scale def quantize_fake(input: torch.Tensor, scale: Optional[torch.Tensor], dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(input, dtype=dtype), scale direct_register_custom_op("quantize", quantize, fake_impl=quantize_fake) def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: # Never actually called print("attention") return torch.zeros_like(q) def attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return torch.empty_like(q) direct_register_custom_op("attention", attention, fake_impl=attention_fake) # ============================================================ # Example PyTorch-model # ============================================================ class SimpleLlamaLayer(nn.Module): def __init__(self, hidden_dim: int = 4096, num_heads: int = 32, num_kv_heads: int = 8, head_size: int = 128, dtype: torch.dtype = torch.float16, qdtype: Optional[torch.dtype] = None, ): super().__init__() if qdtype is None: qdtype = dtype self.hidden_dim = hidden_dim self.head_size = head_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.dtype = dtype self.qdtype = qdtype self.quantized = qdtype != dtype rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda") rand_wq = lambda *dims, **kwargs: rand_w(*dims, **kwargs).to(dtype=qdtype).t().contiguous().t() # column-major for scaled-mm self.weights = { "qkv_proj": rand_wq(hidden_dim, (num_heads + num_kv_heads * 2) * head_size), "o_proj": rand_wq(hidden_dim, hidden_dim), "gate_up_proj": rand_wq(hidden_dim, 2 * hidden_dim), "down_proj": rand_wq(hidden_dim, hidden_dim), "input_norm": rand_w(hidden_dim), "post_attn_norm": rand_w(hidden_dim), } if self.quantized: self.scales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights} self.wscales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights} def _linear(self, input: torch.Tensor, name: str) -> torch.Tensor: weight = self.weights[name] if not self.quantized: return input @ weight scale_a, scale_b = self.scales[name], self.wscales[name] qinput, scale_a = torch.ops.mirage.quantize(input, scale_a, dtype=self.qdtype) return torch._scaled_mm(qinput, weight, scale_a=scale_a, scale_b=scale_b) def forward(self, input: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: input_norm, residual = torch.ops.mirage.rms_norm(input, self.weights["input_norm"], residual) qkv = self._linear(input_norm, "qkv_proj") q, k, v = qkv.split_with_sizes([ self.num_heads * self.head_size, self.num_kv_heads * self.head_size, self.num_kv_heads * self.head_size ], dim=-1) q, k = torch.ops.mirage.rope(q, k, positions) out = torch.ops.mirage.attention(q, k, v) out2 = self._linear(out, "o_proj") out_norm, residual = torch.ops.mirage.rms_norm(out2, self.weights["post_attn_norm"], residual) # mlp up_gate = self._linear(out_norm, "gate_up_proj") silu = torch.ops.mirage.silu_mul(up_gate) down = self._linear(silu, "down_proj") return down, residual class SimpleLlama(nn.Module): def __init__(self, num_layers: int = 32, vocab_size: int = 128256, hidden_dim: int = 4096, num_heads: int = 32, num_kv_heads: int = 8, head_size: int = 128, dtype: torch.dtype = torch.float16, qdtype: Optional[torch.dtype] = None, ): super().__init__() rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda") self.weights = { "embedding": rand_w(vocab_size, hidden_dim), "out_norm": rand_w(hidden_dim), } self.layers = nn.ModuleList([SimpleLlamaLayer( hidden_dim=hidden_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, qdtype=qdtype, ) for _ in range(num_layers)]) def forward(self, input: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: x_emb = F.embedding(input, self.weights["embedding"]) x, residual = x_emb, None for layer in self.layers: x, residual = layer(x, residual, positions) x, _ = torch.ops.mirage.rms_norm(x, self.weights["out_norm"], residual) return x # ============================================================ # Backends # ============================================================ class AotBackend: """Boilerplace to get Mirage backend to """ def __init__(self, compile_fn: Callable[[fx.GraphModule, Sequence], Callable[[Sequence], Any]]): self.compile_fn = compile_fn def __call__(self, graph: fx.GraphModule, example_inputs: Sequence): from torch._dynamo.backends.common import aot_autograd return aot_autograd( fw_compiler=self.compile_fn, decompositions=torch._inductor.compile_fx.select_decomp_table(), )(graph, example_inputs) # ============================================================ # Skeleton for the actual Mirage backend that takes a Mirage-friendly fx graph and compiles it. # ============================================================ class MirageBackend: def __call__(self, graph: fx.GraphModule, example_inputs: Sequence): """ Receives normalized (post-grad) IR. """ print(graph.graph.python_code(root_module="self").src) return self.run def run(self, *args, **kwargs): print(f"Forward called with {len(args)=} args and {len(kwargs)=} kwargs.") return torch.empty_like(args[0], dtype=torch.float16) torch.set_default_device("cuda") model = SimpleLlama() inputs = [torch.randint(0, 4096, (5,)), torch.arange(0, 4096)] model(*inputs) compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True) compiled_model(*inputs) qmodel = SimpleLlama(qdtype=torch.float8_e4m3fnuz) qmodel(*inputs) compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True) compiled_qmodel(*inputs)