Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Created January 3, 2026 06:13
Show Gist options
  • Select an option

  • Save amitpuri/5a21b7cdc6086990230a6878c5ad2000 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/5a21b7cdc6086990230a6878c5ad2000 to your computer and use it in GitHub Desktop.
"""
Transformer Architecture v6 (NLTK Tokenization)
==================================================================
Improved version of v5 using NLTK for word-level tokenization.
Includes interactive menu for Training, Saving, Loading, and Generation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import asyncio
import signal
import sys
import time
import glob
import nltk
from collections import Counter
from typing import Tuple, List, Dict
# Ensure NLTK data is available
try:
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
except Exception as e:
print(f"Warning: Failed to download NLTK data: {e}")
# Global flag for graceful shutdown
_interrupted = False
def signal_handler(signum, frame):
"""Handle interrupt signals gracefully"""
global _interrupted
_interrupted = True
print("\n\n[!] Interrupt received. Finishing current operation gracefully...")
print(" Press Ctrl+C again to force exit.")
signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt
def reset_interrupt():
"""Reset the interrupt flag"""
global _interrupted
_interrupted = False
def is_interrupted():
"""Check if interrupt was requested"""
return _interrupted
# Install signal handler
signal.signal(signal.SIGINT, signal_handler)
# ============================================
# 0. HELPER FUNCTIONS
# ============================================
def safe_input(prompt, default_value):
"""Safe input with default value"""
try:
user_input = input(prompt)
if not user_input.strip():
return default_value
return user_input
except EOFError:
return default_value
def batch_encode(tokenizer, text, batch_size=1000):
"""
Encode text in batches to show progress
"""
# Split text into lines or chunks (using newline which is common in TinyStories)
lines = text.split('\n')
total_lines = len(lines)
print(f" Encoding {total_lines} lines/chunks in batches of {batch_size}...")
all_tokens = []
for i in range(0, total_lines, batch_size):
batch = lines[i : i + batch_size]
# Join batch back for cleaner tokenization context or tokenize line by line
# Tokenizing line-by-line is safer for simple string split
for line in batch:
if line.strip():
all_tokens.extend(tokenizer.encode(line))
if (i // batch_size) % 10 == 0:
print(f" Processed line {min(i + batch_size, total_lines)}/{total_lines}...", end='\r')
print(f" Finished encoding. Total tokens: {len(all_tokens)}")
return all_tokens
def build_vocab_batched(text, max_vocab_size=10000, batch_size=1000):
"""
Build vocabulary from text in batches to show progress
"""
lines = text.split('\n')
total_lines = len(lines)
print(f" Tokenizing {total_lines} lines/chunks in batches of {batch_size} to build vocab...")
counter = Counter()
for i in range(0, total_lines, batch_size):
if is_interrupted():
print("\n [!] Interrupted during vocabulary building.")
return []
batch = lines[i : i + batch_size]
for line in batch:
if line.strip():
try:
tokens = nltk.word_tokenize(line)
counter.update(tokens)
except LookupError:
# Fallback
counter.update(line.split())
if (i // batch_size) % 10 == 0:
print(f" Processed line {min(i + batch_size, total_lines)}/{total_lines}...", end='\r')
# Get most common words
most_common = counter.most_common(max_vocab_size)
vocab = [word for word, count in most_common]
print(f"\n Vocabulary built with {len(vocab)} words (top frequency: {most_common[0][1] if most_common else 0})")
return vocab
# ============================================
# 1. ENHANCED TOKENIZATION (NLTK)
# ============================================
class NLTKTokenizer:
"""Word-level tokenizer using NLTK with dynamic vocabulary"""
SPECIAL_TOKENS = {
'<pad>': 0,
'<unk>': 1,
'<story>': 2, # Story generation mode
'<summary>': 3, # Summarization mode
'<instruct>': 4, # Instruction following mode
'<eos>': 5, # End of sequence
}
def __init__(self, vocab=None):
self.special_token_list = list(self.SPECIAL_TOKENS.keys())
if vocab is None:
# Default minimal vocab if none provided
self.word2id = self.SPECIAL_TOKENS.copy()
self.id2word = {v: k for k, v in self.word2id.items()}
self.vocab_size = len(self.word2id)
else:
# Vocab should be a list of words
self.word2id = self.SPECIAL_TOKENS.copy()
next_id = len(self.SPECIAL_TOKENS)
for word in vocab:
if word not in self.word2id:
self.word2id[word] = next_id
next_id += 1
self.id2word = {v: k for k, v in self.word2id.items()}
self.vocab_size = len(self.word2id)
def encode(self, text):
"""Convert text to token IDs using NLTK word_tokenize"""
try:
tokens = nltk.word_tokenize(text)
except LookupError:
# Fallback if punkt is missing
tokens = text.split()
return [self.word2id.get(token, self.word2id['<unk>']) for token in tokens]
def decode(self, ids):
"""Convert token IDs back to text"""
if torch.is_tensor(ids):
ids = ids.tolist()
words = [self.id2word.get(i, "<unk>") for i in ids]
# Simple detokenization (can be improved with TreebankWordDetokenizer)
from nltk.tokenize.treebank import TreebankWordDetokenizer
try:
return TreebankWordDetokenizer().detokenize(words)
except:
return " ".join(words)
def build_vocab_from_text(text, max_vocab_size=10000):
"""
Build vocabulary from a large text corpus.
Returns a list of words (most common first).
"""
print(" Tokenizing corpus to build vocabulary...")
try:
tokens = nltk.word_tokenize(text)
except LookupError:
print(" Warning: NLTK punkt not found, using split()")
tokens = text.split()
print(f" Found {len(tokens)} tokens. Counting frequencies...")
counter = Counter(tokens)
# Get most common words
most_common = counter.most_common(max_vocab_size)
vocab = [word for word, count in most_common]
print(f" Vocabulary built with {len(vocab)} words (top frequency: {most_common[0][1]})")
return vocab
# ============================================
# 2. TRANSFORMER COMPONENTS
# ============================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
"""Sinusoidal positional encoding"""
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, verbose=False):
"""x: [batch_size, seq_len, d_model]"""
return self.dropout(x + self.pe[:, :x.size(1)])
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, verbose=False):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.to(scores.device)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
output = self.W_o(context)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x, verbose=False):
hidden = self.fc1(x)
activated = F.relu(hidden)
dropped = self.dropout(activated)
output = self.fc2(dropped)
return output
class TransformerBlockV6(nn.Module):
"""Pre-LN Transformer Block (Standard for modern LLMs)"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.norm2 = LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, verbose=False):
# 1. Pre-LN Attention
x_norm = self.norm1(x)
attn_out = self.attention(x_norm, mask, verbose=verbose)
x = x + self.dropout(attn_out)
# 2. Pre-LN FFN
x_norm = self.norm2(x)
ffn_out = self.ffn(x_norm, verbose=verbose)
x = x + self.dropout(ffn_out)
return x
class TransformerV6(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=6,
d_ff=1024, max_seq_len=5000, dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.dropout = nn.Dropout(dropout)
self.transformer_blocks = nn.ModuleList([
TransformerBlockV6(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.final_norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
self._init_parameters()
def _init_parameters(self):
"""Initialize parameters"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_causal_mask(self, seq_len, device):
return torch.tril(torch.ones(seq_len, seq_len, device=device))
def forward(self, token_ids, verbose=False):
batch_size, seq_len = token_ids.shape
x = self.token_embedding(token_ids)
x = self.positional_encoding(x, verbose=verbose)
x = self.dropout(x)
causal_mask = self.generate_causal_mask(seq_len, token_ids.device)
for block in self.transformer_blocks:
x = block(x, mask=causal_mask, verbose=verbose)
x = self.final_norm(x)
logits = self.lm_head(x)
return logits
def get_num_params(self):
return sum(p.numel() for p in self.parameters())
def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0, verbose=False):
"""Top-p sampling for generation"""
with torch.no_grad():
logits = self.forward(token_ids, verbose=verbose)
last_logits = logits[:, -1, :]
# Apply temperature
last_logits = last_logits / max(temperature, 1e-5)
# Top-p (Nucleus) Sampling
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_mask = cumulative_probs > p
cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
cutoff_mask[..., 0] = False
sorted_logits[cutoff_mask] = float('-inf')
unsorted_logits = torch.full_like(last_logits, float('-inf'))
unsorted_logits.scatter_(1, sorted_indices, sorted_logits)
probs = F.softmax(unsorted_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token, probs
# ============================================
# 3. DATA LOADING & TRAINING
# ============================================
def load_tiny_stories(percent=1.0):
from datasets import load_dataset
print(f"Loading {percent}% of TinyStories dataset...")
dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
# Estimate count (approx 2M total)
total_count = 2000000
target_count = max(1, int(total_count * (percent / 100.0)))
print(f"Fetching approx {target_count} examples...")
texts = []
for i, entry in enumerate(dataset):
if i >= target_count:
break
texts.append(entry['text'])
if i % 100 == 0:
print(f" Fetched {i} stories...", end='\r')
print(f"\nLoaded {len(texts)} stories.")
return "\n".join(texts)
def train_epoch(model, tokenizer, data_ids, batch_size=32, seq_len=64, optimizer=None, criterion=None):
model.train()
total_loss = 0
num_batches = 0
n_tokens = len(data_ids)
# Simple sliding window batching
for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len):
if is_interrupted():
break
inputs_list = []
targets_list = []
for b in range(batch_size):
start_idx = i + b * seq_len
# Check bounds
if start_idx + seq_len + 1 >= n_tokens:
break
chunk = data_ids[start_idx : start_idx + seq_len + 1]
inputs_list.append(chunk[:-1])
targets_list.append(chunk[1:])
if len(inputs_list) == 0:
break
inputs = torch.tensor(inputs_list).to(model.token_embedding.weight.device)
targets = torch.tensor(targets_list).to(model.token_embedding.weight.device)
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if num_batches % 10 == 0:
print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\r')
print(f" Batch {num_batches}, Final Loss: {loss.item():.4f}")
return total_loss / max(1, num_batches)
def generate_text(model, tokenizer, prompt="Once upon a time", max_len=100, temperature=0.8):
model.eval()
tokens = tokenizer.encode(prompt)
input_ids = torch.tensor([tokens]).to(model.token_embedding.weight.device)
print(f"Generating (Prompt: '{prompt}')...")
for _ in range(max_len):
with torch.no_grad():
logits = model(input_ids)
last_logits = logits[:, -1, :] / temperature
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == tokenizer.SPECIAL_TOKENS['<eos>']:
break
return tokenizer.decode(input_ids[0])
# ============================================
# 4. CHECKPOINTING
# ============================================
def save_checkpoint(model, tokenizer, path):
# Extract only new words (excluding special tokens which are added by init)
# Actually, simpler to just save the ordered list of all tokens or just the vocab words
# NLTKTokenizer constructor takes 'vocab' which are words to ADD to special tokens.
# We want to reconstruct: word2id
# SPECIAL_TOKENS are positions 0-5.
# User words start at 6.
full_vocab_list = [tokenizer.id2word[i] for i in range(tokenizer.vocab_size)]
# Filter out special tokens
special_tokens = list(tokenizer.SPECIAL_TOKENS.keys())
saved_vocab = [w for w in full_vocab_list if w not in special_tokens]
checkpoint = {
'model_state': model.state_dict(),
'vocab_words': saved_vocab,
'config': {
'd_model': model.d_model,
'num_heads': model.num_heads,
'num_layers': model.num_layers,
'd_ff': model.d_ff,
'vocab_size': model.vocab_size,
}
}
torch.save(checkpoint, path)
print(f"Model and vocabulary saved to {path}")
def load_checkpoint(path):
if not os.path.exists(path):
return None, None
print(f"Loading checkpoint from {path}...")
checkpoint = torch.load(path)
vocab_words = checkpoint['vocab_words']
config = checkpoint['config']
# Reconstruct tokenizer
tokenizer = NLTKTokenizer(vocab=vocab_words)
# Initialize model
model = TransformerV6(
vocab_size=tokenizer.vocab_size,
d_model=config['d_model'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
d_ff=config.get('d_ff', 1024),
)
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model, tokenizer
def select_model_file(default_path=None):
"""Allow user to select a model file"""
os.makedirs("models_cache", exist_ok=True)
files = glob.glob("models_cache/model_v6*.pth")
files.sort(key=os.path.getmtime, reverse=True)
if not files:
print("No existing v6 model files found.")
return default_path
print("\nAvailable Models:")
for i, f in enumerate(files):
size_mb = os.path.getsize(f) / (1024 * 1024)
print(f"{i+1}. {f} ({size_mb:.2f} MB)")
print(f"{len(files)+1}. Cancel / Use New Name ({default_path})")
choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1")
try:
idx = int(choice) - 1
if 0 <= idx < len(files):
return files[idx]
except ValueError:
pass
return default_path
# ============================================
# 5. MAIN EXECUTION MENU
# ============================================
def main():
print("Transformer Architecture v6 (NLTK Tokenization)")
print("=============================================")
timestamp = time.strftime("%Y%m%d_%H%M%S")
os.makedirs("models_cache", exist_ok=True)
MODEL_PATH = f"models_cache/model_v6_{timestamp}.pth"
# State variables
model = None
tokenizer = None
while True:
reset_interrupt()
print("\n\nMAIN MENU")
print("1. Train New Model (TinyStories)")
print("2. Load Model & Resume Training")
print("3. Generate Text")
print("4. Exit")
choice = safe_input("\nEnter choice (1-4): ", "4")
if choice == '1':
# --- TRAIN NEW ---
percent_str = safe_input("Enter percentage of TinyStories to load (0.01-100, default 0.5): ", "0.5")
try:
percent = float(percent_str)
except:
percent = 0.5
print("\n[1] Loading Data...")
raw_text = load_tiny_stories(percent=percent)
print("\n[2] Building Vocabulary...")
vocab_size_str = safe_input("Max vocab size (default 10000): ", "10000")
# Use batched vocab builder
vocab_words = build_vocab_batched(raw_text, max_vocab_size=int(vocab_size_str), batch_size=1000)
if not vocab_words:
print("Vocabulary is empty or interrupted. Exiting training setup.")
continue
tokenizer = NLTKTokenizer(vocab=vocab_words)
print(f"Tokenizer ready. Vocab size: {tokenizer.vocab_size}")
print("\n[3] Initializing Model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on: {device}")
model = TransformerV6(
vocab_size=tokenizer.vocab_size,
d_model=256,
num_heads=8,
num_layers=4,
d_ff=1024
).to(device)
print(f"Params: {model.get_num_params():,}")
print("\n[4] Encoding Data...")
# Use batch encoding for progress
train_ids = batch_encode(tokenizer, raw_text, batch_size=1000)
print("\n[5] Training Loop (Press Ctrl+C to stop & save)")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>'])
epochs = int(safe_input("Epochs (default 5): ", "5"))
try:
for epoch in range(1, epochs + 1):
if is_interrupted(): break
print(f"\n--- Epoch {epoch}/{epochs} ---")
avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64,
optimizer=optimizer, criterion=criterion)
print(f"Avg Loss: {avg_loss:.4f}")
if epoch % 2 == 0:
save_checkpoint(model, tokenizer, MODEL_PATH)
except KeyboardInterrupt:
print("\nInterrupted.")
if is_interrupted():
print("\n[!] Loop interrupted. Saving checkpoint...")
# Save at the end
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '2':
# --- LOAD & RESUME ---
path = select_model_file()
if not path:
continue
model, tokenizer = load_checkpoint(path)
if not model:
print("Failed to load model.")
continue
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
MODEL_PATH = path # Continue saving to same file
print("\nNote: To resume training, we need data.")
percent_str = safe_input("Enter percentage of TinyStories to load for training (default 0.5): ", "0.5")
raw_text = load_tiny_stories(percent=float(percent_str))
print("Encoding data with loaded tokenizer...")
train_ids = batch_encode(tokenizer, raw_text, batch_size=1000)
print("\nResuming Training (Press Ctrl+C to stop & save)")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>'])
epochs = int(safe_input("Additional Epochs (default 5): ", "5"))
try:
for epoch in range(1, epochs + 1):
if is_interrupted(): break
print(f"\n--- Epoch {epoch}/{epochs} ---")
avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64,
optimizer=optimizer, criterion=criterion)
print(f"Avg Loss: {avg_loss:.4f}")
except KeyboardInterrupt:
print("\nInterrupted.")
if is_interrupted():
print("\n[!] Loop interrupted. Saving checkpoint...")
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '3':
# --- GENERATE ---
if model is None:
path = select_model_file()
if path:
model, tokenizer = load_checkpoint(path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
if model is None:
print("No model loaded. Please Train or Load first.")
continue
print("\n--- Text Generation ---")
print("1. Custom Prompt")
print("2. Preset: 'Once upon a time'")
print("3. Preset: 'The little dog'")
p_choice = safe_input("Choice (1-3): ", "2")
if p_choice == '1':
prompt = safe_input("Enter prompt: ", "Once upon a time")
elif p_choice == '3':
prompt = "The little dog"
else:
prompt = "Once upon a time"
length = int(safe_input("Length (default 100): ", "100"))
temp = float(safe_input("Temperature (default 0.8): ", "0.8"))
generated = generate_text(model, tokenizer, prompt, max_len=length, temperature=temp)
print(f"\n[OUTPUT]\n{generated}\n")
elif choice == '4':
print("Exiting.")
break
else:
print("Invalid choice.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment