Last active
April 13, 2026 10:58
-
-
Save egorsmkv/ded71fb20a1e878a4f923f88b7b8948f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ----------------------------------------------------------------------------- | |
| # HiF4 reference implementation | |
| # Paper assumptions used here: | |
| # - 1 HiF4 block = 64 values | |
| # - metadata = 1x E6M2 + 8x E1_8 + 16x E1_16 | |
| # - payload = 64x S1P2 | |
| # - rounding mode below is round-half-away-from-zero (allowed by the paper) | |
| # | |
| # This is a reference/correctness implementation. | |
| # It stores weights in HiF4-packed form, then dequantizes in forward(). | |
| # It is not a fused high-performance kernel. | |
| # ----------------------------------------------------------------------------- | |
| _HIF4_BLOCK = 64 | |
| _BF16_ONE_OVER_SEVEN = torch.tensor(1.0 / 7.0, dtype=torch.bfloat16).float() | |
| def round_half_away_from_zero(x: torch.Tensor) -> torch.Tensor: | |
| return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) | |
| def pack_bits(bits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pack 0/1 bits along the last dimension into uint8 bytes. | |
| Bit order inside each byte is LSB-first. | |
| """ | |
| bits = bits.to(torch.uint8) | |
| pad = (-bits.shape[-1]) % 8 | |
| if pad: | |
| bits = F.pad(bits, (0, pad)) | |
| bits = bits.reshape(*bits.shape[:-1], -1, 8) | |
| shifts = (1 << torch.arange(8, device=bits.device, dtype=torch.int64)).view( | |
| *([1] * (bits.ndim - 1)), 8 | |
| ) | |
| return (bits.to(torch.int64) * shifts).sum(dim=-1).to(torch.uint8) | |
| def unpack_bits(packed: torch.Tensor, nbits: int) -> torch.Tensor: | |
| shifts = torch.arange(8, device=packed.device, dtype=torch.int64) | |
| bits = ((packed.to(torch.int64).unsqueeze(-1) >> shifts) & 1).to(torch.uint8) | |
| return bits.reshape(*packed.shape[:-1], -1)[..., :nbits] | |
| def pack_nibbles(codes: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pack 4-bit codes along the last dimension. | |
| codes[..., 0] goes into low nibble, codes[..., 1] into high nibble. | |
| """ | |
| codes = codes.to(torch.uint8) | |
| assert codes.shape[-1] % 2 == 0 | |
| lo = codes[..., 0::2] | |
| hi = codes[..., 1::2] | |
| return (lo | (hi << 4)).to(torch.uint8) | |
| def unpack_nibbles(packed: torch.Tensor, count: int) -> torch.Tensor: | |
| lo = packed & 0x0F | |
| hi = (packed >> 4) & 0x0F | |
| out = torch.empty( | |
| *packed.shape[:-1], | |
| packed.shape[-1] * 2, | |
| dtype=torch.uint8, | |
| device=packed.device, | |
| ) | |
| out[..., 0::2] = lo | |
| out[..., 1::2] = hi | |
| return out[..., :count] | |
| def decode_e6m2(code: torch.Tensor) -> torch.Tensor: | |
| """ | |
| E6M2 is an unsigned 8-bit float with: | |
| exponent bits = 6, bias = 48 | |
| mantissa bits = 2, hidden leading 1 | |
| 0xFF reserved for NaN | |
| """ | |
| code_i = code.to(torch.int64) | |
| exp = (code_i >> 2) & 0x3F | |
| mant = code_i & 0x03 | |
| value = torch.ldexp( | |
| 1.0 + mant.to(torch.float32) / 4.0, | |
| (exp - 48).to(torch.int32), | |
| ) | |
| value = torch.where(code_i == 0xFF, torch.full_like(value, float("nan")), value) | |
| return value | |
| def quantize_e6m2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Quantize positive scales to E6M2. | |
| Returns: | |
| code: uint8 E6M2 code | |
| value: decoded float32 E6M2 value | |
| """ | |
| x = x.to(torch.float32) | |
| zero_mask = x <= 0 | |
| # E6M2 has no zero. For all-zero blocks we keep S1P2=0 and park the scale | |
| # at the minimum normal value so dequantization still returns exact zeros. | |
| x_safe = torch.where(zero_mask, torch.full_like(x, 2.0 ** -48), x) | |
| exp_unbiased = torch.floor(torch.log2(x_safe)).to(torch.int64) | |
| significand = x_safe / torch.pow( | |
| torch.tensor(2.0, device=x.device), exp_unbiased.to(torch.float32) | |
| ) # in [1, 2) | |
| mant = round_half_away_from_zero((significand - 1.0) * 4.0).to(torch.int64) | |
| # carry if mant rounded to 4 | |
| carry = mant == 4 | |
| exp_unbiased = exp_unbiased + carry.to(torch.int64) | |
| mant = torch.where(carry, torch.zeros_like(mant), mant) | |
| # clamp to representable range | |
| under = exp_unbiased < -48 | |
| exp_unbiased = torch.where(under, torch.full_like(exp_unbiased, -48), exp_unbiased) | |
| mant = torch.where(under, torch.zeros_like(mant), mant) | |
| over = exp_unbiased > 15 | |
| exp_unbiased = torch.where(over, torch.full_like(exp_unbiased, 15), exp_unbiased) | |
| mant = torch.where(over, torch.full_like(mant, 2), mant) | |
| # 0xFF is NaN, so clip (exp=15, mant=3) down to max finite (mant=2) | |
| mant = torch.where( | |
| (exp_unbiased == 15) & (mant > 2), | |
| torch.full_like(mant, 2), | |
| mant, | |
| ) | |
| exp_code = (exp_unbiased + 48).clamp(0, 63) | |
| code = ((exp_code << 2) | mant).to(torch.uint8) | |
| # all-zero block special case | |
| code = torch.where(zero_mask, torch.zeros_like(code), code) | |
| return code, decode_e6m2(code) | |
| def decode_s1p2(code: torch.Tensor) -> torch.Tensor: | |
| """ | |
| S1P2 sign-magnitude decoding: | |
| sign bit = bit 3 | |
| magnitude = bits [2:0] interpreted in quarter steps | |
| representable set = {0.00, 0.25, ..., 1.75} with sign | |
| """ | |
| code_i = code.to(torch.int64) | |
| sign = torch.where(((code_i >> 3) & 1) == 1, -1.0, 1.0) | |
| mag = (code_i & 0x7).to(torch.float32) / 4.0 | |
| return sign * mag | |
| def quantize_s1p2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| x = x.to(torch.float32).clamp(-1.75, 1.75) | |
| sign_bit = (x < 0).to(torch.uint8) | |
| mag = round_half_away_from_zero(torch.abs(x) * 4.0).clamp_(0, 7).to(torch.uint8) | |
| # Canonicalize both +0 and -0 to +0 in software. | |
| sign_bit = torch.where(mag == 0, torch.zeros_like(sign_bit), sign_bit) | |
| code = ((sign_bit << 3) | mag).to(torch.uint8) | |
| return code, decode_s1p2(code) | |
| def _reshape_into_64_blocks(x: torch.Tensor) -> tuple[torch.Tensor, int, int]: | |
| """ | |
| Groups the last dimension into contiguous 64-value blocks. | |
| Pads with zeros if needed. | |
| """ | |
| original_last_dim = x.shape[-1] | |
| pad = (-original_last_dim) % _HIF4_BLOCK | |
| if pad: | |
| x = F.pad(x, (0, pad)) | |
| x = x.reshape(*x.shape[:-1], x.shape[-1] // _HIF4_BLOCK, _HIF4_BLOCK) | |
| return x, original_last_dim, pad | |
| @torch.no_grad() | |
| def quantize_to_hif4_blocks(x: torch.Tensor) -> dict[str, torch.Tensor | int | tuple]: | |
| """ | |
| Direct implementation of Algorithm 1 on the last dimension. | |
| Returns a packed HiF4 representation: | |
| e6m2: (..., n_blocks) uint8 | |
| e1_8_packed: (..., n_blocks, 1) uint8 | |
| e1_16_packed: (..., n_blocks, 2) uint8 | |
| s1p2_packed: (..., n_blocks, 32) uint8 | |
| """ | |
| # The paper starts from BF16. If x is not BF16 already, cast first. | |
| x = x.to(torch.bfloat16).to(torch.float32) | |
| blocks, original_last_dim, pad = _reshape_into_64_blocks(x) | |
| abs_blocks = blocks.abs() | |
| # Stage 1: 64 -> 16 -> 8 -> 1 peak tree reduction | |
| v16 = abs_blocks.reshape(*blocks.shape[:-1], 16, 4).amax(dim=-1) | |
| v8 = v16.reshape(*v16.shape[:-1], 8, 2).amax(dim=-1) | |
| vmax = v8.amax(dim=-1) | |
| # Stage 2: hierarchical scales | |
| sf_bf16 = (vmax.to(torch.bfloat16) * _BF16_ONE_OVER_SEVEN.to(vmax.device)).to(torch.bfloat16).to(torch.float32) | |
| e6m2_code, e6m2_value = quantize_e6m2(sf_bf16) | |
| # The paper uses an E6M2_REC_to_BF16 instruction. | |
| # In software we use the reciprocal of the quantized E6M2 value and cast to BF16. | |
| e6m2_recip = (1.0 / e6m2_value).to(torch.bfloat16).to(torch.float32) | |
| e1_8 = (v8 * e6m2_recip.unsqueeze(-1) >= 4.0).to(torch.uint8) | |
| e1_8_for_v16 = e1_8.repeat_interleave(2, dim=-1) | |
| lvl2_downscale = torch.where( | |
| e1_8_for_v16.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| e1_16 = (v16 * e6m2_recip.unsqueeze(-1) * lvl2_downscale >= 2.0).to(torch.uint8) | |
| # Stage 3: scale original values down, then quantize payload to S1P2 | |
| e1_8_for_v64 = e1_8.repeat_interleave(8, dim=-1) | |
| e1_16_for_v64 = e1_16.repeat_interleave(4, dim=-1) | |
| scaled = blocks * e6m2_recip.unsqueeze(-1) | |
| scaled = scaled * torch.where( | |
| e1_8_for_v64.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| scaled = scaled * torch.where( | |
| e1_16_for_v64.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| s1p2_code, _ = quantize_s1p2(scaled) | |
| return { | |
| "e6m2": e6m2_code.contiguous(), | |
| "e1_8_packed": pack_bits(e1_8).contiguous(), # 8 bits -> 1 byte | |
| "e1_16_packed": pack_bits(e1_16).contiguous(), # 16 bits -> 2 bytes | |
| "s1p2_packed": pack_nibbles(s1p2_code).contiguous(), # 64 x 4b -> 32 bytes | |
| "original_last_dim": original_last_dim, | |
| "padded_last_dim": original_last_dim + pad, | |
| "shape_prefix": tuple(x.shape[:-1]), | |
| } | |
| @torch.no_grad() | |
| def dequantize_from_hif4_blocks(pack: dict[str, torch.Tensor | int | tuple]) -> torch.Tensor: | |
| e6m2 = decode_e6m2(pack["e6m2"]).to(torch.float32) | |
| e1_8 = unpack_bits(pack["e1_8_packed"], 8) | |
| e1_16 = unpack_bits(pack["e1_16_packed"], 16) | |
| s1p2_codes = unpack_nibbles(pack["s1p2_packed"], 64) | |
| s1p2 = decode_s1p2(s1p2_codes) | |
| e1_8_full = e1_8.repeat_interleave(8, dim=-1).to(torch.float32) | |
| e1_16_full = e1_16.repeat_interleave(4, dim=-1).to(torch.float32) | |
| blocks = e6m2.unsqueeze(-1) * torch.pow(2.0, e1_8_full + e1_16_full) * s1p2 | |
| out = blocks.reshape(*pack["shape_prefix"], pack["padded_last_dim"]) | |
| return out[..., : pack["original_last_dim"]] | |
| class HiF4Linear(nn.Module): | |
| """ | |
| Reference nn.Linear replacement that stores its weight in packed HiF4 form. | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.original_in_features = in_features | |
| self.padded_in_features = in_features | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False) | |
| else: | |
| self.register_parameter("bias", None) | |
| @classmethod | |
| @torch.no_grad() | |
| def from_float(cls, linear: nn.Linear) -> "HiF4Linear": | |
| mod = cls(linear.in_features, linear.out_features, bias=linear.bias is not None) | |
| q = quantize_to_hif4_blocks(linear.weight.detach()) | |
| mod.original_in_features = q["original_last_dim"] | |
| mod.padded_in_features = q["padded_last_dim"] | |
| mod.register_buffer("e6m2", q["e6m2"]) | |
| mod.register_buffer("e1_8_packed", q["e1_8_packed"]) | |
| mod.register_buffer("e1_16_packed", q["e1_16_packed"]) | |
| mod.register_buffer("s1p2_packed", q["s1p2_packed"]) | |
| if linear.bias is not None: | |
| mod.bias = nn.Parameter(linear.bias.detach().clone(), requires_grad=False) | |
| return mod | |
| @torch.no_grad() | |
| def dequantize_weight(self, dtype: torch.dtype | None = None) -> torch.Tensor: | |
| pack = { | |
| "e6m2": self.e6m2, | |
| "e1_8_packed": self.e1_8_packed, | |
| "e1_16_packed": self.e1_16_packed, | |
| "s1p2_packed": self.s1p2_packed, | |
| "original_last_dim": self.original_in_features, | |
| "padded_last_dim": self.padded_in_features, | |
| "shape_prefix": (self.out_features,), | |
| } | |
| weight = dequantize_from_hif4_blocks(pack) | |
| if dtype is not None: | |
| weight = weight.to(dtype) | |
| return weight | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.dequantize_weight(dtype=x.dtype if x.is_floating_point() else torch.float32) | |
| bias = self.bias | |
| if bias is not None and bias.dtype != weight.dtype: | |
| bias = bias.to(weight.dtype) | |
| return F.linear(x, weight, bias) | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"bias={self.bias is not None}, format=HiF4" | |
| ) | |
| def _should_convert_linear( | |
| full_name: str, | |
| module: nn.Module, | |
| skip_name_keywords: tuple[str, ...], | |
| ) -> bool: | |
| if not isinstance(module, nn.Linear): | |
| return False | |
| lname = full_name.lower() | |
| return not any(keyword in lname for keyword in skip_name_keywords) | |
| @torch.no_grad() | |
| def convert_hf_model_to_hif4( | |
| model: nn.Module, | |
| skip_name_keywords: tuple[str, ...] = ( | |
| "embed", | |
| "embedding", | |
| "lm_head", | |
| "gate", | |
| "gating", | |
| "router", | |
| ), | |
| verbose: bool = True, | |
| ) -> nn.Module: | |
| """ | |
| Replaces eligible nn.Linear modules with HiF4Linear. | |
| Default skips: | |
| - embeddings / embedding | |
| - lm_head | |
| - gate / gating / router | |
| That matches the paper's setup for ordinary LLMs | |
| (skip embedding + output head) and is also safe for MoE models | |
| where gating/router layers should usually remain higher precision. | |
| """ | |
| replaced = [] | |
| def _convert(parent: nn.Module, prefix: str = "") -> None: | |
| for child_name, child in list(parent.named_children()): | |
| full_name = f"{prefix}.{child_name}" if prefix else child_name | |
| if _should_convert_linear(full_name, child, skip_name_keywords): | |
| setattr(parent, child_name, HiF4Linear.from_float(child)) | |
| replaced.append(full_name) | |
| else: | |
| _convert(child, full_name) | |
| _convert(model) | |
| if verbose: | |
| print(f"[HiF4] replaced {len(replaced)} Linear layers") | |
| for name in replaced[:20]: | |
| print(f" - {name}") | |
| if len(replaced) > 20: | |
| print(f" ... and {len(replaced) - 20} more") | |
| return model | |
| @torch.no_grad() | |
| def load_hf_model_and_convert_to_hif4( | |
| model_name_or_path: str, | |
| device: str = "cuda", | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| trust_remote_code: bool = False, | |
| ): | |
| """ | |
| Load a Hugging Face causal LM, then replace eligible Linear layers with HiF4Linear. | |
| """ | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model.eval() | |
| model = convert_hf_model_to_hif4(model) | |
| if device: | |
| model = model.to(device) | |
| return model, tokenizer | |
| # ----------------------------------------------------------------------------- | |
| # Example | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Small smoke test without downloading a model: | |
| linear = nn.Linear(130, 16, bias=True).eval() | |
| hif4_linear = HiF4Linear.from_float(linear) | |
| x = torch.randn(4, 130) | |
| y_ref = linear(x) | |
| y_hif4 = hif4_linear(x) | |
| mse = (y_ref - y_hif4).pow(2).mean().item() | |
| print("smoke-test MSE:", mse) | |
| # Hugging Face usage: | |
| # | |
| # model_id = "meta-llama/Llama-3.1-8B" | |
| # model, tokenizer = load_hf_model_and_convert_to_hif4( | |
| # model_id, | |
| # device="cuda", | |
| # torch_dtype=torch.bfloat16, | |
| # ) | |
| # | |
| # prompt = "Explain block floating point in one paragraph." | |
| # inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # with torch.no_grad(): | |
| # out = model.generate(**inputs, max_new_tokens=64) | |
| # print(tokenizer.decode(out[0], skip_special_tokens=True)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ----------------------------------------------------------------------------- | |
| # HiF4 reference implementation | |
| # Paper assumptions used here: | |
| # - 1 HiF4 block = 64 values | |
| # - metadata = 1x E6M2 + 8x E1_8 + 16x E1_16 | |
| # - payload = 64x S1P2 | |
| # - rounding mode below is round-half-away-from-zero (allowed by the paper) | |
| # | |
| # This is a reference/correctness implementation. | |
| # It stores weights in HiF4-packed form, then dequantizes in forward(). | |
| # It is not a fused high-performance kernel. | |
| # ----------------------------------------------------------------------------- | |
| _HIF4_BLOCK = 64 | |
| _BF16_ONE_OVER_SEVEN = torch.tensor(1.0 / 7.0, dtype=torch.bfloat16).float() | |
| def round_half_away_from_zero(x: torch.Tensor) -> torch.Tensor: | |
| return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) | |
| def pack_bits(bits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pack 0/1 bits along the last dimension into uint8 bytes. | |
| Bit order inside each byte is LSB-first. | |
| """ | |
| bits = bits.to(torch.uint8) | |
| pad = (-bits.shape[-1]) % 8 | |
| if pad: | |
| bits = F.pad(bits, (0, pad)) | |
| bits = bits.reshape(*bits.shape[:-1], -1, 8) | |
| shifts = (1 << torch.arange(8, device=bits.device, dtype=torch.int64)).view( | |
| *([1] * (bits.ndim - 1)), 8 | |
| ) | |
| return (bits.to(torch.int64) * shifts).sum(dim=-1).to(torch.uint8) | |
| def unpack_bits(packed: torch.Tensor, nbits: int) -> torch.Tensor: | |
| shifts = torch.arange(8, device=packed.device, dtype=torch.int64) | |
| bits = ((packed.to(torch.int64).unsqueeze(-1) >> shifts) & 1).to(torch.uint8) | |
| return bits.reshape(*packed.shape[:-1], -1)[..., :nbits] | |
| def pack_nibbles(codes: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pack 4-bit codes along the last dimension. | |
| codes[..., 0] goes into low nibble, codes[..., 1] into high nibble. | |
| """ | |
| codes = codes.to(torch.uint8) | |
| assert codes.shape[-1] % 2 == 0 | |
| lo = codes[..., 0::2] | |
| hi = codes[..., 1::2] | |
| return (lo | (hi << 4)).to(torch.uint8) | |
| def unpack_nibbles(packed: torch.Tensor, count: int) -> torch.Tensor: | |
| lo = packed & 0x0F | |
| hi = (packed >> 4) & 0x0F | |
| out = torch.empty( | |
| *packed.shape[:-1], | |
| packed.shape[-1] * 2, | |
| dtype=torch.uint8, | |
| device=packed.device, | |
| ) | |
| out[..., 0::2] = lo | |
| out[..., 1::2] = hi | |
| return out[..., :count] | |
| def decode_e6m2(code: torch.Tensor) -> torch.Tensor: | |
| """ | |
| E6M2 is an unsigned 8-bit float with: | |
| exponent bits = 6, bias = 48 | |
| mantissa bits = 2, hidden leading 1 | |
| 0xFF reserved for NaN | |
| """ | |
| code_i = code.to(torch.int64) | |
| exp = (code_i >> 2) & 0x3F | |
| mant = code_i & 0x03 | |
| value = torch.ldexp( | |
| 1.0 + mant.to(torch.float32) / 4.0, | |
| (exp - 48).to(torch.int32), | |
| ) | |
| value = torch.where(code_i == 0xFF, torch.full_like(value, float("nan")), value) | |
| return value | |
| def quantize_e6m2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Quantize positive scales to E6M2. | |
| Returns: | |
| code: uint8 E6M2 code | |
| value: decoded float32 E6M2 value | |
| """ | |
| x = x.to(torch.float32) | |
| zero_mask = x <= 0 | |
| # E6M2 has no zero. For all-zero blocks we keep S1P2=0 and park the scale | |
| # at the minimum normal value so dequantization still returns exact zeros. | |
| x_safe = torch.where(zero_mask, torch.full_like(x, 2.0 ** -48), x) | |
| exp_unbiased = torch.floor(torch.log2(x_safe)).to(torch.int64) | |
| significand = x_safe / torch.pow( | |
| torch.tensor(2.0, device=x.device), exp_unbiased.to(torch.float32) | |
| ) # in [1, 2) | |
| mant = round_half_away_from_zero((significand - 1.0) * 4.0).to(torch.int64) | |
| # carry if mant rounded to 4 | |
| carry = mant == 4 | |
| exp_unbiased = exp_unbiased + carry.to(torch.int64) | |
| mant = torch.where(carry, torch.zeros_like(mant), mant) | |
| # clamp to representable range | |
| under = exp_unbiased < -48 | |
| exp_unbiased = torch.where(under, torch.full_like(exp_unbiased, -48), exp_unbiased) | |
| mant = torch.where(under, torch.zeros_like(mant), mant) | |
| over = exp_unbiased > 15 | |
| exp_unbiased = torch.where(over, torch.full_like(exp_unbiased, 15), exp_unbiased) | |
| mant = torch.where(over, torch.full_like(mant, 2), mant) | |
| # 0xFF is NaN, so clip (exp=15, mant=3) down to max finite (mant=2) | |
| mant = torch.where( | |
| (exp_unbiased == 15) & (mant > 2), | |
| torch.full_like(mant, 2), | |
| mant, | |
| ) | |
| exp_code = (exp_unbiased + 48).clamp(0, 63) | |
| code = ((exp_code << 2) | mant).to(torch.uint8) | |
| # all-zero block special case | |
| code = torch.where(zero_mask, torch.zeros_like(code), code) | |
| return code, decode_e6m2(code) | |
| def decode_s1p2(code: torch.Tensor) -> torch.Tensor: | |
| """ | |
| S1P2 sign-magnitude decoding: | |
| sign bit = bit 3 | |
| magnitude = bits [2:0] interpreted in quarter steps | |
| representable set = {0.00, 0.25, ..., 1.75} with sign | |
| """ | |
| code_i = code.to(torch.int64) | |
| sign = torch.where(((code_i >> 3) & 1) == 1, -1.0, 1.0) | |
| mag = (code_i & 0x7).to(torch.float32) / 4.0 | |
| return sign * mag | |
| def quantize_s1p2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| x = x.to(torch.float32).clamp(-1.75, 1.75) | |
| sign_bit = (x < 0).to(torch.uint8) | |
| mag = round_half_away_from_zero(torch.abs(x) * 4.0).clamp_(0, 7).to(torch.uint8) | |
| # Canonicalize both +0 and -0 to +0 in software. | |
| sign_bit = torch.where(mag == 0, torch.zeros_like(sign_bit), sign_bit) | |
| code = ((sign_bit << 3) | mag).to(torch.uint8) | |
| return code, decode_s1p2(code) | |
| def _reshape_into_64_blocks(x: torch.Tensor) -> tuple[torch.Tensor, int, int]: | |
| """ | |
| Groups the last dimension into contiguous 64-value blocks. | |
| Pads with zeros if needed. | |
| """ | |
| original_last_dim = x.shape[-1] | |
| pad = (-original_last_dim) % _HIF4_BLOCK | |
| if pad: | |
| x = F.pad(x, (0, pad)) | |
| x = x.reshape(*x.shape[:-1], x.shape[-1] // _HIF4_BLOCK, _HIF4_BLOCK) | |
| return x, original_last_dim, pad | |
| @torch.no_grad() | |
| def quantize_to_hif4_blocks(x: torch.Tensor) -> dict[str, torch.Tensor | int | tuple]: | |
| """ | |
| Direct implementation of Algorithm 1 on the last dimension. | |
| Returns a packed HiF4 representation: | |
| e6m2: (..., n_blocks) uint8 | |
| e1_8_packed: (..., n_blocks, 1) uint8 | |
| e1_16_packed: (..., n_blocks, 2) uint8 | |
| s1p2_packed: (..., n_blocks, 32) uint8 | |
| """ | |
| # The paper starts from BF16. If x is not BF16 already, cast first. | |
| x = x.to(torch.bfloat16).to(torch.float32) | |
| blocks, original_last_dim, pad = _reshape_into_64_blocks(x) | |
| abs_blocks = blocks.abs() | |
| # Stage 1: 64 -> 16 -> 8 -> 1 peak tree reduction | |
| v16 = abs_blocks.reshape(*blocks.shape[:-1], 16, 4).amax(dim=-1) | |
| v8 = v16.reshape(*v16.shape[:-1], 8, 2).amax(dim=-1) | |
| vmax = v8.amax(dim=-1) | |
| # Stage 2: hierarchical scales | |
| sf_bf16 = (vmax.to(torch.bfloat16) * _BF16_ONE_OVER_SEVEN.to(vmax.device)).to(torch.bfloat16).to(torch.float32) | |
| e6m2_code, e6m2_value = quantize_e6m2(sf_bf16) | |
| # The paper uses an E6M2_REC_to_BF16 instruction. | |
| # In software we use the reciprocal of the quantized E6M2 value and cast to BF16. | |
| e6m2_recip = (1.0 / e6m2_value).to(torch.bfloat16).to(torch.float32) | |
| e1_8 = (v8 * e6m2_recip.unsqueeze(-1) >= 4.0).to(torch.uint8) | |
| e1_8_for_v16 = e1_8.repeat_interleave(2, dim=-1) | |
| lvl2_downscale = torch.where( | |
| e1_8_for_v16.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| e1_16 = (v16 * e6m2_recip.unsqueeze(-1) * lvl2_downscale >= 2.0).to(torch.uint8) | |
| # Stage 3: scale original values down, then quantize payload to S1P2 | |
| e1_8_for_v64 = e1_8.repeat_interleave(8, dim=-1) | |
| e1_16_for_v64 = e1_16.repeat_interleave(4, dim=-1) | |
| scaled = blocks * e6m2_recip.unsqueeze(-1) | |
| scaled = scaled * torch.where( | |
| e1_8_for_v64.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| scaled = scaled * torch.where( | |
| e1_16_for_v64.bool(), | |
| torch.tensor(0.5, device=x.device, dtype=torch.float32), | |
| torch.tensor(1.0, device=x.device, dtype=torch.float32), | |
| ) | |
| s1p2_code, _ = quantize_s1p2(scaled) | |
| return { | |
| "e6m2": e6m2_code.contiguous(), | |
| "e1_8_packed": pack_bits(e1_8).contiguous(), # 8 bits -> 1 byte | |
| "e1_16_packed": pack_bits(e1_16).contiguous(), # 16 bits -> 2 bytes | |
| "s1p2_packed": pack_nibbles(s1p2_code).contiguous(), # 64 x 4b -> 32 bytes | |
| "original_last_dim": original_last_dim, | |
| "padded_last_dim": original_last_dim + pad, | |
| "shape_prefix": tuple(x.shape[:-1]), | |
| } | |
| @torch.no_grad() | |
| def dequantize_from_hif4_blocks(pack: dict[str, torch.Tensor | int | tuple]) -> torch.Tensor: | |
| e6m2 = decode_e6m2(pack["e6m2"]).to(torch.float32) | |
| e1_8 = unpack_bits(pack["e1_8_packed"], 8) | |
| e1_16 = unpack_bits(pack["e1_16_packed"], 16) | |
| s1p2_codes = unpack_nibbles(pack["s1p2_packed"], 64) | |
| s1p2 = decode_s1p2(s1p2_codes) | |
| e1_8_full = e1_8.repeat_interleave(8, dim=-1).to(torch.float32) | |
| e1_16_full = e1_16.repeat_interleave(4, dim=-1).to(torch.float32) | |
| blocks = e6m2.unsqueeze(-1) * torch.pow(2.0, e1_8_full + e1_16_full) * s1p2 | |
| out = blocks.reshape(*pack["shape_prefix"], pack["padded_last_dim"]) | |
| return out[..., : pack["original_last_dim"]] | |
| class HiF4Linear(nn.Module): | |
| """ | |
| Reference nn.Linear replacement that stores its weight in packed HiF4 form. | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.original_in_features = in_features | |
| self.padded_in_features = in_features | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False) | |
| else: | |
| self.register_parameter("bias", None) | |
| @classmethod | |
| @torch.no_grad() | |
| def empty( | |
| cls, | |
| in_features: int, | |
| out_features: int, | |
| bias: bool = True, | |
| device: torch.device | str | None = None, | |
| ) -> "HiF4Linear": | |
| mod = cls(in_features, out_features, bias=bias) | |
| n_blocks = (in_features + _HIF4_BLOCK - 1) // _HIF4_BLOCK | |
| padded_in_features = n_blocks * _HIF4_BLOCK | |
| mod.original_in_features = in_features | |
| mod.padded_in_features = padded_in_features | |
| kwargs = {} | |
| if device is not None: | |
| kwargs["device"] = device | |
| mod.register_buffer("e6m2", torch.zeros(out_features, n_blocks, dtype=torch.uint8, **kwargs)) | |
| mod.register_buffer("e1_8_packed", torch.zeros(out_features, n_blocks, 1, dtype=torch.uint8, **kwargs)) | |
| mod.register_buffer("e1_16_packed", torch.zeros(out_features, n_blocks, 2, dtype=torch.uint8, **kwargs)) | |
| mod.register_buffer("s1p2_packed", torch.zeros(out_features, n_blocks, 32, dtype=torch.uint8, **kwargs)) | |
| if bias: | |
| mod.bias = nn.Parameter(torch.zeros(out_features, **kwargs), requires_grad=False) | |
| return mod | |
| @classmethod | |
| @torch.no_grad() | |
| def from_float(cls, linear: nn.Linear) -> "HiF4Linear": | |
| mod = cls(linear.in_features, linear.out_features, bias=linear.bias is not None) | |
| q = quantize_to_hif4_blocks(linear.weight.detach()) | |
| mod.original_in_features = q["original_last_dim"] | |
| mod.padded_in_features = q["padded_last_dim"] | |
| mod.register_buffer("e6m2", q["e6m2"]) | |
| mod.register_buffer("e1_8_packed", q["e1_8_packed"]) | |
| mod.register_buffer("e1_16_packed", q["e1_16_packed"]) | |
| mod.register_buffer("s1p2_packed", q["s1p2_packed"]) | |
| if linear.bias is not None: | |
| mod.bias = nn.Parameter(linear.bias.detach().clone(), requires_grad=False) | |
| return mod | |
| @torch.no_grad() | |
| def dequantize_weight(self, dtype: torch.dtype | None = None) -> torch.Tensor: | |
| pack = { | |
| "e6m2": self.e6m2, | |
| "e1_8_packed": self.e1_8_packed, | |
| "e1_16_packed": self.e1_16_packed, | |
| "s1p2_packed": self.s1p2_packed, | |
| "original_last_dim": self.original_in_features, | |
| "padded_last_dim": self.padded_in_features, | |
| "shape_prefix": (self.out_features,), | |
| } | |
| weight = dequantize_from_hif4_blocks(pack) | |
| if dtype is not None: | |
| weight = weight.to(dtype) | |
| return weight | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.dequantize_weight(dtype=x.dtype if x.is_floating_point() else torch.float32) | |
| bias = self.bias | |
| if bias is not None and bias.dtype != weight.dtype: | |
| bias = bias.to(weight.dtype) | |
| return F.linear(x, weight, bias) | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"bias={self.bias is not None}, format=HiF4" | |
| ) | |
| def _should_convert_linear( | |
| full_name: str, | |
| module: nn.Module, | |
| skip_name_keywords: tuple[str, ...], | |
| ) -> bool: | |
| if not isinstance(module, nn.Linear): | |
| return False | |
| lname = full_name.lower() | |
| return not any(keyword in lname for keyword in skip_name_keywords) | |
| @torch.no_grad() | |
| def convert_hf_model_to_hif4( | |
| model: nn.Module, | |
| skip_name_keywords: tuple[str, ...] = ( | |
| "embed", | |
| "embedding", | |
| "lm_head", | |
| "gate", | |
| "gating", | |
| "router", | |
| ), | |
| verbose: bool = True, | |
| structural_only: bool = False, | |
| ) -> nn.Module: | |
| """ | |
| Replaces eligible nn.Linear modules with HiF4Linear. | |
| If structural_only=True, the module topology is converted without | |
| quantizing the current weights. That is useful when reconstructing a | |
| model structure before loading a previously saved HiF4 safetensors file. | |
| Default skips: | |
| - embeddings / embedding | |
| - lm_head | |
| - gate / gating / router | |
| That matches the paper's setup for ordinary LLMs | |
| (skip embedding + output head) and is also safe for MoE models | |
| where gating/router layers should usually remain higher precision. | |
| """ | |
| replaced = [] | |
| def _convert(parent: nn.Module, prefix: str = "") -> None: | |
| for child_name, child in list(parent.named_children()): | |
| full_name = f"{prefix}.{child_name}" if prefix else child_name | |
| if _should_convert_linear(full_name, child, skip_name_keywords): | |
| if structural_only: | |
| replacement = HiF4Linear.empty( | |
| child.in_features, | |
| child.out_features, | |
| bias=child.bias is not None, | |
| device=child.weight.device, | |
| ) | |
| else: | |
| replacement = HiF4Linear.from_float(child) | |
| setattr(parent, child_name, replacement) | |
| replaced.append(full_name) | |
| else: | |
| _convert(child, full_name) | |
| _convert(model) | |
| if verbose: | |
| print(f"[HiF4] replaced {len(replaced)} Linear layers") | |
| for name in replaced[:20]: | |
| print(f" - {name}") | |
| if len(replaced) > 20: | |
| print(f" ... and {len(replaced) - 20} more") | |
| return model | |
| @torch.no_grad() | |
| def load_hf_model_and_convert_to_hif4( | |
| model_name_or_path: str, | |
| device: str = "cuda", | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| trust_remote_code: bool = False, | |
| ): | |
| """ | |
| Load a Hugging Face causal LM, then replace eligible Linear layers with HiF4Linear. | |
| """ | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model.eval() | |
| model = convert_hf_model_to_hif4(model) | |
| if device: | |
| model = model.to(device) | |
| return model, tokenizer | |
| def _clone_shared_tensors_for_safetensors( | |
| state_dict: dict[str, torch.Tensor], | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| safetensors stores plain tensors only. Some HF models tie weights | |
| (for example embeddings <-> lm_head), which means multiple state_dict | |
| entries may share the same underlying storage. We clone later aliases | |
| so save_file() can serialize them safely as independent tensors. | |
| """ | |
| out: dict[str, torch.Tensor] = {} | |
| seen_storages: set[tuple[int, int]] = set() | |
| for key, value in state_dict.items(): | |
| tensor = value.detach() | |
| storage = tensor.untyped_storage() | |
| storage_key = (storage.data_ptr(), storage.nbytes()) | |
| tensor = tensor.cpu().contiguous() | |
| if storage_key in seen_storages: | |
| tensor = tensor.clone() | |
| seen_storages.add(storage_key) | |
| out[key] = tensor | |
| return out | |
| @torch.no_grad() | |
| def save_hif4_model_as_safetensors( | |
| model: nn.Module, | |
| save_directory: str | Path, | |
| tokenizer=None, | |
| skip_name_keywords: tuple[str, ...] = ( | |
| "embed", | |
| "embedding", | |
| "lm_head", | |
| "gate", | |
| "gating", | |
| "router", | |
| ), | |
| filename: str = "model.safetensors", | |
| ) -> Path: | |
| """ | |
| Save a converted HiF4 model into a Hugging Face-style directory. | |
| Saved files: | |
| - model.safetensors | |
| - config.json (if model has .config) | |
| - tokenizer files (if tokenizer is given) | |
| - hif4_config.json (stores HiF4-specific reconstruction metadata) | |
| Note: | |
| Because HiF4Linear is a custom module, loading requires rebuilding | |
| the model structure with convert_hf_model_to_hif4(..., structural_only=True) | |
| before loading the safetensors weights. | |
| """ | |
| from safetensors.torch import save_file | |
| save_directory = Path(save_directory) | |
| save_directory.mkdir(parents=True, exist_ok=True) | |
| state_dict = _clone_shared_tensors_for_safetensors(model.state_dict()) | |
| out_path = save_directory / filename | |
| save_file( | |
| state_dict, | |
| str(out_path), | |
| metadata={ | |
| "format": "pt", | |
| "quantization": "HiF4", | |
| }, | |
| ) | |
| if hasattr(model, "config") and model.config is not None: | |
| model.config.save_pretrained(save_directory) | |
| if tokenizer is not None: | |
| tokenizer.save_pretrained(save_directory) | |
| with open(save_directory / "hif4_config.json", "w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "format": "HiF4", | |
| "filename": filename, | |
| "skip_name_keywords": list(skip_name_keywords), | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| return out_path | |
| @torch.no_grad() | |
| def load_hif4_model_from_safetensors( | |
| model_directory: str | Path, | |
| device: str = "cuda", | |
| torch_dtype: torch.dtype | None = torch.bfloat16, | |
| trust_remote_code: bool = False, | |
| ): | |
| """ | |
| Reconstruct a HF causal LM with HiF4Linear modules, then load the | |
| previously saved safetensors checkpoint. | |
| """ | |
| from safetensors.torch import load_file | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| model_directory = Path(model_directory) | |
| manifest_path = model_directory / "hif4_config.json" | |
| if manifest_path.exists(): | |
| with open(manifest_path, "r", encoding="utf-8") as f: | |
| manifest = json.load(f) | |
| skip_name_keywords = tuple( | |
| manifest.get( | |
| "skip_name_keywords", | |
| ["embed", "embedding", "lm_head", "gate", "gating", "router"], | |
| ) | |
| ) | |
| filename = manifest.get("filename", "model.safetensors") | |
| else: | |
| skip_name_keywords = ("embed", "embedding", "lm_head", "gate", "gating", "router") | |
| filename = "model.safetensors" | |
| tokenizer = None | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_directory, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| except Exception: | |
| pass | |
| config = AutoConfig.from_pretrained( | |
| model_directory, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model = AutoModelForCausalLM.from_config( | |
| config, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| model = convert_hf_model_to_hif4( | |
| model, | |
| skip_name_keywords=skip_name_keywords, | |
| verbose=False, | |
| structural_only=True, | |
| ) | |
| state_dict = load_file(str(model_directory / filename)) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=True) | |
| if missing or unexpected: | |
| raise RuntimeError( | |
| f"State-dict mismatch while loading HiF4 safetensors. " | |
| f"Missing={missing}, unexpected={unexpected}" | |
| ) | |
| model.eval() | |
| if torch_dtype is not None: | |
| model = model.to(dtype=torch_dtype) | |
| if device: | |
| model = model.to(device) | |
| return model, tokenizer | |
| # ----------------------------------------------------------------------------- | |
| # Example | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Small smoke test without downloading a model: | |
| linear = nn.Linear(130, 16, bias=True).eval() | |
| hif4_linear = HiF4Linear.from_float(linear) | |
| x = torch.randn(4, 130) | |
| y_ref = linear(x) | |
| y_hif4 = hif4_linear(x) | |
| mse = (y_ref - y_hif4).pow(2).mean().item() | |
| print("smoke-test MSE:", mse) | |
| # Hugging Face usage: | |
| # | |
| # model_id = "meta-llama/Llama-3.1-8B" | |
| # model, tokenizer = load_hf_model_and_convert_to_hif4( | |
| # model_id, | |
| # device="cuda", | |
| # torch_dtype=torch.bfloat16, | |
| # ) | |
| # | |
| # prompt = "Explain block floating point in one paragraph." | |
| # inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # with torch.no_grad(): | |
| # out = model.generate(**inputs, max_new_tokens=64) | |
| # print(tokenizer.decode(out[0], skip_special_tokens=True)) | |
| # | |
| # save_hif4_model_as_safetensors(model, "./llama_hif4", tokenizer=tokenizer) | |
| # model2, tokenizer2 = load_hif4_model_from_safetensors("./llama_hif4") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment