Skip to content

Instantly share code, notes, and snippets.

@egorsmkv
Last active April 13, 2026 10:58
Show Gist options
  • Select an option

  • Save egorsmkv/ded71fb20a1e878a4f923f88b7b8948f to your computer and use it in GitHub Desktop.

Select an option

Save egorsmkv/ded71fb20a1e878a4f923f88b7b8948f to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# HiF4 reference implementation
# Paper assumptions used here:
# - 1 HiF4 block = 64 values
# - metadata = 1x E6M2 + 8x E1_8 + 16x E1_16
# - payload = 64x S1P2
# - rounding mode below is round-half-away-from-zero (allowed by the paper)
#
# This is a reference/correctness implementation.
# It stores weights in HiF4-packed form, then dequantizes in forward().
# It is not a fused high-performance kernel.
# -----------------------------------------------------------------------------
_HIF4_BLOCK = 64
_BF16_ONE_OVER_SEVEN = torch.tensor(1.0 / 7.0, dtype=torch.bfloat16).float()
def round_half_away_from_zero(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.floor(torch.abs(x) + 0.5)
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
"""
Pack 0/1 bits along the last dimension into uint8 bytes.
Bit order inside each byte is LSB-first.
"""
bits = bits.to(torch.uint8)
pad = (-bits.shape[-1]) % 8
if pad:
bits = F.pad(bits, (0, pad))
bits = bits.reshape(*bits.shape[:-1], -1, 8)
shifts = (1 << torch.arange(8, device=bits.device, dtype=torch.int64)).view(
*([1] * (bits.ndim - 1)), 8
)
return (bits.to(torch.int64) * shifts).sum(dim=-1).to(torch.uint8)
def unpack_bits(packed: torch.Tensor, nbits: int) -> torch.Tensor:
shifts = torch.arange(8, device=packed.device, dtype=torch.int64)
bits = ((packed.to(torch.int64).unsqueeze(-1) >> shifts) & 1).to(torch.uint8)
return bits.reshape(*packed.shape[:-1], -1)[..., :nbits]
def pack_nibbles(codes: torch.Tensor) -> torch.Tensor:
"""
Pack 4-bit codes along the last dimension.
codes[..., 0] goes into low nibble, codes[..., 1] into high nibble.
"""
codes = codes.to(torch.uint8)
assert codes.shape[-1] % 2 == 0
lo = codes[..., 0::2]
hi = codes[..., 1::2]
return (lo | (hi << 4)).to(torch.uint8)
def unpack_nibbles(packed: torch.Tensor, count: int) -> torch.Tensor:
lo = packed & 0x0F
hi = (packed >> 4) & 0x0F
out = torch.empty(
*packed.shape[:-1],
packed.shape[-1] * 2,
dtype=torch.uint8,
device=packed.device,
)
out[..., 0::2] = lo
out[..., 1::2] = hi
return out[..., :count]
def decode_e6m2(code: torch.Tensor) -> torch.Tensor:
"""
E6M2 is an unsigned 8-bit float with:
exponent bits = 6, bias = 48
mantissa bits = 2, hidden leading 1
0xFF reserved for NaN
"""
code_i = code.to(torch.int64)
exp = (code_i >> 2) & 0x3F
mant = code_i & 0x03
value = torch.ldexp(
1.0 + mant.to(torch.float32) / 4.0,
(exp - 48).to(torch.int32),
)
value = torch.where(code_i == 0xFF, torch.full_like(value, float("nan")), value)
return value
def quantize_e6m2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize positive scales to E6M2.
Returns:
code: uint8 E6M2 code
value: decoded float32 E6M2 value
"""
x = x.to(torch.float32)
zero_mask = x <= 0
# E6M2 has no zero. For all-zero blocks we keep S1P2=0 and park the scale
# at the minimum normal value so dequantization still returns exact zeros.
x_safe = torch.where(zero_mask, torch.full_like(x, 2.0 ** -48), x)
exp_unbiased = torch.floor(torch.log2(x_safe)).to(torch.int64)
significand = x_safe / torch.pow(
torch.tensor(2.0, device=x.device), exp_unbiased.to(torch.float32)
) # in [1, 2)
mant = round_half_away_from_zero((significand - 1.0) * 4.0).to(torch.int64)
# carry if mant rounded to 4
carry = mant == 4
exp_unbiased = exp_unbiased + carry.to(torch.int64)
mant = torch.where(carry, torch.zeros_like(mant), mant)
# clamp to representable range
under = exp_unbiased < -48
exp_unbiased = torch.where(under, torch.full_like(exp_unbiased, -48), exp_unbiased)
mant = torch.where(under, torch.zeros_like(mant), mant)
over = exp_unbiased > 15
exp_unbiased = torch.where(over, torch.full_like(exp_unbiased, 15), exp_unbiased)
mant = torch.where(over, torch.full_like(mant, 2), mant)
# 0xFF is NaN, so clip (exp=15, mant=3) down to max finite (mant=2)
mant = torch.where(
(exp_unbiased == 15) & (mant > 2),
torch.full_like(mant, 2),
mant,
)
exp_code = (exp_unbiased + 48).clamp(0, 63)
code = ((exp_code << 2) | mant).to(torch.uint8)
# all-zero block special case
code = torch.where(zero_mask, torch.zeros_like(code), code)
return code, decode_e6m2(code)
def decode_s1p2(code: torch.Tensor) -> torch.Tensor:
"""
S1P2 sign-magnitude decoding:
sign bit = bit 3
magnitude = bits [2:0] interpreted in quarter steps
representable set = {0.00, 0.25, ..., 1.75} with sign
"""
code_i = code.to(torch.int64)
sign = torch.where(((code_i >> 3) & 1) == 1, -1.0, 1.0)
mag = (code_i & 0x7).to(torch.float32) / 4.0
return sign * mag
def quantize_s1p2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = x.to(torch.float32).clamp(-1.75, 1.75)
sign_bit = (x < 0).to(torch.uint8)
mag = round_half_away_from_zero(torch.abs(x) * 4.0).clamp_(0, 7).to(torch.uint8)
# Canonicalize both +0 and -0 to +0 in software.
sign_bit = torch.where(mag == 0, torch.zeros_like(sign_bit), sign_bit)
code = ((sign_bit << 3) | mag).to(torch.uint8)
return code, decode_s1p2(code)
def _reshape_into_64_blocks(x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
"""
Groups the last dimension into contiguous 64-value blocks.
Pads with zeros if needed.
"""
original_last_dim = x.shape[-1]
pad = (-original_last_dim) % _HIF4_BLOCK
if pad:
x = F.pad(x, (0, pad))
x = x.reshape(*x.shape[:-1], x.shape[-1] // _HIF4_BLOCK, _HIF4_BLOCK)
return x, original_last_dim, pad
@torch.no_grad()
def quantize_to_hif4_blocks(x: torch.Tensor) -> dict[str, torch.Tensor | int | tuple]:
"""
Direct implementation of Algorithm 1 on the last dimension.
Returns a packed HiF4 representation:
e6m2: (..., n_blocks) uint8
e1_8_packed: (..., n_blocks, 1) uint8
e1_16_packed: (..., n_blocks, 2) uint8
s1p2_packed: (..., n_blocks, 32) uint8
"""
# The paper starts from BF16. If x is not BF16 already, cast first.
x = x.to(torch.bfloat16).to(torch.float32)
blocks, original_last_dim, pad = _reshape_into_64_blocks(x)
abs_blocks = blocks.abs()
# Stage 1: 64 -> 16 -> 8 -> 1 peak tree reduction
v16 = abs_blocks.reshape(*blocks.shape[:-1], 16, 4).amax(dim=-1)
v8 = v16.reshape(*v16.shape[:-1], 8, 2).amax(dim=-1)
vmax = v8.amax(dim=-1)
# Stage 2: hierarchical scales
sf_bf16 = (vmax.to(torch.bfloat16) * _BF16_ONE_OVER_SEVEN.to(vmax.device)).to(torch.bfloat16).to(torch.float32)
e6m2_code, e6m2_value = quantize_e6m2(sf_bf16)
# The paper uses an E6M2_REC_to_BF16 instruction.
# In software we use the reciprocal of the quantized E6M2 value and cast to BF16.
e6m2_recip = (1.0 / e6m2_value).to(torch.bfloat16).to(torch.float32)
e1_8 = (v8 * e6m2_recip.unsqueeze(-1) >= 4.0).to(torch.uint8)
e1_8_for_v16 = e1_8.repeat_interleave(2, dim=-1)
lvl2_downscale = torch.where(
e1_8_for_v16.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
e1_16 = (v16 * e6m2_recip.unsqueeze(-1) * lvl2_downscale >= 2.0).to(torch.uint8)
# Stage 3: scale original values down, then quantize payload to S1P2
e1_8_for_v64 = e1_8.repeat_interleave(8, dim=-1)
e1_16_for_v64 = e1_16.repeat_interleave(4, dim=-1)
scaled = blocks * e6m2_recip.unsqueeze(-1)
scaled = scaled * torch.where(
e1_8_for_v64.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
scaled = scaled * torch.where(
e1_16_for_v64.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
s1p2_code, _ = quantize_s1p2(scaled)
return {
"e6m2": e6m2_code.contiguous(),
"e1_8_packed": pack_bits(e1_8).contiguous(), # 8 bits -> 1 byte
"e1_16_packed": pack_bits(e1_16).contiguous(), # 16 bits -> 2 bytes
"s1p2_packed": pack_nibbles(s1p2_code).contiguous(), # 64 x 4b -> 32 bytes
"original_last_dim": original_last_dim,
"padded_last_dim": original_last_dim + pad,
"shape_prefix": tuple(x.shape[:-1]),
}
@torch.no_grad()
def dequantize_from_hif4_blocks(pack: dict[str, torch.Tensor | int | tuple]) -> torch.Tensor:
e6m2 = decode_e6m2(pack["e6m2"]).to(torch.float32)
e1_8 = unpack_bits(pack["e1_8_packed"], 8)
e1_16 = unpack_bits(pack["e1_16_packed"], 16)
s1p2_codes = unpack_nibbles(pack["s1p2_packed"], 64)
s1p2 = decode_s1p2(s1p2_codes)
e1_8_full = e1_8.repeat_interleave(8, dim=-1).to(torch.float32)
e1_16_full = e1_16.repeat_interleave(4, dim=-1).to(torch.float32)
blocks = e6m2.unsqueeze(-1) * torch.pow(2.0, e1_8_full + e1_16_full) * s1p2
out = blocks.reshape(*pack["shape_prefix"], pack["padded_last_dim"])
return out[..., : pack["original_last_dim"]]
class HiF4Linear(nn.Module):
"""
Reference nn.Linear replacement that stores its weight in packed HiF4 form.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.original_in_features = in_features
self.padded_in_features = in_features
if bias:
self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False)
else:
self.register_parameter("bias", None)
@classmethod
@torch.no_grad()
def from_float(cls, linear: nn.Linear) -> "HiF4Linear":
mod = cls(linear.in_features, linear.out_features, bias=linear.bias is not None)
q = quantize_to_hif4_blocks(linear.weight.detach())
mod.original_in_features = q["original_last_dim"]
mod.padded_in_features = q["padded_last_dim"]
mod.register_buffer("e6m2", q["e6m2"])
mod.register_buffer("e1_8_packed", q["e1_8_packed"])
mod.register_buffer("e1_16_packed", q["e1_16_packed"])
mod.register_buffer("s1p2_packed", q["s1p2_packed"])
if linear.bias is not None:
mod.bias = nn.Parameter(linear.bias.detach().clone(), requires_grad=False)
return mod
@torch.no_grad()
def dequantize_weight(self, dtype: torch.dtype | None = None) -> torch.Tensor:
pack = {
"e6m2": self.e6m2,
"e1_8_packed": self.e1_8_packed,
"e1_16_packed": self.e1_16_packed,
"s1p2_packed": self.s1p2_packed,
"original_last_dim": self.original_in_features,
"padded_last_dim": self.padded_in_features,
"shape_prefix": (self.out_features,),
}
weight = dequantize_from_hif4_blocks(pack)
if dtype is not None:
weight = weight.to(dtype)
return weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.dequantize_weight(dtype=x.dtype if x.is_floating_point() else torch.float32)
bias = self.bias
if bias is not None and bias.dtype != weight.dtype:
bias = bias.to(weight.dtype)
return F.linear(x, weight, bias)
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, format=HiF4"
)
def _should_convert_linear(
full_name: str,
module: nn.Module,
skip_name_keywords: tuple[str, ...],
) -> bool:
if not isinstance(module, nn.Linear):
return False
lname = full_name.lower()
return not any(keyword in lname for keyword in skip_name_keywords)
@torch.no_grad()
def convert_hf_model_to_hif4(
model: nn.Module,
skip_name_keywords: tuple[str, ...] = (
"embed",
"embedding",
"lm_head",
"gate",
"gating",
"router",
),
verbose: bool = True,
) -> nn.Module:
"""
Replaces eligible nn.Linear modules with HiF4Linear.
Default skips:
- embeddings / embedding
- lm_head
- gate / gating / router
That matches the paper's setup for ordinary LLMs
(skip embedding + output head) and is also safe for MoE models
where gating/router layers should usually remain higher precision.
"""
replaced = []
def _convert(parent: nn.Module, prefix: str = "") -> None:
for child_name, child in list(parent.named_children()):
full_name = f"{prefix}.{child_name}" if prefix else child_name
if _should_convert_linear(full_name, child, skip_name_keywords):
setattr(parent, child_name, HiF4Linear.from_float(child))
replaced.append(full_name)
else:
_convert(child, full_name)
_convert(model)
if verbose:
print(f"[HiF4] replaced {len(replaced)} Linear layers")
for name in replaced[:20]:
print(f" - {name}")
if len(replaced) > 20:
print(f" ... and {len(replaced) - 20} more")
return model
@torch.no_grad()
def load_hf_model_and_convert_to_hif4(
model_name_or_path: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
trust_remote_code: bool = False,
):
"""
Load a Hugging Face causal LM, then replace eligible Linear layers with HiF4Linear.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
model.eval()
model = convert_hf_model_to_hif4(model)
if device:
model = model.to(device)
return model, tokenizer
# -----------------------------------------------------------------------------
# Example
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Small smoke test without downloading a model:
linear = nn.Linear(130, 16, bias=True).eval()
hif4_linear = HiF4Linear.from_float(linear)
x = torch.randn(4, 130)
y_ref = linear(x)
y_hif4 = hif4_linear(x)
mse = (y_ref - y_hif4).pow(2).mean().item()
print("smoke-test MSE:", mse)
# Hugging Face usage:
#
# model_id = "meta-llama/Llama-3.1-8B"
# model, tokenizer = load_hf_model_and_convert_to_hif4(
# model_id,
# device="cuda",
# torch_dtype=torch.bfloat16,
# )
#
# prompt = "Explain block floating point in one paragraph."
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# with torch.no_grad():
# out = model.generate(**inputs, max_new_tokens=64)
# print(tokenizer.decode(out[0], skip_special_tokens=True))
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# HiF4 reference implementation
# Paper assumptions used here:
# - 1 HiF4 block = 64 values
# - metadata = 1x E6M2 + 8x E1_8 + 16x E1_16
# - payload = 64x S1P2
# - rounding mode below is round-half-away-from-zero (allowed by the paper)
#
# This is a reference/correctness implementation.
# It stores weights in HiF4-packed form, then dequantizes in forward().
# It is not a fused high-performance kernel.
# -----------------------------------------------------------------------------
_HIF4_BLOCK = 64
_BF16_ONE_OVER_SEVEN = torch.tensor(1.0 / 7.0, dtype=torch.bfloat16).float()
def round_half_away_from_zero(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.floor(torch.abs(x) + 0.5)
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
"""
Pack 0/1 bits along the last dimension into uint8 bytes.
Bit order inside each byte is LSB-first.
"""
bits = bits.to(torch.uint8)
pad = (-bits.shape[-1]) % 8
if pad:
bits = F.pad(bits, (0, pad))
bits = bits.reshape(*bits.shape[:-1], -1, 8)
shifts = (1 << torch.arange(8, device=bits.device, dtype=torch.int64)).view(
*([1] * (bits.ndim - 1)), 8
)
return (bits.to(torch.int64) * shifts).sum(dim=-1).to(torch.uint8)
def unpack_bits(packed: torch.Tensor, nbits: int) -> torch.Tensor:
shifts = torch.arange(8, device=packed.device, dtype=torch.int64)
bits = ((packed.to(torch.int64).unsqueeze(-1) >> shifts) & 1).to(torch.uint8)
return bits.reshape(*packed.shape[:-1], -1)[..., :nbits]
def pack_nibbles(codes: torch.Tensor) -> torch.Tensor:
"""
Pack 4-bit codes along the last dimension.
codes[..., 0] goes into low nibble, codes[..., 1] into high nibble.
"""
codes = codes.to(torch.uint8)
assert codes.shape[-1] % 2 == 0
lo = codes[..., 0::2]
hi = codes[..., 1::2]
return (lo | (hi << 4)).to(torch.uint8)
def unpack_nibbles(packed: torch.Tensor, count: int) -> torch.Tensor:
lo = packed & 0x0F
hi = (packed >> 4) & 0x0F
out = torch.empty(
*packed.shape[:-1],
packed.shape[-1] * 2,
dtype=torch.uint8,
device=packed.device,
)
out[..., 0::2] = lo
out[..., 1::2] = hi
return out[..., :count]
def decode_e6m2(code: torch.Tensor) -> torch.Tensor:
"""
E6M2 is an unsigned 8-bit float with:
exponent bits = 6, bias = 48
mantissa bits = 2, hidden leading 1
0xFF reserved for NaN
"""
code_i = code.to(torch.int64)
exp = (code_i >> 2) & 0x3F
mant = code_i & 0x03
value = torch.ldexp(
1.0 + mant.to(torch.float32) / 4.0,
(exp - 48).to(torch.int32),
)
value = torch.where(code_i == 0xFF, torch.full_like(value, float("nan")), value)
return value
def quantize_e6m2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize positive scales to E6M2.
Returns:
code: uint8 E6M2 code
value: decoded float32 E6M2 value
"""
x = x.to(torch.float32)
zero_mask = x <= 0
# E6M2 has no zero. For all-zero blocks we keep S1P2=0 and park the scale
# at the minimum normal value so dequantization still returns exact zeros.
x_safe = torch.where(zero_mask, torch.full_like(x, 2.0 ** -48), x)
exp_unbiased = torch.floor(torch.log2(x_safe)).to(torch.int64)
significand = x_safe / torch.pow(
torch.tensor(2.0, device=x.device), exp_unbiased.to(torch.float32)
) # in [1, 2)
mant = round_half_away_from_zero((significand - 1.0) * 4.0).to(torch.int64)
# carry if mant rounded to 4
carry = mant == 4
exp_unbiased = exp_unbiased + carry.to(torch.int64)
mant = torch.where(carry, torch.zeros_like(mant), mant)
# clamp to representable range
under = exp_unbiased < -48
exp_unbiased = torch.where(under, torch.full_like(exp_unbiased, -48), exp_unbiased)
mant = torch.where(under, torch.zeros_like(mant), mant)
over = exp_unbiased > 15
exp_unbiased = torch.where(over, torch.full_like(exp_unbiased, 15), exp_unbiased)
mant = torch.where(over, torch.full_like(mant, 2), mant)
# 0xFF is NaN, so clip (exp=15, mant=3) down to max finite (mant=2)
mant = torch.where(
(exp_unbiased == 15) & (mant > 2),
torch.full_like(mant, 2),
mant,
)
exp_code = (exp_unbiased + 48).clamp(0, 63)
code = ((exp_code << 2) | mant).to(torch.uint8)
# all-zero block special case
code = torch.where(zero_mask, torch.zeros_like(code), code)
return code, decode_e6m2(code)
def decode_s1p2(code: torch.Tensor) -> torch.Tensor:
"""
S1P2 sign-magnitude decoding:
sign bit = bit 3
magnitude = bits [2:0] interpreted in quarter steps
representable set = {0.00, 0.25, ..., 1.75} with sign
"""
code_i = code.to(torch.int64)
sign = torch.where(((code_i >> 3) & 1) == 1, -1.0, 1.0)
mag = (code_i & 0x7).to(torch.float32) / 4.0
return sign * mag
def quantize_s1p2(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = x.to(torch.float32).clamp(-1.75, 1.75)
sign_bit = (x < 0).to(torch.uint8)
mag = round_half_away_from_zero(torch.abs(x) * 4.0).clamp_(0, 7).to(torch.uint8)
# Canonicalize both +0 and -0 to +0 in software.
sign_bit = torch.where(mag == 0, torch.zeros_like(sign_bit), sign_bit)
code = ((sign_bit << 3) | mag).to(torch.uint8)
return code, decode_s1p2(code)
def _reshape_into_64_blocks(x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
"""
Groups the last dimension into contiguous 64-value blocks.
Pads with zeros if needed.
"""
original_last_dim = x.shape[-1]
pad = (-original_last_dim) % _HIF4_BLOCK
if pad:
x = F.pad(x, (0, pad))
x = x.reshape(*x.shape[:-1], x.shape[-1] // _HIF4_BLOCK, _HIF4_BLOCK)
return x, original_last_dim, pad
@torch.no_grad()
def quantize_to_hif4_blocks(x: torch.Tensor) -> dict[str, torch.Tensor | int | tuple]:
"""
Direct implementation of Algorithm 1 on the last dimension.
Returns a packed HiF4 representation:
e6m2: (..., n_blocks) uint8
e1_8_packed: (..., n_blocks, 1) uint8
e1_16_packed: (..., n_blocks, 2) uint8
s1p2_packed: (..., n_blocks, 32) uint8
"""
# The paper starts from BF16. If x is not BF16 already, cast first.
x = x.to(torch.bfloat16).to(torch.float32)
blocks, original_last_dim, pad = _reshape_into_64_blocks(x)
abs_blocks = blocks.abs()
# Stage 1: 64 -> 16 -> 8 -> 1 peak tree reduction
v16 = abs_blocks.reshape(*blocks.shape[:-1], 16, 4).amax(dim=-1)
v8 = v16.reshape(*v16.shape[:-1], 8, 2).amax(dim=-1)
vmax = v8.amax(dim=-1)
# Stage 2: hierarchical scales
sf_bf16 = (vmax.to(torch.bfloat16) * _BF16_ONE_OVER_SEVEN.to(vmax.device)).to(torch.bfloat16).to(torch.float32)
e6m2_code, e6m2_value = quantize_e6m2(sf_bf16)
# The paper uses an E6M2_REC_to_BF16 instruction.
# In software we use the reciprocal of the quantized E6M2 value and cast to BF16.
e6m2_recip = (1.0 / e6m2_value).to(torch.bfloat16).to(torch.float32)
e1_8 = (v8 * e6m2_recip.unsqueeze(-1) >= 4.0).to(torch.uint8)
e1_8_for_v16 = e1_8.repeat_interleave(2, dim=-1)
lvl2_downscale = torch.where(
e1_8_for_v16.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
e1_16 = (v16 * e6m2_recip.unsqueeze(-1) * lvl2_downscale >= 2.0).to(torch.uint8)
# Stage 3: scale original values down, then quantize payload to S1P2
e1_8_for_v64 = e1_8.repeat_interleave(8, dim=-1)
e1_16_for_v64 = e1_16.repeat_interleave(4, dim=-1)
scaled = blocks * e6m2_recip.unsqueeze(-1)
scaled = scaled * torch.where(
e1_8_for_v64.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
scaled = scaled * torch.where(
e1_16_for_v64.bool(),
torch.tensor(0.5, device=x.device, dtype=torch.float32),
torch.tensor(1.0, device=x.device, dtype=torch.float32),
)
s1p2_code, _ = quantize_s1p2(scaled)
return {
"e6m2": e6m2_code.contiguous(),
"e1_8_packed": pack_bits(e1_8).contiguous(), # 8 bits -> 1 byte
"e1_16_packed": pack_bits(e1_16).contiguous(), # 16 bits -> 2 bytes
"s1p2_packed": pack_nibbles(s1p2_code).contiguous(), # 64 x 4b -> 32 bytes
"original_last_dim": original_last_dim,
"padded_last_dim": original_last_dim + pad,
"shape_prefix": tuple(x.shape[:-1]),
}
@torch.no_grad()
def dequantize_from_hif4_blocks(pack: dict[str, torch.Tensor | int | tuple]) -> torch.Tensor:
e6m2 = decode_e6m2(pack["e6m2"]).to(torch.float32)
e1_8 = unpack_bits(pack["e1_8_packed"], 8)
e1_16 = unpack_bits(pack["e1_16_packed"], 16)
s1p2_codes = unpack_nibbles(pack["s1p2_packed"], 64)
s1p2 = decode_s1p2(s1p2_codes)
e1_8_full = e1_8.repeat_interleave(8, dim=-1).to(torch.float32)
e1_16_full = e1_16.repeat_interleave(4, dim=-1).to(torch.float32)
blocks = e6m2.unsqueeze(-1) * torch.pow(2.0, e1_8_full + e1_16_full) * s1p2
out = blocks.reshape(*pack["shape_prefix"], pack["padded_last_dim"])
return out[..., : pack["original_last_dim"]]
class HiF4Linear(nn.Module):
"""
Reference nn.Linear replacement that stores its weight in packed HiF4 form.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.original_in_features = in_features
self.padded_in_features = in_features
if bias:
self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False)
else:
self.register_parameter("bias", None)
@classmethod
@torch.no_grad()
def empty(
cls,
in_features: int,
out_features: int,
bias: bool = True,
device: torch.device | str | None = None,
) -> "HiF4Linear":
mod = cls(in_features, out_features, bias=bias)
n_blocks = (in_features + _HIF4_BLOCK - 1) // _HIF4_BLOCK
padded_in_features = n_blocks * _HIF4_BLOCK
mod.original_in_features = in_features
mod.padded_in_features = padded_in_features
kwargs = {}
if device is not None:
kwargs["device"] = device
mod.register_buffer("e6m2", torch.zeros(out_features, n_blocks, dtype=torch.uint8, **kwargs))
mod.register_buffer("e1_8_packed", torch.zeros(out_features, n_blocks, 1, dtype=torch.uint8, **kwargs))
mod.register_buffer("e1_16_packed", torch.zeros(out_features, n_blocks, 2, dtype=torch.uint8, **kwargs))
mod.register_buffer("s1p2_packed", torch.zeros(out_features, n_blocks, 32, dtype=torch.uint8, **kwargs))
if bias:
mod.bias = nn.Parameter(torch.zeros(out_features, **kwargs), requires_grad=False)
return mod
@classmethod
@torch.no_grad()
def from_float(cls, linear: nn.Linear) -> "HiF4Linear":
mod = cls(linear.in_features, linear.out_features, bias=linear.bias is not None)
q = quantize_to_hif4_blocks(linear.weight.detach())
mod.original_in_features = q["original_last_dim"]
mod.padded_in_features = q["padded_last_dim"]
mod.register_buffer("e6m2", q["e6m2"])
mod.register_buffer("e1_8_packed", q["e1_8_packed"])
mod.register_buffer("e1_16_packed", q["e1_16_packed"])
mod.register_buffer("s1p2_packed", q["s1p2_packed"])
if linear.bias is not None:
mod.bias = nn.Parameter(linear.bias.detach().clone(), requires_grad=False)
return mod
@torch.no_grad()
def dequantize_weight(self, dtype: torch.dtype | None = None) -> torch.Tensor:
pack = {
"e6m2": self.e6m2,
"e1_8_packed": self.e1_8_packed,
"e1_16_packed": self.e1_16_packed,
"s1p2_packed": self.s1p2_packed,
"original_last_dim": self.original_in_features,
"padded_last_dim": self.padded_in_features,
"shape_prefix": (self.out_features,),
}
weight = dequantize_from_hif4_blocks(pack)
if dtype is not None:
weight = weight.to(dtype)
return weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.dequantize_weight(dtype=x.dtype if x.is_floating_point() else torch.float32)
bias = self.bias
if bias is not None and bias.dtype != weight.dtype:
bias = bias.to(weight.dtype)
return F.linear(x, weight, bias)
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, format=HiF4"
)
def _should_convert_linear(
full_name: str,
module: nn.Module,
skip_name_keywords: tuple[str, ...],
) -> bool:
if not isinstance(module, nn.Linear):
return False
lname = full_name.lower()
return not any(keyword in lname for keyword in skip_name_keywords)
@torch.no_grad()
def convert_hf_model_to_hif4(
model: nn.Module,
skip_name_keywords: tuple[str, ...] = (
"embed",
"embedding",
"lm_head",
"gate",
"gating",
"router",
),
verbose: bool = True,
structural_only: bool = False,
) -> nn.Module:
"""
Replaces eligible nn.Linear modules with HiF4Linear.
If structural_only=True, the module topology is converted without
quantizing the current weights. That is useful when reconstructing a
model structure before loading a previously saved HiF4 safetensors file.
Default skips:
- embeddings / embedding
- lm_head
- gate / gating / router
That matches the paper's setup for ordinary LLMs
(skip embedding + output head) and is also safe for MoE models
where gating/router layers should usually remain higher precision.
"""
replaced = []
def _convert(parent: nn.Module, prefix: str = "") -> None:
for child_name, child in list(parent.named_children()):
full_name = f"{prefix}.{child_name}" if prefix else child_name
if _should_convert_linear(full_name, child, skip_name_keywords):
if structural_only:
replacement = HiF4Linear.empty(
child.in_features,
child.out_features,
bias=child.bias is not None,
device=child.weight.device,
)
else:
replacement = HiF4Linear.from_float(child)
setattr(parent, child_name, replacement)
replaced.append(full_name)
else:
_convert(child, full_name)
_convert(model)
if verbose:
print(f"[HiF4] replaced {len(replaced)} Linear layers")
for name in replaced[:20]:
print(f" - {name}")
if len(replaced) > 20:
print(f" ... and {len(replaced) - 20} more")
return model
@torch.no_grad()
def load_hf_model_and_convert_to_hif4(
model_name_or_path: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
trust_remote_code: bool = False,
):
"""
Load a Hugging Face causal LM, then replace eligible Linear layers with HiF4Linear.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
model.eval()
model = convert_hf_model_to_hif4(model)
if device:
model = model.to(device)
return model, tokenizer
def _clone_shared_tensors_for_safetensors(
state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
safetensors stores plain tensors only. Some HF models tie weights
(for example embeddings <-> lm_head), which means multiple state_dict
entries may share the same underlying storage. We clone later aliases
so save_file() can serialize them safely as independent tensors.
"""
out: dict[str, torch.Tensor] = {}
seen_storages: set[tuple[int, int]] = set()
for key, value in state_dict.items():
tensor = value.detach()
storage = tensor.untyped_storage()
storage_key = (storage.data_ptr(), storage.nbytes())
tensor = tensor.cpu().contiguous()
if storage_key in seen_storages:
tensor = tensor.clone()
seen_storages.add(storage_key)
out[key] = tensor
return out
@torch.no_grad()
def save_hif4_model_as_safetensors(
model: nn.Module,
save_directory: str | Path,
tokenizer=None,
skip_name_keywords: tuple[str, ...] = (
"embed",
"embedding",
"lm_head",
"gate",
"gating",
"router",
),
filename: str = "model.safetensors",
) -> Path:
"""
Save a converted HiF4 model into a Hugging Face-style directory.
Saved files:
- model.safetensors
- config.json (if model has .config)
- tokenizer files (if tokenizer is given)
- hif4_config.json (stores HiF4-specific reconstruction metadata)
Note:
Because HiF4Linear is a custom module, loading requires rebuilding
the model structure with convert_hf_model_to_hif4(..., structural_only=True)
before loading the safetensors weights.
"""
from safetensors.torch import save_file
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
state_dict = _clone_shared_tensors_for_safetensors(model.state_dict())
out_path = save_directory / filename
save_file(
state_dict,
str(out_path),
metadata={
"format": "pt",
"quantization": "HiF4",
},
)
if hasattr(model, "config") and model.config is not None:
model.config.save_pretrained(save_directory)
if tokenizer is not None:
tokenizer.save_pretrained(save_directory)
with open(save_directory / "hif4_config.json", "w", encoding="utf-8") as f:
json.dump(
{
"format": "HiF4",
"filename": filename,
"skip_name_keywords": list(skip_name_keywords),
},
f,
indent=2,
)
return out_path
@torch.no_grad()
def load_hif4_model_from_safetensors(
model_directory: str | Path,
device: str = "cuda",
torch_dtype: torch.dtype | None = torch.bfloat16,
trust_remote_code: bool = False,
):
"""
Reconstruct a HF causal LM with HiF4Linear modules, then load the
previously saved safetensors checkpoint.
"""
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
model_directory = Path(model_directory)
manifest_path = model_directory / "hif4_config.json"
if manifest_path.exists():
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
skip_name_keywords = tuple(
manifest.get(
"skip_name_keywords",
["embed", "embedding", "lm_head", "gate", "gating", "router"],
)
)
filename = manifest.get("filename", "model.safetensors")
else:
skip_name_keywords = ("embed", "embedding", "lm_head", "gate", "gating", "router")
filename = "model.safetensors"
tokenizer = None
try:
tokenizer = AutoTokenizer.from_pretrained(
model_directory,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_directory,
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
model = convert_hf_model_to_hif4(
model,
skip_name_keywords=skip_name_keywords,
verbose=False,
structural_only=True,
)
state_dict = load_file(str(model_directory / filename))
missing, unexpected = model.load_state_dict(state_dict, strict=True)
if missing or unexpected:
raise RuntimeError(
f"State-dict mismatch while loading HiF4 safetensors. "
f"Missing={missing}, unexpected={unexpected}"
)
model.eval()
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
if device:
model = model.to(device)
return model, tokenizer
# -----------------------------------------------------------------------------
# Example
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Small smoke test without downloading a model:
linear = nn.Linear(130, 16, bias=True).eval()
hif4_linear = HiF4Linear.from_float(linear)
x = torch.randn(4, 130)
y_ref = linear(x)
y_hif4 = hif4_linear(x)
mse = (y_ref - y_hif4).pow(2).mean().item()
print("smoke-test MSE:", mse)
# Hugging Face usage:
#
# model_id = "meta-llama/Llama-3.1-8B"
# model, tokenizer = load_hf_model_and_convert_to_hif4(
# model_id,
# device="cuda",
# torch_dtype=torch.bfloat16,
# )
#
# prompt = "Explain block floating point in one paragraph."
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# with torch.no_grad():
# out = model.generate(**inputs, max_new_tokens=64)
# print(tokenizer.decode(out[0], skip_special_tokens=True))
#
# save_hif4_model_as_safetensors(model, "./llama_hif4", tokenizer=tokenizer)
# model2, tokenizer2 = load_hif4_model_from_safetensors("./llama_hif4")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment