|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
import argparse |
|
import os |
|
|
|
##################################################################### |
|
# Dataset |
|
##################################################################### |
|
|
|
class SequentialByteDataset(Dataset): |
|
""" |
|
A simple dataset that produces sequential byte values in a loop [0, 1, ..., 255, 0, 1, ...]. |
|
""" |
|
def __init__(self, length=10000, seq_len=1024): |
|
self.length = length # Number of sequences |
|
self.seq_len = seq_len # Length of each sequence |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
# Create a sequence [0, 1, ..., 255] repeated to fill seq_len |
|
sequence = torch.arange(256, dtype=torch.long).repeat((self.seq_len // 256) + 1) |
|
return sequence[:self.seq_len] |
|
|
|
|
|
##################################################################### |
|
# Patching Utilities |
|
##################################################################### |
|
|
|
def fixed_stride_patching(byte_seq, patch_size=8): |
|
""" |
|
Simple fixed-stride patching: |
|
Given a byte sequence of length N, we form patches of size `patch_size`. |
|
Last patch might be shorter if not divisible. |
|
""" |
|
seq_len = byte_seq.size(0) |
|
patches = [] |
|
for i in range(0, seq_len, patch_size): |
|
patches.append(byte_seq[i:i+patch_size]) |
|
return patches |
|
|
|
##################################################################### |
|
# Model Components |
|
##################################################################### |
|
|
|
class ByteEmbedding(nn.Module): |
|
def __init__(self, vocab_size=256, embed_dim=256): |
|
super().__init__() |
|
self.embed = nn.Embedding(vocab_size, embed_dim) |
|
def forward(self, x): |
|
return self.embed(x) |
|
|
|
class LocalByteEncoder(nn.Module): |
|
""" |
|
A small transformer that converts bytes into patch representations. |
|
We add a cross-attention layer that pools the bytes into a single patch representation. |
|
""" |
|
def __init__(self, embed_dim=256, num_layers=2, num_heads=4, ff_mult=4): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True) |
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
# Query for cross-attention per patch |
|
self.query_init = nn.Parameter(torch.randn(embed_dim)) |
|
self.attn_proj_q = nn.Linear(embed_dim, embed_dim) |
|
self.attn_proj_k = nn.Linear(embed_dim, embed_dim) |
|
self.attn_proj_v = nn.Linear(embed_dim, embed_dim) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
def forward(self, patch_bytes_emb): |
|
# patch_bytes_emb: [B, patch_len, embed_dim] |
|
# Encode bytes locally |
|
h = self.encoder(patch_bytes_emb) |
|
# Cross-attention pooling: single query per patch |
|
# Query: [B, 1, embed_dim] |
|
B, L, E = h.size() |
|
query = self.query_init.unsqueeze(0).unsqueeze(1).expand(B, 1, E) |
|
Q = self.attn_proj_q(query) |
|
K = self.attn_proj_k(h) |
|
V = self.attn_proj_v(h) |
|
|
|
attn_scores = (Q @ K.transpose(-1,-2)) / math.sqrt(E) # [B, 1, L] |
|
attn_weights = torch.softmax(attn_scores, dim=-1) # [B, 1, L] |
|
pooled = attn_weights @ V # [B, 1, E] |
|
pooled = self.out_proj(pooled) # [B, 1, E] |
|
return pooled.squeeze(1) # [B, E] |
|
|
|
class GlobalLatentTransformer(nn.Module): |
|
""" |
|
Large transformer over patch representations. |
|
""" |
|
def __init__(self, embed_dim=256, num_layers=6, num_heads=8, ff_mult=4): |
|
super().__init__() |
|
layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True) |
|
self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers) |
|
|
|
def forward(self, patch_reps): |
|
# patch_reps: [B, n_patches, E] |
|
return self.transformer(patch_reps) |
|
|
|
class LocalByteDecoder(nn.Module): |
|
""" |
|
Decode global patch representation back to bytes. |
|
Similar to encoder but reversed role of cross-attention. |
|
We'll generate next patch bytes autoregressively. |
|
For simplicity, assume we already have the gold bytes and train with teacher forcing. |
|
""" |
|
def __init__(self, embed_dim=256, num_layers=2, num_heads=4, ff_mult=4): |
|
super().__init__() |
|
decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*ff_mult, batch_first=True) |
|
self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers) |
|
|
|
self.attn_proj_q = nn.Linear(embed_dim, embed_dim) |
|
self.attn_proj_k = nn.Linear(embed_dim, embed_dim) |
|
self.attn_proj_v = nn.Linear(embed_dim, embed_dim) |
|
self.out_proj = nn.Linear(embed_dim, 256) # vocab size fixed at 256 |
|
|
|
def forward(self, patch_emb, global_rep): |
|
# patch_emb: [B, patch_len, E] |
|
# global_rep: [B, E] single patch rep |
|
# Cross-attend from bytes to patch rep |
|
B, L, E = patch_emb.size() |
|
Q = self.attn_proj_q(patch_emb) |
|
K = self.attn_proj_k(global_rep.unsqueeze(1)) # [B,1,E] |
|
V = self.attn_proj_v(global_rep.unsqueeze(1)) |
|
|
|
attn_scores = (Q @ K.transpose(-1,-2)) / math.sqrt(E) # [B, L, 1] |
|
attn_weights = torch.softmax(attn_scores, dim=-2) # [B, L, 1] |
|
fused = attn_weights * V # [B, L, E] |
|
fused = fused + patch_emb # residual connection |
|
h = self.decoder(fused) |
|
logits = self.out_proj(h) # [B, L, 256] |
|
return logits |
|
|
|
##################################################################### |
|
# Combined BLT Model |
|
##################################################################### |
|
|
|
class BLTModel(nn.Module): |
|
def __init__(self, byte_embed_dim=256, patch_size=8): |
|
super().__init__() |
|
self.byte_embed = ByteEmbedding(vocab_size=256, embed_dim=byte_embed_dim) |
|
self.encoder = LocalByteEncoder(embed_dim=byte_embed_dim, num_layers=2, num_heads=4, ff_mult=4) |
|
self.global_model = GlobalLatentTransformer(embed_dim=byte_embed_dim, num_layers=6, num_heads=8, ff_mult=4) |
|
self.decoder = LocalByteDecoder(embed_dim=byte_embed_dim, num_layers=2, num_heads=4, ff_mult=4) |
|
self.patch_size = patch_size |
|
|
|
def forward(self, x): |
|
# x: [B, seq_len] bytes |
|
B, N = x.size() |
|
|
|
# Split into patches |
|
patches = fixed_stride_patching(x[0], self.patch_size) # only shape logic |
|
# We'll vectorize for B>1: assume same length for simplicity |
|
# Combine all batches into a single list of patches |
|
all_patches = [] |
|
max_patches = math.ceil(N / self.patch_size) |
|
for b in range(B): |
|
p = fixed_stride_patching(x[b], self.patch_size) |
|
# pad if needed |
|
while len(p) < max_patches: |
|
pad_size = self.patch_size - (len(p[-1]) if p else 0) |
|
p[-1] = torch.cat([p[-1], torch.full((pad_size,), 0, dtype=torch.long, device=x.device)], dim=0) |
|
all_patches.append(torch.stack(p, dim=0)) # [num_patches, patch_size] |
|
all_patches = torch.stack(all_patches, dim=0) # [B, num_patches, patch_size] |
|
|
|
# Encode patches |
|
# Encode each patch of shape [B, patch_size] with byte embeddings |
|
patch_emb = self.byte_embed(all_patches) # [B, n_patches, patch_size, E] |
|
# Encode each patch to a single vector |
|
B, M, L, E = patch_emb.size() |
|
patch_reps = [] |
|
for i in range(M): |
|
rep = self.encoder(patch_emb[:, i]) # [B, E] |
|
patch_reps.append(rep) |
|
patch_reps = torch.stack(patch_reps, dim=1) # [B, n_patches, E] |
|
|
|
# Global latent transformer |
|
global_out = self.global_model(patch_reps) # [B, n_patches, E] |
|
|
|
# Decode next patch: |
|
# We'll shift the input by one patch for prediction: |
|
# Given patch[i], predict patch[i+1]. |
|
logits = [] |
|
for i in range(M - 1): |
|
# decode patch i+1 using global representation of patch i |
|
dec_inp = patch_emb[:, i+1] # next patch bytes |
|
# global_out[:, i] is the rep for patch i |
|
logit = self.decoder(dec_inp, global_out[:, i]) |
|
logits.append(logit) # [B, patch_size, 256] |
|
|
|
# stack logits, shape: [B, (M-1)*patch_size, 256] |
|
if len(logits) > 0: |
|
logits = torch.cat(logits, dim=1) |
|
else: |
|
logits = torch.zeros(B, 0, 256, device=x.device) |
|
|
|
# Flatten targets (next-patch bytes) |
|
# Target is all except the first patch: |
|
targets = all_patches[:, 1:].reshape(B, -1) # [B, (M-1)*patch_size] |
|
return logits, targets |
|
|
|
##################################################################### |
|
# Training Loop |
|
##################################################################### |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--batch_size", type=int, default=8) |
|
parser.add_argument("--seq_len", type=int, default=128) |
|
parser.add_argument("--patch_size", type=int, default=8) |
|
parser.add_argument("--lr", type=float, default=1e-3) |
|
parser.add_argument("--epochs", type=int, default=5) |
|
parser.add_argument("--steps", type=int, default=200) |
|
parser.add_argument("--local_rank", type=int, default=0) |
|
parser.add_argument("--world_size", type=int, default=1) |
|
args = parser.parse_args() |
|
|
|
# Optional: Initialize distributed training if world_size > 1 |
|
if args.world_size > 1: |
|
torch.distributed.init_process_group(backend="nccl", world_size=args.world_size, rank=args.local_rank) |
|
torch.cuda.set_device(args.local_rank) |
|
device = torch.device("cuda", args.local_rank) |
|
else: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
dataset = SequentialByteDataset(length=1000, seq_len=args.seq_len) |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.world_size > 1 else None |
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=(sampler is None)) |
|
|
|
model = BLTModel(byte_embed_dim=256, patch_size=args.patch_size).to(device) |
|
if args.world_size > 1: |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) |
|
|
|
model.train() |
|
step = 0 |
|
for epoch in range(args.epochs): |
|
if sampler is not None: |
|
sampler.set_epoch(epoch) |
|
for batch in dataloader: |
|
batch = batch.to(device) |
|
optimizer.zero_grad() |
|
logits, targets = model(batch) |
|
# logits: [B, (M-1)*patch_size, 256] |
|
# targets: [B, (M-1)*patch_size] |
|
if targets.numel() == 0: |
|
# no targets to train on if sequence < patch_size*2 |
|
continue |
|
loss = F.cross_entropy(logits.reshape(-1, 256), targets.reshape(-1)) |
|
loss.backward() |
|
optimizer.step() |
|
if step % 10 == 0 and (args.world_size == 1 or torch.distributed.get_rank() == 0): |
|
print(f"Step {step}, Loss: {loss.item():.4f}") |
|
step += 1 |
|
if step >= args.steps: |
|
break |
|
if step >= args.steps: |
|
break |
|
|
|
if args.world_size == 1 or torch.distributed.get_rank() == 0: |
|
print("Training complete.") |
|
|
|
# Test the model after training |
|
model.eval() |
|
test_seq = dataset[0].unsqueeze(0).to(device) # Single sequence as input |
|
logits, _ = model(test_seq) |
|
preds = torch.argmax(logits, dim=-1) # Get predicted bytes |
|
print("\nInput Bytes:") |
|
print(test_seq[0, :args.patch_size * 2].cpu().numpy()) # Print first two patches |
|
print("Predicted Bytes:") |
|
print(preds[0, :args.patch_size].cpu().numpy()) # Print prediction for the next patch |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |