Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Created January 8, 2026 14:06
Show Gist options
  • Select an option

  • Save amitpuri/218233bd9272d6f8de3cf171341b7a45 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/218233bd9272d6f8de3cf171341b7a45 to your computer and use it in GitHub Desktop.
"""
Transformer Architecture v7
==================================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import signal
import sys
import time
import glob
import tiktoken
from collections import Counter
from typing import Tuple, List, Dict
# Global flag for graceful shutdown
_interrupted = False
def signal_handler(signum, frame):
"""Handle interrupt signals gracefully"""
global _interrupted
_interrupted = True
print("\n\n[!] Interrupt received. Finishing current operation gracefully...")
print(" Press Ctrl+C again to force exit.")
signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt
def reset_interrupt():
"""Reset the interrupt flag"""
global _interrupted
_interrupted = False
def is_interrupted():
"""Check if interrupt was requested"""
return _interrupted
# Install signal handler
signal.signal(signal.SIGINT, signal_handler)
import array
# ============================================
# 0. HELPER FUNCTIONS
# ============================================
def safe_input(prompt, default_value):
"""Safe input with default value"""
try:
user_input = input(prompt)
if not user_input.strip():
return default_value
return user_input
except EOFError:
return default_value
def batch_encode(tokenizer, story_generator, total_target, batch_size=1000):
"""
Encode stories in batches to show progress
"""
print(f" Encoding approx {total_target} stories in batches of {batch_size}...")
# Use array.array 'H' (unsigned short, 2 bytes) for memory efficiency.
# GPT-2 vocab is 50257, which fits in 0-65535.
all_tokens = array.array('H')
for i, story in enumerate(story_generator):
if is_interrupted():
print("\n [!] Interrupted during encoding.")
break
if story.strip():
tokens = tokenizer.encode(story)
all_tokens.extend(tokens)
# Append EOS token so the model learns to stop
all_tokens.append(tokenizer.eos_token_id)
if (i + 1) % batch_size == 0:
print(f" Processed {i + 1}/{total_target} stories...", end='\r')
print(f"\n Finished encoding. Total tokens: {len(all_tokens)}")
return all_tokens
# ============================================
# 1. ENHANCED TOKENIZATION (Tiktoken)
# ============================================
class TiktokenTokenizer:
"""Subword tokenizer using tiktoken (gpt2 by default)"""
def __init__(self, encoding_name="gpt2"):
self.encoding_name = encoding_name
self.encoder = tiktoken.get_encoding(encoding_name)
self.vocab_size = self.encoder.n_vocab
# Tiktoken gpt2 has <|endoftext|> as token 50256
# We'll map a few common labels for compatibility if needed,
# but subword tokenizers usually handle their own special tokens.
self.eos_token_id = self.encoder.eot_token
# Compatibility with v6's SPECIAL_TOKENS interface where possible
self.SPECIAL_TOKENS = {
'<pad>': 0, # gpt2 doesn't have a pad token by default, but 0 is usually fine or we could add one
'<eos>': self.eos_token_id
}
def encode(self, text):
"""Convert text to token IDs"""
return self.encoder.encode(text, allowed_special={'<|endoftext|>'})
def decode(self, ids):
"""Convert token IDs back to text"""
if torch.is_tensor(ids):
ids = ids.tolist()
# Handle array.array or list
return self.encoder.decode(ids)
# ============================================
# 2. TRANSFORMER COMPONENTS
# ============================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
"""Sinusoidal positional encoding"""
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
"""x: [batch_size, seq_len, d_model]"""
return self.dropout(x + self.pe[:, :x.size(1)])
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.to(scores.device)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
output = self.W_o(context)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(self.dropout(F.relu(self.fc1(x))))
class TransformerBlockV7(nn.Module):
"""Pre-LN Transformer Block"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.norm2 = LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.dropout(self.attention(self.norm1(x), mask))
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerV7(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=6,
d_ff=1024, max_seq_len=5000, dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.dropout = nn.Dropout(dropout)
self.transformer_blocks = nn.ModuleList([
TransformerBlockV7(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.final_norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
self._init_parameters()
def _init_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_causal_mask(self, seq_len, device):
return torch.tril(torch.ones(seq_len, seq_len, device=device))
def forward(self, token_ids):
batch_size, seq_len = token_ids.shape
x = self.token_embedding(token_ids)
x = self.positional_encoding(x)
x = self.dropout(x)
causal_mask = self.generate_causal_mask(seq_len, token_ids.device)
for block in self.transformer_blocks:
x = block(x, mask=causal_mask)
x = self.final_norm(x)
logits = self.lm_head(x)
return logits
def get_num_params(self):
return sum(p.numel() for p in self.parameters())
# ============================================
# 3. DATA LOADING & TRAINING
# ============================================
def get_tiny_stories_generator(percent=1.0):
"""
Generator that yields stories one by one from TinyStories
"""
from datasets import load_dataset
print(f"Loading {percent}% of TinyStories dataset (Streaming mode)...")
dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
# Estimate target count (approx 2.1M total in train split)
total_count = 2119719
target_count = max(1, int(total_count * (percent / 100.0)))
def story_yield():
for i, entry in enumerate(dataset):
if i >= target_count:
break
yield entry['text']
return story_yield(), target_count
def train_epoch(model, tokenizer, data_ids, batch_size=32, seq_len=64, optimizer=None, criterion=None):
model.train()
total_loss = 0
num_batches = 0
n_tokens = len(data_ids)
# data_ids is now an array.array or list
for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len):
if is_interrupted():
break
inputs_list = []
targets_list = []
for b in range(batch_size):
start_idx = i + b * seq_len
if start_idx + seq_len + 1 >= n_tokens:
break
chunk = data_ids[start_idx : start_idx + seq_len + 1]
inputs_list.append(list(chunk[:-1]))
targets_list.append(list(chunk[1:]))
if len(inputs_list) == 0:
break
inputs = torch.tensor(inputs_list).to(model.token_embedding.weight.device)
targets = torch.tensor(targets_list).to(model.token_embedding.weight.device)
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if num_batches % 10 == 0:
print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\r')
print(f" Batch {num_batches}, Final Loss: {loss.item():.4f}")
return total_loss / max(1, num_batches)
def generate_text(model, tokenizer, prompt="Once upon a time", max_len=100, temperature=0.8):
model.eval()
tokens = tokenizer.encode(prompt)
input_ids = torch.tensor([tokens]).to(model.token_embedding.weight.device)
print(f"Generating (Prompt: '{prompt}')...")
start_len = input_ids.shape[1]
max_model_len = model.positional_encoding.pe.size(1)
# Truncate input if it exceeds model capacity
if start_len > max_model_len:
print(f" [!] Prompt length ({start_len}) exceeds model limit ({max_model_len}). Truncating.")
input_ids = input_ids[:, -max_model_len:]
for _ in range(max_len):
if is_interrupted(): break
# Safety check for context window
if input_ids.size(1) >= max_model_len:
print(f"\n [!] Context window reached ({max_model_len}). Stopping generation.")
break
with torch.no_grad():
logits = model(input_ids)
last_logits = logits[:, -1, :] / max(temperature, 1e-5)
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
# ============================================
# 4. CHECKPOINTING
# ============================================
def save_checkpoint(model, tokenizer, path):
checkpoint = {
'model_state': model.state_dict(),
'encoding_name': tokenizer.encoding_name,
'config': {
'd_model': model.d_model,
'num_heads': model.num_heads,
'num_layers': model.num_layers,
'd_ff': model.d_ff,
'vocab_size': model.vocab_size,
}
}
torch.save(checkpoint, path)
print(f"Model and configuration saved to {path}")
def load_checkpoint(path):
if not os.path.exists(path):
return None, None
print(f"Loading checkpoint from {path}...")
checkpoint = torch.load(path)
encoding_name = checkpoint.get('encoding_name', 'gpt2')
config = checkpoint['config']
tokenizer = TiktokenTokenizer(encoding_name=encoding_name)
model = TransformerV7(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
d_ff=config.get('d_ff', 1024),
)
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model, tokenizer
def select_model_file(default_path=None):
os.makedirs("models_cache", exist_ok=True)
files = glob.glob("models_cache/model_v7*.pth")
files.sort(key=os.path.getmtime, reverse=True)
if not files:
print("No existing v7 model files found.")
return default_path
print("\nAvailable Models:")
for i, f in enumerate(files):
size_mb = os.path.getsize(f) / (1024 * 1024)
print(f"{i+1}. {f} ({size_mb:.2f} MB)")
print(f"{len(files)+1}. Cancel / Use New Name ({default_path})")
choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1")
try:
idx = int(choice) - 1
if 0 <= idx < len(files):
return files[idx]
except ValueError:
pass
return default_path
# ============================================
# 5. MAIN EXECUTION MENU
# ============================================
def main():
print("Transformer Architecture v7")
print("=============================================")
timestamp = time.strftime("%Y%m%d_%H%M%S")
os.makedirs("models_cache", exist_ok=True)
MODEL_PATH = f"models_cache/model_v7_{timestamp}.pth"
model = None
tokenizer = None
while True:
reset_interrupt()
print("\n\nMAIN MENU")
print("1. Train New Model (TinyStories)")
print("2. Load Model & Resume Training")
print("3. Generate Text")
print("4. Exit")
choice = safe_input("\nEnter choice (1-4): ", "4")
if choice == '1':
percent_str = safe_input("Enter percentage of TinyStories to load (0.01-100, default 0.1): ", "0.1")
try:
percent = float(percent_str)
except:
percent = 0.1
print("\n[1] Preparing Data Stream...")
story_gen, target_count = get_tiny_stories_generator(percent=percent)
print("\n[2] Initializing Tokenizer...")
tokenizer = TiktokenTokenizer("gpt2")
print(f"Tokenizer ready. Vocab size: {tokenizer.vocab_size}")
print("\n[3] Initializing Model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on: {device}")
# Using smaller default for subword to keep speed vs v6 word-level
model = TransformerV7(
vocab_size=tokenizer.vocab_size,
d_model=256,
num_heads=8,
num_layers=4,
d_ff=1024
).to(device)
print(f"Params: {model.get_num_params():,}")
# --- TOKEN CACHING LOGIC ---
os.makedirs("datasets_cache", exist_ok=True)
cache_file = f"datasets_cache/tinystories_p{percent}.bin"
if os.path.exists(cache_file):
print(f"\n[4] Loading Encoded Data from Cache ({cache_file})...")
train_ids = array.array('H')
with open(cache_file, 'rb') as f:
train_ids.fromfile(f, os.path.getsize(cache_file) // 2)
print(f" Loaded {len(train_ids)} tokens.")
else:
print("\n[4] Encoding Data (Streaming)...")
train_ids = batch_encode(tokenizer, story_gen, total_target=target_count, batch_size=1000)
if train_ids and not is_interrupted():
print(f" Saving tokens to cache: {cache_file}")
with open(cache_file, 'wb') as f:
train_ids.tofile(f)
if not train_ids:
print("No tokens found. Training cancelled.")
continue
print("\n[5] Training Loop (Press Ctrl+C to stop & save)")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>'])
epochs = int(safe_input("Epochs (default 3): ", "3"))
try:
for epoch in range(1, epochs + 1):
if is_interrupted(): break
print(f"\n--- Epoch {epoch}/{epochs} ---")
avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64,
optimizer=optimizer, criterion=criterion)
print(f"Avg Loss: {avg_loss:.4f}")
if epoch % 1 == 0:
save_checkpoint(model, tokenizer, MODEL_PATH)
except KeyboardInterrupt:
print("\nInterrupted.")
if is_interrupted():
print("\n[!] Loop interrupted. Saving checkpoint...")
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '2':
path = select_model_file()
if not path:
continue
model, tokenizer = load_checkpoint(path)
if not model:
print("Failed to load model.")
continue
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
MODEL_PATH = path
print("\nNote: To resume training, we need data.")
percent_str = safe_input("Enter percentage of TinyStories to load for training (default 0.1): ", "0.1")
percent = float(percent_str)
# --- TOKEN CACHING LOGIC ---
os.makedirs("datasets_cache", exist_ok=True)
cache_file = f"datasets_cache/tinystories_p{percent}.bin"
if os.path.exists(cache_file):
print(f"Loading Encoded Data from Cache ({cache_file})...")
train_ids = array.array('H')
with open(cache_file, 'rb') as f:
train_ids.fromfile(f, os.path.getsize(cache_file) // 2)
else:
story_gen, target_count = get_tiny_stories_generator(percent=percent)
print("Encoding data...")
train_ids = batch_encode(tokenizer, story_gen, total_target=target_count, batch_size=1000)
if train_ids and not is_interrupted():
print(f"Saving tokens to cache: {cache_file}")
with open(cache_file, 'wb') as f:
train_ids.tofile(f)
if not train_ids:
print("Data loading failed.")
continue
print("\nResuming Training (Press Ctrl+C to stop & save)")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>'])
epochs = int(safe_input("Additional Epochs (default 3): ", "3"))
try:
for epoch in range(1, epochs + 1):
if is_interrupted(): break
print(f"\n--- Epoch {epoch}/{epochs} ---")
avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64,
optimizer=optimizer, criterion=criterion)
print(f"Avg Loss: {avg_loss:.4f}")
except KeyboardInterrupt:
print("\nInterrupted.")
if is_interrupted():
print("\n[!] Loop interrupted. Saving checkpoint...")
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '3':
if model is None:
path = select_model_file()
if path:
model, tokenizer = load_checkpoint(path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
if model is None:
print("No model loaded. Please Train or Load first.")
continue
print("\n--- Text Generation ---")
prompt = safe_input("Enter prompt (default: 'Once upon a time'): ", "Once upon a time")
length = int(safe_input("Length (default 100): ", "100"))
temp = float(safe_input("Temperature (default 0.8): ", "0.8"))
generated = generate_text(model, tokenizer, prompt, max_len=length, temperature=temp)
print(f"\n[OUTPUT]\n{generated}\n")
elif choice == '4':
print("Exiting.")
break
else:
print("Invalid choice.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment