Created
February 25, 2026 01:25
-
-
Save xangma/1c2a1b2f9ca871b1f15646eed60d10ab to your computer and use it in GitHub Desktop.
177 params
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import random | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import numpy as np | |
| from mlx.utils import tree_map | |
| from mlx_lm.models.qwen3 import Model, ModelArgs | |
| MODEL_LAYERS = 2 | |
| MODEL_DIM = 5 | |
| ATTENTION_HEADS = 2 | |
| KEY_VALUE_HEADS = 1 | |
| HEAD_DIM = 2 | |
| INTERMEDIATE_SIZE = 3 | |
| VOCAB_SIZE = 10 | |
| OUTPUT_DIGITS = 11 | |
| MAX_ADDEND = 10**10 - 1 | |
| VARIANT_CHOICES = ( | |
| "baseline", | |
| "rank1", | |
| "rank1+embed2", | |
| "rank1+embed2+sparse_gate0", | |
| "rank1+embed2+sparse_gate0+no_norm_weight", | |
| "rank1+embed2+sparse_gate0+no_norm_weight+lowrank_head2", | |
| ) | |
| VARIANT_FEATURES = { | |
| "baseline": frozenset(), | |
| "rank1": frozenset({"rank1"}), | |
| "rank1+embed2": frozenset({"rank1", "embed2"}), | |
| "rank1+embed2+sparse_gate0": frozenset({"rank1", "embed2", "sparse_gate0"}), | |
| "rank1+embed2+sparse_gate0+no_norm_weight": frozenset( | |
| {"rank1", "embed2", "sparse_gate0", "no_norm_weight"} | |
| ), | |
| "rank1+embed2+sparse_gate0+no_norm_weight+lowrank_head2": frozenset( | |
| {"rank1", "embed2", "sparse_gate0", "no_norm_weight", "lowrank_head2"} | |
| ), | |
| } | |
| LM_HEAD_WEIGHT = np.array( | |
| [ | |
| [5.5779090e00, 3.1322198e00, -4.0438358e02, 6.2589108e01, 9.9358273e-01], | |
| [5.0814748e00, 2.4687927e00, -3.1444955e02, 4.8671352e01, 7.7272820e-01], | |
| [3.6916721e00, 1.7657869e00, -2.2455742e02, 3.4757641e01, 5.5075526e-01], | |
| [1.4084998e00, 1.0232025e00, -1.3470717e02, 2.0847967e01, 3.2766387e-01], | |
| [-1.7680415e00, 2.4103954e-01, -4.4898785e01, 6.9423370e00, 1.0345399e-01], | |
| [-5.8379521e00, -5.8070201e-01, 4.4867714e01, -6.9592528e00, -1.2187435e-01], | |
| [-1.0801232e01, -1.4420221e00, 1.3459233e02, -2.0856800e01, -3.4832114e-01], | |
| [-1.6657881e01, -2.3429208e00, 2.2427509e02, -3.4750309e01, -5.7588643e-01], | |
| [-2.3407900e01, -3.2833982e00, 3.1391595e02, -4.8639774e01, -8.0457014e-01], | |
| [-3.1051287e01, -4.2634540e00, 4.0351492e02, -6.2525200e01, -1.0343723e00], | |
| ], | |
| dtype=np.float32, | |
| ) | |
| LAYER1_GATE_PROJ_WEIGHT = np.array( | |
| [ | |
| [-4.3951669e-01, 5.6323919e00, 4.9838150e-01, 1.3435575e03, 6.0357680e04], | |
| [-1.2112466e02, 3.2923722e-01, -5.0313854e00, 1.3449166e03, 6.0357438e04], | |
| [-1.3453412e02, -2.6000220e-01, -5.6458039e00, 1.3450677e03, 6.0357410e04], | |
| ], | |
| dtype=np.float32, | |
| ) | |
| def build_model_args() -> ModelArgs: | |
| return ModelArgs( | |
| model_type="qwen3", | |
| hidden_size=MODEL_DIM, | |
| num_hidden_layers=MODEL_LAYERS, | |
| intermediate_size=INTERMEDIATE_SIZE, | |
| num_attention_heads=ATTENTION_HEADS, | |
| rms_norm_eps=1e-6, | |
| vocab_size=VOCAB_SIZE, | |
| tie_word_embeddings=False, | |
| num_key_value_heads=KEY_VALUE_HEADS, | |
| max_position_embeddings=2048, | |
| rope_theta=10000, | |
| head_dim=HEAD_DIM, | |
| ) | |
| def _validate_addends(a: int, b: int) -> None: | |
| if not isinstance(a, int) or not isinstance(b, int): | |
| raise ValueError("a and b must be ints") | |
| if a < 0 or a > MAX_ADDEND or b < 0 or b > MAX_ADDEND: | |
| raise ValueError(f"a and b must be in [0, {MAX_ADDEND}]") | |
| def _encode_addends_internal(a: int, b: int) -> list[int]: | |
| _validate_addends(a, b) | |
| prompt = f"{a:010d}{b:010d}" | |
| a = [int(c) for c in prompt[:10]] | |
| b = [int(c) for c in prompt[10:]] | |
| return [0] + list(reversed(a)) + [0] + [0] + list(reversed(b)) + [0] | |
| def _expected_output(a: int, b: int) -> str: | |
| _validate_addends(a, b) | |
| return str(a + b)[::-1].ljust(OUTPUT_DIGITS, "0") | |
| class Rank1Linear(nn.Module): | |
| def __init__(self, out_features: int, in_features: int): | |
| super().__init__() | |
| self.u = mx.zeros((out_features,), dtype=mx.float32) | |
| self.v = mx.zeros((in_features,), dtype=mx.float32) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| s = mx.sum(x * self.v, axis=-1, keepdims=True) | |
| return s * self.u | |
| class FactorizedEmbedding(nn.Module): | |
| def __init__(self, vocab_size: int, dim: int): | |
| super().__init__() | |
| self.A = mx.zeros((vocab_size, 2), dtype=mx.float32) | |
| self.B = mx.zeros((2, dim), dtype=mx.float32) | |
| def __call__(self, ids: mx.array) -> mx.array: | |
| return self.A[ids] @ self.B | |
| class LowRankLMHead(nn.Module): | |
| def __init__(self, vocab_size: int, dim: int, rank: int): | |
| super().__init__() | |
| self.A = mx.zeros((vocab_size, rank), dtype=mx.float32) | |
| self.B = mx.zeros((rank, dim), dtype=mx.float32) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return (x @ self.B.T) @ self.A.T | |
| class SparseGateProj0(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.W23 = mx.zeros((2, 3), dtype=mx.float32) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| x3 = x[..., :3] | |
| y2 = x3 @ self.W23.T | |
| pad = mx.zeros((*y2.shape[:-1], 1), dtype=y2.dtype) | |
| return mx.concatenate([y2, pad], axis=-1) | |
| class RMSNormNoWeight(nn.Module): | |
| def __init__(self, dims: int, eps: float, scale: float): | |
| super().__init__() | |
| self.dims = dims | |
| self.eps = eps | |
| self.scale = scale | |
| def __call__(self, x: mx.array) -> mx.array: | |
| weight = mx.full((self.dims,), self.scale, dtype=x.dtype) | |
| return mx.fast.rms_norm(x, weight, self.eps) | |
| def _set_param(params, path: list, value) -> None: | |
| node = params | |
| for key in path[:-1]: | |
| node = node[key] | |
| node[path[-1]] = mx.array(value, dtype=mx.float32) | |
| def _set_rank1_factors(params, layer_idx: int, module_path: str, u, v) -> None: | |
| node = params["model"]["layers"][layer_idx] | |
| for key in module_path.split("."): | |
| node = node[key] | |
| node["u"] = mx.array(u, dtype=mx.float32) | |
| node["v"] = mx.array(v, dtype=mx.float32) | |
| def _set_rank1( | |
| params, | |
| layer_idx: int, | |
| module_path: str, | |
| u, | |
| v, | |
| w_original, | |
| verify_rank1: bool, | |
| ) -> None: | |
| u_arr = mx.array(u, dtype=mx.float32) | |
| v_arr = mx.array(v, dtype=mx.float32) | |
| _set_rank1_factors(params, layer_idx, module_path, u_arr, v_arr) | |
| if verify_rank1: | |
| w_hat = u_arr[:, None] * v_arr[None, :] | |
| w_orig_arr = mx.array(w_original, dtype=mx.float32) | |
| if not bool(np.array(mx.allclose(w_hat, w_orig_arr)).item()): | |
| raise AssertionError( | |
| f"Rank-1 reconstruction mismatch for layer {layer_idx} {module_path}" | |
| ) | |
| def _get_lowrank_head_rank(features: frozenset[str]) -> int | None: | |
| if "lowrank_head2" in features: | |
| return 2 | |
| return None | |
| def _factorize_lowrank(weight: np.ndarray, rank: int) -> tuple[np.ndarray, np.ndarray]: | |
| u, s, vt = np.linalg.svd(weight, full_matrices=False) | |
| a = u[:, :rank] * s[:rank] | |
| b = vt[:rank, :] | |
| return a.astype(np.float32), b.astype(np.float32) | |
| def apply_variant_monkeypatches(model: Model, variant: str) -> None: | |
| features = VARIANT_FEATURES[variant] | |
| lowrank_head_rank = _get_lowrank_head_rank(features) | |
| if "rank1" in features: | |
| for layer in model.model.layers: | |
| layer.self_attn.q_proj = Rank1Linear(4, 5) | |
| layer.self_attn.k_proj = Rank1Linear(2, 5) | |
| layer.self_attn.v_proj = Rank1Linear(2, 5) | |
| layer.self_attn.o_proj = Rank1Linear(5, 4) | |
| layer.mlp.up_proj = Rank1Linear(3, 5) | |
| layer.mlp.down_proj = Rank1Linear(5, 3) | |
| if "embed2" in features: | |
| model.model.embed_tokens = FactorizedEmbedding(VOCAB_SIZE, MODEL_DIM) | |
| if "sparse_gate0" in features: | |
| model.model.layers[0].mlp.gate_proj = SparseGateProj0() | |
| if lowrank_head_rank is not None: | |
| model.lm_head = LowRankLMHead(VOCAB_SIZE, MODEL_DIM, lowrank_head_rank) | |
| if "no_norm_weight" in features: | |
| for layer in model.model.layers: | |
| layer.input_layernorm = RMSNormNoWeight( | |
| MODEL_DIM, eps=model.args.rms_norm_eps, scale=1.0 | |
| ) | |
| layer.post_attention_layernorm = RMSNormNoWeight( | |
| MODEL_DIM, eps=model.args.rms_norm_eps, scale=1.0 | |
| ) | |
| layer.self_attn.q_norm = RMSNormNoWeight( | |
| HEAD_DIM, eps=model.args.rms_norm_eps, scale=16.0 | |
| ) | |
| layer.self_attn.k_norm = RMSNormNoWeight( | |
| HEAD_DIM, eps=model.args.rms_norm_eps, scale=16.0 | |
| ) | |
| model.model.norm = RMSNormNoWeight( | |
| MODEL_DIM, eps=model.args.rms_norm_eps, scale=1.0 | |
| ) | |
| def hand_set_weights_magic( | |
| model: Model, variant: str, verify_rank1: bool = False | |
| ) -> None: | |
| features = VARIANT_FEATURES[variant] | |
| use_rank1 = "rank1" in features | |
| use_embed2 = "embed2" in features | |
| use_sparse_gate0 = "sparse_gate0" in features | |
| use_no_norm = "no_norm_weight" in features | |
| lowrank_head_rank = _get_lowrank_head_rank(features) | |
| params = tree_map(lambda x: mx.zeros_like(x), model.parameters()) | |
| if lowrank_head_rank is None: | |
| _set_param(params, ["lm_head", "weight"], LM_HEAD_WEIGHT) | |
| else: | |
| a, b = _factorize_lowrank(LM_HEAD_WEIGHT, rank=lowrank_head_rank) | |
| _set_param(params, ["lm_head", "A"], a) | |
| _set_param(params, ["lm_head", "B"], b) | |
| if use_embed2: | |
| _set_param( | |
| params, | |
| ["model", "embed_tokens", "A"], | |
| [[1.0, float(i)] for i in range(VOCAB_SIZE)], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "embed_tokens", "B"], | |
| [[100.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]], | |
| ) | |
| else: | |
| _set_param( | |
| params, | |
| ["model", "embed_tokens", "weight"], | |
| [[100.0, float(i), 0.0, 0.0, 0.0] for i in range(VOCAB_SIZE)], | |
| ) | |
| if not use_no_norm: | |
| for layer_idx in range(MODEL_LAYERS): | |
| _set_param( | |
| params, | |
| ["model", "layers", layer_idx, "input_layernorm", "weight"], | |
| [1.0, 1.0, 1.0, 1.0, 1.0], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", layer_idx, "post_attention_layernorm", "weight"], | |
| [1.0, 1.0, 1.0, 1.0, 1.0], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", layer_idx, "self_attn", "q_norm", "weight"], | |
| [16.0, 16.0], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", layer_idx, "self_attn", "k_norm", "weight"], | |
| [16.0, 16.0], | |
| ) | |
| _set_param(params, ["model", "norm", "weight"], [1.0, 1.0, 1.0, 1.0, 1.0]) | |
| if use_sparse_gate0: | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "mlp", "gate_proj", "W23"], | |
| [ | |
| [-3.3532020e-01, -1.3412670e03, 6.0353305e04], | |
| [-1.3743691e01, -1.3418693e03, 6.0353277e04], | |
| ], | |
| ) | |
| else: | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "mlp", "gate_proj", "weight"], | |
| [ | |
| [-3.3532020e-01, -1.3412670e03, 6.0353305e04, 0.0, 0.0], | |
| [-1.3743691e01, -1.3418693e03, 6.0353277e04, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "mlp", "gate_proj", "weight"], | |
| LAYER1_GATE_PROJ_WEIGHT, | |
| ) | |
| if use_rank1: | |
| _set_rank1( | |
| params, | |
| 0, | |
| "self_attn.q_proj", | |
| [0.98502123, 0.17243294, 0.96630472, -0.25740093], | |
| [1.0, 0.0, 0.0, 0.0, 0.0], | |
| [ | |
| [0.98502123, 0.0, 0.0, 0.0, 0.0], | |
| [0.17243294, 0.0, 0.0, 0.0, 0.0], | |
| [0.96630472, 0.0, 0.0, 0.0, 0.0], | |
| [-0.25740093, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 0, | |
| "self_attn.k_proj", | |
| [-0.31672141, -0.94851863], | |
| [1.0, 0.0, 0.0, 0.0, 0.0], | |
| [ | |
| [-0.31672141, 0.0, 0.0, 0.0, 0.0], | |
| [-0.94851863, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 0, | |
| "self_attn.v_proj", | |
| [1.0, 0.0], | |
| [0.0, 1.0, 0.0, 0.0, 0.0], | |
| [[0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 0, | |
| "self_attn.o_proj", | |
| [0.0, 0.0, 1.0, 0.0, 0.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| [ | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 0, | |
| "mlp.up_proj", | |
| [1.0, 1.0, 0.0], | |
| [1.4898191e-02, 6.6922739e-04, 2.9977213e-05, 0.0, 0.0], | |
| [ | |
| [1.4898191e-02, 6.6922739e-04, 2.9977213e-05, 0.0, 0.0], | |
| [1.4898191e-02, 6.6922739e-04, 2.9977213e-05, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 0, | |
| "mlp.down_proj", | |
| [0.0, 0.0, 0.0, 1.0, 0.0], | |
| [1.0, -1.0, 0.0], | |
| [ | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [1.0, -1.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "self_attn.q_proj", | |
| [-0.25507239, 0.96692199, 0.17478994, 0.98460573], | |
| [1.0, 0.0, 0.0, 0.0, 0.0], | |
| [ | |
| [-0.25507239, 0.0, 0.0, 0.0, 0.0], | |
| [0.96692199, 0.0, 0.0, 0.0, 0.0], | |
| [0.17478994, 0.0, 0.0, 0.0, 0.0], | |
| [0.98460573, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "self_attn.k_proj", | |
| [0.32702553, -0.94501549], | |
| [1.0, 0.0, 0.0, 0.0, 0.0], | |
| [[0.32702553, 0.0, 0.0, 0.0, 0.0], [-0.94501549, 0.0, 0.0, 0.0, 0.0]], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "self_attn.v_proj", | |
| [1.0, 0.0], | |
| [0.0, 1.0, 0.0, 0.0, 0.0], | |
| [[0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "self_attn.o_proj", | |
| [0.0, 0.0, 0.0, 0.0, 1.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| [ | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "mlp.up_proj", | |
| [1.0, 1.0, 1.0], | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| [ | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| ], | |
| verify_rank1, | |
| ) | |
| _set_rank1( | |
| params, | |
| 1, | |
| "mlp.down_proj", | |
| [0.0, 0.0, 1.0, 0.0, 0.0], | |
| [1.0, -10.0, 10.0], | |
| [ | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [1.0, -10.0, 10.0], | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| ], | |
| verify_rank1, | |
| ) | |
| else: | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "self_attn", "q_proj", "weight"], | |
| [ | |
| [0.98502123, 0.0, 0.0, 0.0, 0.0], | |
| [0.17243294, 0.0, 0.0, 0.0, 0.0], | |
| [0.96630472, 0.0, 0.0, 0.0, 0.0], | |
| [-0.25740093, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "self_attn", "k_proj", "weight"], | |
| [[-0.31672141, 0.0, 0.0, 0.0, 0.0], [-0.94851863, 0.0, 0.0, 0.0, 0.0]], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "self_attn", "v_proj", "weight"], | |
| [[0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "self_attn", "o_proj", "weight"], | |
| [ | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "mlp", "up_proj", "weight"], | |
| [ | |
| [1.4898191e-02, 6.6922739e-04, 2.9977213e-05, 0.0, 0.0], | |
| [1.4898191e-02, 6.6922739e-04, 2.9977213e-05, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 0, "mlp", "down_proj", "weight"], | |
| [ | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [1.0, -1.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "self_attn", "q_proj", "weight"], | |
| [ | |
| [-0.25507239, 0.0, 0.0, 0.0, 0.0], | |
| [0.96692199, 0.0, 0.0, 0.0, 0.0], | |
| [0.17478994, 0.0, 0.0, 0.0, 0.0], | |
| [0.98460573, 0.0, 0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "self_attn", "k_proj", "weight"], | |
| [[0.32702553, 0.0, 0.0, 0.0, 0.0], [-0.94501549, 0.0, 0.0, 0.0, 0.0]], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "self_attn", "v_proj", "weight"], | |
| [[0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "self_attn", "o_proj", "weight"], | |
| [ | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0], | |
| [1.0, 0.0, 1.0, 0.0], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "mlp", "up_proj", "weight"], | |
| [ | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| [ | |
| 1.4899401e-02, | |
| 6.5471046e-04, | |
| 6.8268733e-04, | |
| -1.6779384e-04, | |
| 2.9817384e-05, | |
| ], | |
| ], | |
| ) | |
| _set_param( | |
| params, | |
| ["model", "layers", 1, "mlp", "down_proj", "weight"], | |
| [ | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| [1.0, -10.0, 10.0], | |
| [0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0], | |
| ], | |
| ) | |
| model.update(params) | |
| mx.eval(model.parameters()) | |
| def build_magic_model(variant: str, verify_rank1: bool = False) -> Model: | |
| model = Model(build_model_args()) | |
| apply_variant_monkeypatches(model, variant) | |
| hand_set_weights_magic(model, variant, verify_rank1=verify_rank1) | |
| return model | |
| def _generate_output_batch(model: Model, addends: list[tuple[int, int]]) -> list[str]: | |
| internal = [_encode_addends_internal(a, b) for a, b in addends] | |
| for _ in range(OUTPUT_DIGITS): | |
| x = mx.array(internal, dtype=mx.int32) | |
| logits = model(x) | |
| next_digits = np.array(mx.argmax(logits[:, -1, :], axis=-1), dtype=np.int32) | |
| for seq, next_digit in zip(internal, next_digits): | |
| seq.append(int(next_digit)) | |
| return ["".join(str(d) for d in seq[-OUTPUT_DIGITS:]) for seq in internal] | |
| def run_self_test_batched(model: Model, num_tests: int, batch_size: int) -> None: | |
| rng = random.Random(123) | |
| tested = 0 | |
| while tested < num_tests: | |
| cur_batch_size = min(batch_size, num_tests - tested) | |
| addends = [] | |
| expected = [] | |
| for _ in range(cur_batch_size): | |
| a = rng.randint(0, 10**10 - 1) | |
| b = rng.randint(0, 10**10 - 1) | |
| addends.append((a, b)) | |
| expected.append(_expected_output(a, b)) | |
| actual = _generate_output_batch(model, addends) | |
| for (a, b), exp, act in zip(addends, expected, actual): | |
| if act != exp: | |
| raise AssertionError( | |
| f"Mismatch for a={a:010d}, b={b:010d}: expected {exp}, got {act}" | |
| ) | |
| tested += cur_batch_size | |
| print(f"self-test progress: {tested}/{num_tests}") | |
| def count_parameters(node) -> int: | |
| if isinstance(node, dict): | |
| return sum(count_parameters(v) for v in node.values()) | |
| if isinstance(node, (list, tuple)): | |
| return sum(count_parameters(v) for v in node) | |
| if hasattr(node, "shape"): | |
| n = 1 | |
| for dim in node.shape: | |
| n *= int(dim) | |
| return n | |
| return 0 | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--num-tests", type=int, default=8192) | |
| parser.add_argument("--batch-size", type=int, default=1024) | |
| parser.add_argument("--variant", choices=VARIANT_CHOICES, default="baseline") | |
| parser.add_argument("--quick", action="store_true") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| if args.quick: | |
| args.num_tests = 1024 | |
| args.batch_size = 256 | |
| if args.batch_size <= 0: | |
| raise ValueError("--batch-size must be > 0") | |
| if args.num_tests < 0: | |
| raise ValueError("--num-tests must be >= 0") | |
| model = build_magic_model(args.variant, verify_rank1=args.quick) | |
| print(f"variant: {args.variant}") | |
| print(f"parameter count: {count_parameters(model.parameters())}") | |
| try: | |
| run_self_test_batched(model, args.num_tests, args.batch_size) | |
| except AssertionError as e: | |
| print(f"self-test: FAIL ({e})") | |
| raise SystemExit(1) | |
| print( | |
| f"self-test: PASS ({args.num_tests} random cases, batch size {args.batch_size})" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
| # Thanks to @N8Programs for the original script. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment