Skip to content

Instantly share code, notes, and snippets.

@moonlockwood
Created December 17, 2024 12:01
Show Gist options
  • Select an option

  • Save moonlockwood/0da515ce9eaa99e6cc641f1142d22f15 to your computer and use it in GitHub Desktop.

Select an option

Save moonlockwood/0da515ce9eaa99e6cc641f1142d22f15 to your computer and use it in GitHub Desktop.
Toy Byte Latent Transformer

* warning * This is non rigorous and probably wrong

Simplest framework needed to mess around with a BLT

o1 says:

Below is a self-contained, minimal example showing how you might implement and train a very simplified Byte Latent Transformer (BLT)-style model on a single machine with 1-2 GPUs. This is a toy implementation for experimentation and does not replicate the scale or exact performance of the official implementation. It focuses on the main architectural concepts: a local byte encoder, a global latent transformer operating on patches, and a local decoder. The code uses a simple fixed-stride patching scheme and trains on a dummy dataset of random bytes. You can adapt it to use your own dataset and incorporate entropy-based patching later.

What this example does:

Implements a basic BLT model with:
    A local byte encoder (a small Transformer with cross-attention to produce patch representations).
    A global latent transformer over patches.
    A local byte decoder.
Uses a fixed-size patching function for simplicity.
Runs a few training steps on random byte sequences.
Supports multi-GPU training via PyTorch’s torch.distributed (optional).

Prerequisites:

Python 3.9+
PyTorch 2.0+ with CUDA support
(Optional) Set environment variables or run torch.distributed.launch for multi-GPU.

If you need to adapt this code for a real dataset:

Replace the RandomByteDataset with a TextByteDataset that reads real text files, converts them to bytes, and splits them into sequences.
Implement entropy patching by loading a small entropy model, computing per-byte entropies, and deciding patch boundaries accordingly.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import argparse
import os
#####################################################################
# Dataset
#####################################################################
class SequentialByteDataset(Dataset):
"""
A simple dataset that produces sequential byte values in a loop [0, 1, ..., 255, 0, 1, ...].
"""
def __init__(self, length=10000, seq_len=1024):
self.length = length # Number of sequences
self.seq_len = seq_len # Length of each sequence
def __len__(self):
return self.length
def __getitem__(self, idx):
# Create a sequence [0, 1, ..., 255] repeated to fill seq_len
sequence = torch.arange(256, dtype=torch.long).repeat((self.seq_len // 256) + 1)
return sequence[:self.seq_len]
#####################################################################
# Patching Utilities
#####################################################################
def fixed_stride_patching(byte_seq, patch_size=8):
"""
Simple fixed-stride patching:
Given a byte sequence of length N, we form patches of size `patch_size`.
Last patch might be shorter if not divisible.
"""
seq_len = byte_seq.size(0)
patches = []
for i in range(0, seq_len, patch_size):
patches.append(byte_seq[i:i+patch_size])
return patches
#####################################################################
# Model Components
#####################################################################
class ByteEmbedding(nn.Module):
def __init__(self, vocab_size=256, embed_dim=256):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
def forward(self, x):
return self.embed(x)
class LocalByteEncoder(nn.Module):
"""
A small transformer that converts bytes into patch representations.
We add a cross-attention layer that pools the bytes into a single patch representation.
"""
def __init__(self, embed_dim=256, num_layers=2, num_heads=4, ff_mult=4):
super().__init__()
self.embed_dim = embed_dim
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Query for cross-attention per patch
self.query_init = nn.Parameter(torch.randn(embed_dim))
self.attn_proj_q = nn.Linear(embed_dim, embed_dim)
self.attn_proj_k = nn.Linear(embed_dim, embed_dim)
self.attn_proj_v = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, patch_bytes_emb):
# patch_bytes_emb: [B, patch_len, embed_dim]
# Encode bytes locally
h = self.encoder(patch_bytes_emb)
# Cross-attention pooling: single query per patch
# Query: [B, 1, embed_dim]
B, L, E = h.size()
query = self.query_init.unsqueeze(0).unsqueeze(1).expand(B, 1, E)
Q = self.attn_proj_q(query)
K = self.attn_proj_k(h)
V = self.attn_proj_v(h)
attn_scores = (Q @ K.transpose(-1,-2)) / math.sqrt(E) # [B, 1, L]
attn_weights = torch.softmax(attn_scores, dim=-1) # [B, 1, L]
pooled = attn_weights @ V # [B, 1, E]
pooled = self.out_proj(pooled) # [B, 1, E]
return pooled.squeeze(1) # [B, E]
class GlobalLatentTransformer(nn.Module):
"""
Large transformer over patch representations.
"""
def __init__(self, embed_dim=256, num_layers=6, num_heads=8, ff_mult=4):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True)
self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
def forward(self, patch_reps):
# patch_reps: [B, n_patches, E]
return self.transformer(patch_reps)
class LocalByteDecoder(nn.Module):
"""
Decode global patch representation back to bytes.
Similar to encoder but reversed role of cross-attention.
We'll generate next patch bytes autoregressively.
For simplicity, assume we already have the gold bytes and train with teacher forcing.
"""
def __init__(self, embed_dim=256, num_layers=2, num_heads=4, ff_mult=4):
super().__init__()
decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True)
self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
self.attn_proj_q = nn.Linear(embed_dim, embed_dim)
self.attn_proj_k = nn.Linear(embed_dim, embed_dim)
self.attn_proj_v = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, 256) # vocab size fixed at 256
def forward(self, patch_emb, global_rep):
# patch_emb: [B, patch_len, E]
# global_rep: [B, E] single patch rep
# Cross-attend from bytes to patch rep
B, L, E = patch_emb.size()
Q = self.attn_proj_q(patch_emb)
K = self.attn_proj_k(global_rep.unsqueeze(1)) # [B,1,E]
V = self.attn_proj_v(global_rep.unsqueeze(1))
attn_scores = (Q @ K.transpose(-1,-2)) / math.sqrt(E) # [B, L, 1]
attn_weights = torch.softmax(attn_scores, dim=-2) # [B, L, 1]
fused = attn_weights * V # [B, L, E]
fused = fused + patch_emb # residual connection
h = self.decoder(fused)
logits = self.out_proj(h) # [B, L, 256]
return logits
#####################################################################
# Combined BLT Model
#####################################################################
class BLTModel(nn.Module):
def __init__(self, byte_embed_dim=256, patch_size=8):
super().__init__()
self.byte_embed = ByteEmbedding(vocab_size=256, embed_dim=byte_embed_dim)
self.encoder = LocalByteEncoder(embed_dim=byte_embed_dim, num_layers=2, num_heads=4, ff_mult=4)
self.global_model = GlobalLatentTransformer(embed_dim=byte_embed_dim, num_layers=6, num_heads=8, ff_mult=4)
self.decoder = LocalByteDecoder(embed_dim=byte_embed_dim, num_layers=2, num_heads=4, ff_mult=4)
self.patch_size = patch_size
def forward(self, x):
# x: [B, seq_len] bytes
B, N = x.size()
# Split into patches
patches = fixed_stride_patching(x[0], self.patch_size) # only shape logic
# We'll vectorize for B>1: assume same length for simplicity
# Combine all batches into a single list of patches
all_patches = []
max_patches = math.ceil(N / self.patch_size)
for b in range(B):
p = fixed_stride_patching(x[b], self.patch_size)
# pad if needed
while len(p) < max_patches:
pad_size = self.patch_size - (len(p[-1]) if p else 0)
p[-1] = torch.cat([p[-1], torch.full((pad_size,), 0, dtype=torch.long, device=x.device)], dim=0)
all_patches.append(torch.stack(p, dim=0)) # [num_patches, patch_size]
all_patches = torch.stack(all_patches, dim=0) # [B, num_patches, patch_size]
# Encode patches
# Encode each patch of shape [B, patch_size] with byte embeddings
patch_emb = self.byte_embed(all_patches) # [B, n_patches, patch_size, E]
# Encode each patch to a single vector
B, M, L, E = patch_emb.size()
patch_reps = []
for i in range(M):
rep = self.encoder(patch_emb[:, i]) # [B, E]
patch_reps.append(rep)
patch_reps = torch.stack(patch_reps, dim=1) # [B, n_patches, E]
# Global latent transformer
global_out = self.global_model(patch_reps) # [B, n_patches, E]
# Decode next patch:
# We'll shift the input by one patch for prediction:
# Given patch[i], predict patch[i+1].
logits = []
for i in range(M - 1):
# decode patch i+1 using global representation of patch i
dec_inp = patch_emb[:, i+1] # next patch bytes
# global_out[:, i] is the rep for patch i
logit = self.decoder(dec_inp, global_out[:, i])
logits.append(logit) # [B, patch_size, 256]
# stack logits, shape: [B, (M-1)*patch_size, 256]
if len(logits) > 0:
logits = torch.cat(logits, dim=1)
else:
logits = torch.zeros(B, 0, 256, device=x.device)
# Flatten targets (next-patch bytes)
# Target is all except the first patch:
targets = all_patches[:, 1:].reshape(B, -1) # [B, (M-1)*patch_size]
return logits, targets
#####################################################################
# Training Loop
#####################################################################
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--patch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--steps", type=int, default=200)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
args = parser.parse_args()
# Optional: Initialize distributed training if world_size > 1
if args.world_size > 1:
torch.distributed.init_process_group(backend="nccl", world_size=args.world_size, rank=args.local_rank)
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = SequentialByteDataset(length=1000, seq_len=args.seq_len)
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.world_size > 1 else None
dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=(sampler is None))
model = BLTModel(byte_embed_dim=256, patch_size=args.patch_size).to(device)
if args.world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
model.train()
step = 0
for epoch in range(args.epochs):
if sampler is not None:
sampler.set_epoch(epoch)
for batch in dataloader:
batch = batch.to(device)
optimizer.zero_grad()
logits, targets = model(batch)
# logits: [B, (M-1)*patch_size, 256]
# targets: [B, (M-1)*patch_size]
if targets.numel() == 0:
# no targets to train on if sequence < patch_size*2
continue
loss = F.cross_entropy(logits.reshape(-1, 256), targets.reshape(-1))
loss.backward()
optimizer.step()
if step % 10 == 0 and (args.world_size == 1 or torch.distributed.get_rank() == 0):
print(f"Step {step}, Loss: {loss.item():.4f}")
step += 1
if step >= args.steps:
break
if step >= args.steps:
break
if args.world_size == 1 or torch.distributed.get_rank() == 0:
print("Training complete.")
# Test the model after training
model.eval()
test_seq = dataset[0].unsqueeze(0).to(device) # Single sequence as input
logits, _ = model(test_seq)
preds = torch.argmax(logits, dim=-1) # Get predicted bytes
print("\nInput Bytes:")
print(test_seq[0, :args.patch_size * 2].cpu().numpy()) # Print first two patches
print("Predicted Bytes:")
print(preds[0, :args.patch_size].cpu().numpy()) # Print prediction for the next patch
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment