Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created March 9, 2026 18:07
Show Gist options
  • Select an option

  • Save justinchuby/2ec93c1c46405ed405a70c1357c1b180 to your computer and use it in GitHub Desktop.

Select an option

Save justinchuby/2ec93c1c46405ed405a70c1357c1b180 to your computer and use it in GitHub Desktop.
Model Builder flags
# Mask-specific variables
# TODO: Reconcile differences between `seqlens_k` and `key_total_seq_lens` in the GroupQueryAttention and SparseAttention implementations. Ideally the same subgraph can be shared for both.
self.mask_attrs = {
# "mask_name": "", # Name of node that outputs 4D causal attention mask (used as add_qk in MultiHeadAttention)
"seqlens_k": "", # Sum of each row in attention mask - 1 (used as input to GroupQueryAttention)
"total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention and SparseAttention)
# "block_row_indices": "", # Row indices of CSR format of block mask (used as input to SparseAttention)
# "block_col_indices": "", # Col indices of CSR format of block mask (used as input to SparseAttention)
# "key_total_seq_lens": "", # Sum of each row in attention mask (used as input to SparseAttention)
}
# Embedding-specific variables
self.embed_attrs = {
"scale": 1, # Scale value to multiply output of Embedding layer by # from config
}
# LayerNorm-specific variables
epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06
self.layernorm_attrs = {
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
# can have combinations (below)
"first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms
"last_layernorm": False, # Last LayerNorm = SkipLayerNorm with only output 0 (no output 3)
# states (below)
# "root_input": "", # Root input from parent node for LayerNorm and SkipLayerNorm
# "skip_input": "", # Skip input from parent node for SkipLayerNorm
# "output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
# "output_3": "", # Output 3 for SkipLayerNorm
"add_offset": 0, # Offset value for LayerNorm weight
# config
"epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm
# useful
"cast": { # Casting LayerNorm-specific variables
"use_fp32": False, # Use float32 precision to compute LayerNorm
"root_input": False, # Cast root_input
"skip_input": False, # Cast skip_input
"output_0": False, # Cast output_0
"output_3": False, # Cast output_3
},
}
# MatMul-specific variables
is_lora = hasattr(config, "peft_type") and config.peft_type == "LORA"
self.matmul_attrs = {
"use_lora": is_lora, # Use LoRA/QLoRA format
}
# RotaryEmbedding-specific variables
position_scale = config.rope_position_scale if hasattr(config, "rope_position_scale") else 1
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
rotemb_dim = int(self.head_size * partial_rotary_factor) if partial_rotary_factor != 1.0 else 0
rope_theta = (
config.rope_theta
if hasattr(config, "rope_theta")
else config.rope_embedding_base
if hasattr(config, "rope_embedding_base")
else 10000
)
self.rope_attrs = {
# caches are shared if they are the same
"create_caches": True, # Create cos/sin caches for rotary embeddings
"save_caches": True, # Auto-save cos/sin caches for rotary embeddings after creation
# config
"cache_length": self.context_length, # Cache length to use when creating cos/sin caches for rotary embeddings
"theta": rope_theta, # Base value if calculating cos/sin caches from scratch
"partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings
"interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": rotemb_dim, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rescale_factors": 1, # Rescale factors when calculating `inv_freq` in rotary embeddings
"t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings
"position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings
"mscale": 1, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings
# policies can be different
"mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings
}
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.make_rope_init(config)
# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
attn_softcap = (
config.attn_logit_softcapping
if hasattr(config, "attn_logit_softcapping") and config.attn_logit_softcapping is not None
else 0.0
) # default is 0.0 in GroupQueryAttention kernel
# Block-sparse attention-specific variables
sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0
kernel_block_size = (
config.blocksparse_triton_kernel_block_size
if hasattr(config, "blocksparse_triton_kernel_block_size")
else 0
)
local_blocks = config.blocksparse_num_local_blocks if hasattr(config, "blocksparse_num_local_blocks") else 0
vert_block_stride = config.blocksparse_vert_stride if hasattr(config, "blocksparse_vert_stride") else 0
homo_head = config.blocksparse_homo_head_pattern if hasattr(config, "blocksparse_homo_head_pattern") else False
# useful
self.attention_attrs = {
# "q_path": "", # Q path to attention
# "k_path": "", # K path to attention
# "v_path": "", # V path to attention
# pick attention op
# "op_type": "MultiHeadAttention", # Attention op to use
# config
"scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention
"softcap": attn_softcap, # Softcap value to prevent values from exploding in attention
# rope fusion
"use_rope_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
# phi model
"block_sparse": { # Block-sparse attention-specific variables
"sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op
"kernel_block_size": kernel_block_size, # Kernel block size for sparse attention
"local_blocks": local_blocks, # Number of local blocks for sparse attention
"vert_stride": vert_block_stride, # Vertical stride to use for sparse attention
"homo_head": homo_head, # Use homo head pattern for sparse attention
},
# useful
"q_norm": False, # LayerNorm after MatMul in Q path
"k_norm": False, # LayerNorm after MatMul in K path
"sinks": False, # Sink values for softmax in attention
}
self.make_attention_init()
# MLP-specific variables
self.mlp_attrs = {
# exclusive
"use_proj": True, # Use projection style for MLP (GateProj/UpProj/DownProj)
"use_fc": False, # Use fully-connected style for MLP (FC1/FC2)
# "output_0": "", # Output 0 for MLP layer
}
# MoE-specific variables
# precision determined
moe_op_type = "QMoE" if self.onnx_dtype == ir.DataType.INT4 else "MoE"
# config
num_experts = config.num_local_experts if hasattr(config, "num_local_experts") else 0
top_k_experts = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 0
expert_weight_bits = 8 if extra_options.get("use_8bits_moe", False) else 4
swiglu_limit = config.swiglu_limit if hasattr(config, "swiglu_limit") else None
# config end
# all moe attributes!
self.moe_attrs = {
"op_type": moe_op_type, # MoE op to use
"num_experts": num_experts, # Number of experts in MoE layer
"top_k": top_k_experts, # Number of experts to select in MoE layer
"activation_alpha": 1.0, # Alpha parameter used in activation function
"activation_beta": 0.0, # Beta parameter used in activation function
"activation_type": self.activation, # Activation function for MoE layer
"expert_weight_bits": expert_weight_bits, # Number of bits used in quantized MoE weights (only INT4 or INT8 are supported).
# depend on the model
# from modeling logic
"normalize_routing_weights": False, # Normalize routing weights in MoE layer
# fusion level depending on the model
"swiglu_fusion": 0, # Fusion level for SwiGLU activation function
"swiglu_limit": swiglu_limit, # Value used to clamp results into a certain range in SwiGLU activation function
# phi-3.5 specific
"use_sparse_mixer": False, # Use SparseMixer in MoE layer (used in Phi-3.5 MoE)
}
# LM head-specific variables
lm_head_softcap = (
config.final_logit_softcapping
if hasattr(config, "final_logit_softcapping") and config.final_logit_softcapping is not None
else 0.0
) # default is 0.0 in GroupQueryAttention kernel
self.lm_head_attrs = {
"scale": 1, # Scale value to multiply output of LM head by
# gemma specific
"mask": None, # LM head mask for tokens in the vocabulary
"softcap": lm_head_softcap, # Softcap value to prevent values from exploding in LM head
}
if hasattr(config, "dummy_token_indices"):
# Create LM head mask for tokens in the vocabulary
dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
dummy_tokens_mask[config.dummy_token_indices] = True
self.lm_head_attrs["mask"] = dummy_tokens_mask
# Quantization-specific variables (INT4, INT8, etc.)
# if we have quant model input
int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default"))
self.int4_block_size = extra_options.get("int4_block_size", 32)
# CPU, WebGPU, and TRT-RTX support block-wise quantization for QMoE.
# TRT-RTX defaults to 128; others default to 32 for consistency with MatMulNBits.
supported_blockwise_eps = ["cpu", "webgpu", "trt-rtx"]
# attr
default_qmoe_block_size = 128 if self.ep == "trt-rtx" else 32
self.qmoe_block_size = int(extra_options.get("qmoe_block_size", default_qmoe_block_size))
# Validate that unsupported EPs don't explicitly request block-wise quantization
if self.ep not in supported_blockwise_eps and "qmoe_block_size" in extra_options and moe_op_type == "QMoE":
raise ValueError(
f"The 'qmoe_block_size' option is not supported for {self.ep} execution provider with QMoE. "
f"Block-wise quantization is only supported for: {', '.join(supported_blockwise_eps)}."
)
self.quant_attrs = {
"int4": {
"accuracy_level": int(
extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0)
),
"qmoe_block_size": int(self.qmoe_block_size),
"qdq_block_size": int(self.int4_block_size),
"is_symmetric": extra_options.get("int4_is_symmetric", True),
"op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul",)),
"nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []),
"algo_config": int4_algo_config,
},
"use_qdq": extra_options.get("use_qdq", False),
}
# Needs clean up
# Propagate block_size to MoE/QMoE op when supported.
# QMoE on supported EPs uses block-wise quantization via the 'block_size' attribute.
# Ensure the attribute is set on the MoE op so runtime kernels can honor it.
if self.moe_attrs.get("op_type") == "QMoE" and self.ep in supported_blockwise_eps:
self.moe_attrs["block_size"] = int(self.qmoe_block_size)
# Correct
if self.quant_type is not None:
# Create quantized attributes from quantization config
self.quant_attrs["config"] = config.quantization_config
self.quant_attrs["use_g_idx"] = (
config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False
)
# controls embedding and lm_head quantization since they often share weights and should be quantized together for consistency. If the lm_head is unquantized, then we should not quantize the embeddings even if the quantization config says to, since that would lead to a large accuracy drop.
# Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized.
self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["int4"][
"nodes_to_exclude"
] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}
self.shared_embeddings = extra_options.get(
"shared_embeddings",
config.tie_word_embeddings
if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None
else False,
)
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {
"k_quant_mixed",
"k_quant_last",
"rtn_last",
}
# shared_embeddings conflicts with exclude_embeds and exclude_lm_head
if self.shared_embeddings and (self.exclude_embeds or self.exclude_lm_head):
self.shared_embeddings = False
elif self.shared_embeddings and not self.unquantized_lm_head:
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
self.shared_embeddings = self.int8_lm_head or extra_options.get("int4_algo_config", "default") in {
"rtn",
"k_quant",
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment