Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Last active January 3, 2026 06:11
Show Gist options
  • Select an option

  • Save amitpuri/3390df3f802729975990520ff531a868 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/3390df3f802729975990520ff531a868 to your computer and use it in GitHub Desktop.
"""
Transformer Architecture v4
==================================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import asyncio
import signal
import sys
import time
import glob
from typing import Tuple, List, Dict
# Global flag for graceful shutdown
_interrupted = False
def signal_handler(signum, frame):
"""Handle interrupt signals gracefully"""
global _interrupted
_interrupted = True
print("\n\n[!] Interrupt received. Finishing current operation gracefully...")
print(" Press Ctrl+C again to force exit.")
signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt
def reset_interrupt():
"""Reset the interrupt flag"""
global _interrupted
_interrupted = False
def is_interrupted():
"""Check if interrupt was requested"""
return _interrupted
# Install signal handler
signal.signal(signal.SIGINT, signal_handler)
# ============================================
# 1. ENHANCED TOKENIZATION
# ============================================
class CharTokenizer:
"""Character-level tokenizer with special tokens for instructions"""
SPECIAL_TOKENS = {
'<pad>': 0,
'<unk>': 1,
'<story>': 2, # Story generation mode
'<summary>': 3, # Summarization mode
'<dialog>': 4, # Dialogue mode
'<instruct>': 5, # Instruction following mode
'<eos>': 6, # End of sequence
}
def __init__(self, text=None, chars=None):
if chars is not None:
self.chars = chars
elif text is not None:
self.chars = sorted(list(set(text)))
else:
raise ValueError("Must provide either text or chars")
# Add special tokens
self.special_token_list = list(self.SPECIAL_TOKENS.keys())
self.chars = self.special_token_list + [c for c in self.chars if c not in self.special_token_list]
self.vocab_size = len(self.chars)
self.char2id = {ch: i for i, ch in enumerate(self.chars)}
self.id2char = {i: ch for i, ch in enumerate(self.chars)}
def encode(self, text):
"""Convert text to token IDs"""
return [self.char2id.get(c, self.char2id['<unk>']) for c in text]
def decode(self, ids):
"""Convert token IDs back to text"""
if torch.is_tensor(ids):
ids = ids.tolist()
return "".join([self.id2char.get(i, "?") for i in ids])
def encode_with_task(self, text, task_token='<story>'):
"""Encode text with task prefix"""
task_id = self.char2id.get(task_token, self.char2id['<story>'])
text_ids = self.encode(text)
return [task_id] + text_ids
def get_special_token_id(self, token_name):
"""Get ID of special token"""
return self.char2id.get(token_name, self.char2id['<unk>'])
# ============================================
# 2. POSITIONAL ENCODING
# ============================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
"""Sinusoidal positional encoding"""
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, verbose=False):
"""x: [batch_size, seq_len, d_model]"""
if verbose:
print("\n [Positional Encoding] Adding sinusoidal position info to embeddings")
print(" - Intuition: Transformer sees all words at once; this injects order info (like page numbers)")
print(f" - Input shape: {x.shape}")
print(f" - Formula: PE(pos, 2i) = sin(pos/10000^(2i/d_model))")
print(f" - Formula: PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))")
print(f" - Sample PE (first 8 dims for pos 0): {self.pe[0, :8].tolist()}")
return self.dropout(x + self.pe[:, :x.size(1)])
# ============================================
# 3. LAYER NORMALIZATION
# ============================================
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
"""Normalize across embedding dimension"""
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# ============================================
# 4. MULTI-HEAD SELF-ATTENTION WITH CAUSAL MASK
# ============================================
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, verbose=False):
"""Multi-head self-attention with optional causal masking"""
batch_size, seq_len, _ = x.shape
if verbose:
print(f"\n [Multi-Head Attention] Splitting into {self.num_heads} heads (d_model={self.d_model} -> d_k={self.d_k})")
print(" - Intuition: 'Looking around' the sentence to gather context (e.g., matching 'it' to 'dog')")
print(" - Step 1: Linear projections for Queries (Q), Keys (K), Values (V)")
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
if verbose:
print(f" - Q, K, V shapes: {Q.shape} (batch, heads, seq, d_k)")
print(f" - Sample Q (head 0, first token, first 4 dims): {Q[0, 0, 0, :4].tolist()}")
print(" - Step 2: Scaled Dot-Product Attention (Score = Q · K^T / √d_k)")
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
if verbose:
print(" - Masked Attention: Hiding future tokens (causal mask)")
print(f" - Scores before mask (head 0, row 0, first 4 cols): {scores[0, 0, 0, :4].tolist()}")
mask = mask.to(scores.device)
scores = scores.masked_fill(mask == 0, float('-inf'))
if verbose:
print(f" - Scores after mask (head 0, row 0, first 4 cols): {scores[0, 0, 0, :4].tolist()}")
if verbose:
print(" - Step 4: Softmax -> Attention Weights (Probability distribution)")
print(" - Intuition: Deciding 'how much' to focus on each word (sum of weights = 1.0)")
attention_weights = F.softmax(scores, dim=-1)
if verbose:
print(f" - Sample Weights (head 0, row 0, first 4 cols): {attention_weights[0, 0, 0, :4].tolist()}")
attention_weights = self.dropout(attention_weights)
if verbose:
print(" - Step 5: Weighted sum of Values (Context aggregation)")
print(f" - Formula: Context = Sum(Weight_i * Value_i)")
context = torch.matmul(attention_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
if verbose:
print(" - Concatenating heads and final linear projection")
output = self.W_o(context)
return output
# ============================================
# 5. FEED-FORWARD NETWORK
# ============================================
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x, verbose=False):
if verbose:
print(f"\n [Feed-Forward] Expansion (d_model={x.shape[-1]} -> d_ff={self.fc1.out_features} -> d_model={x.shape[-1]})")
print(" - Intuition: The 'brain' processing the gathered context to extract complex features")
print(f" - Formula: ReLU(xW_1 + b_1)W_2 + b_2")
hidden = self.fc1(x)
activated = F.relu(hidden)
if verbose:
print(f" - Sample intermediate (after ReLU, first 4 dims): {activated[0, 0, :4].tolist()}")
dropped = self.dropout(activated)
output = self.fc2(dropped)
return output
# ============================================
# 6. TRANSFORMER BLOCKS (POST-LN & PRE-LN)
# ============================================
class TransformerBlockPostLN(nn.Module):
"""Post-LN: Attention -> Residual -> LN -> FFN -> Residual -> LN"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm2 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, verbose=False):
if verbose:
print("\n[Transformer Block (Post-LN)]")
attn_out = self.attention(x, mask, verbose=verbose)
if verbose:
print(" [Residual + Norm] Adding attention output to input and normalizing")
x = self.norm1(x + self.dropout(attn_out))
if verbose:
print(" [Feed-Forward Network] Point-wise expansion and compression")
ffn_out = self.ffn(x, verbose=verbose)
if verbose:
print(" [Residual + Norm] Adding FFN output and normalizing")
x = self.norm2(x + self.dropout(ffn_out))
return x
class TransformerBlockPreLN(nn.Module):
"""Pre-LN: LN -> Attention -> Residual -> LN -> FFN -> Residual"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.norm2 = LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, verbose=False):
if verbose:
print("\n[Transformer Block (Pre-LN)]")
print(" [Layer Norm] Normalizing input before attention (Pre-LN benefit: stable gradients)")
x_norm = self.norm1(x)
attn_out = self.attention(x_norm, mask, verbose=verbose)
if verbose:
print(" [Residual] Adding attention output to original input path")
x = x + self.dropout(attn_out)
if verbose:
print(" [Layer Norm] Normalizing before FFN")
x_norm = self.norm2(x)
if verbose:
print(" [Feed-Forward] Processing features")
ffn_out = self.ffn(x_norm, verbose=verbose)
if verbose:
print(" [Residual] Final skip connection")
x = x + self.dropout(ffn_out)
return x
# ============================================
# 7. SAMPLING METHODS
# ============================================
def top_k_sampling(logits, k=10, temperature=1.0):
"""Top-k sampling"""
logits = logits / max(temperature, 1e-5)
topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(1, topk_indices, topk_logits)
probs = F.softmax(mask, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token, probs
def top_p_sampling(logits, p=0.9, temperature=1.0):
"""Top-p (nucleus) sampling"""
logits = logits / max(temperature, 1e-5)
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_mask = cumulative_probs > p
cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
cutoff_mask[..., 0] = False
sorted_logits[cutoff_mask] = float('-inf')
unsorted_logits = torch.full_like(logits, float('-inf'))
unsorted_logits.scatter_(1, sorted_indices, sorted_logits)
probs = F.softmax(unsorted_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token, probs
# ============================================
# 8. TRANSFORMER MODEL
# ============================================
class TransformerV4(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=4,
d_ff=1024, max_seq_len=5000, dropout=0.1, use_pre_ln=False):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.dropout = nn.Dropout(dropout)
BlockClass = TransformerBlockPreLN if use_pre_ln else TransformerBlockPostLN
self.transformer_blocks = nn.ModuleList([
BlockClass(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
if use_pre_ln:
self.final_norm = LayerNorm(d_model)
else:
self.final_norm = None
self.lm_head = nn.Linear(d_model, vocab_size)
self._init_parameters()
def _init_parameters(self):
"""Initialize parameters with Xavier uniform"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_causal_mask(self, seq_len, device):
"""Create lower triangular causal mask"""
return torch.tril(torch.ones(seq_len, seq_len, device=device))
def forward(self, token_ids, verbose=False):
"""Forward pass"""
batch_size, seq_len = token_ids.shape
if verbose:
print(f"\n--- TRANSFORMER FORWARD PASS (Batch: {batch_size}, Seq Len: {seq_len}) ---")
print("1. [Embedding] Converting token IDs to vectors")
print(" - Intuition: Looking up the 'meaning' vector for each specific word ID")
x = self.token_embedding(token_ids)
x = self.positional_encoding(x, verbose=verbose)
x = self.dropout(x)
causal_mask = self.generate_causal_mask(seq_len, token_ids.device)
if verbose:
print(f"2. [Stacking] Passing through {self.num_layers} Transformer blocks")
for i, block in enumerate(self.transformer_blocks):
if verbose:
print(f"\n--- Block {i+1}/{self.num_layers} ---")
x = block(x, mask=causal_mask, verbose=verbose)
if self.final_norm is not None:
if verbose:
print("\n3. [Final Norm] Stabilizing output features")
x = self.final_norm(x)
if verbose:
print("\n4. [Output Head] Projecting to vocabulary size (Logits)")
logits = self.lm_head(x)
if verbose:
print(f" Output shape: {logits.shape} (Batch, Seq, Vocab)")
print("--- END FORWARD PASS ---\n")
return logits
def predict_next_token_greedy(self, token_ids, verbose=False):
"""Greedy decoding"""
with torch.no_grad():
logits = self.forward(token_ids, verbose=verbose)
last_logits = logits[:, -1, :]
probs = F.softmax(last_logits, dim=-1)
next_token_ids = torch.argmax(probs, dim=-1)
return next_token_ids, probs
def predict_next_token_topk(self, token_ids, k=10, temperature=1.0, verbose=False):
"""Top-k sampling"""
with torch.no_grad():
logits = self.forward(token_ids, verbose=verbose)
last_logits = logits[:, -1, :]
next_token_ids, probs = top_k_sampling(last_logits, k=k, temperature=temperature)
return next_token_ids.squeeze(-1), probs
def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0, verbose=False):
"""Top-p sampling"""
with torch.no_grad():
logits = self.forward(token_ids, verbose=verbose)
last_logits = logits[:, -1, :]
next_token_ids, probs = top_p_sampling(last_logits, p=p, temperature=temperature)
return next_token_ids.squeeze(-1), probs
# ============================================
# 10. TRAINING UTILITIES
# ============================================
def train_model_basic(model, tokenizer, text, epochs=1000, batch_size=4, seq_len=64, lr=None, verbose=False):
"""Basic training with warmup schedule and label smoothing"""
# 1. Adam with specific betas and eps
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9)
# 2. Label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 3. Warmup scheduler
d_model = model.d_model
warmup_steps = 4000
# lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
def lr_lambda(step):
step = step + 1 # 1-indexed for formula
return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
ids = tokenizer.encode(text)
model.train()
print(f"Training with: Adam(betas=(0.9, 0.98)), LabelSmoothing=0.1, WarmupSteps={warmup_steps}")
# Show verbose output for first batch of first epoch if requested
first_step_verbose = verbose
for epoch in range(epochs):
batch_inputs = []
batch_targets = []
for _ in range(batch_size):
if len(ids) <= seq_len + 1:
start_idx = 0
curr_seq_len = len(ids) - 1
else:
start_idx = torch.randint(0, len(ids) - seq_len - 1, (1,)).item()
curr_seq_len = seq_len
batch_inputs.append(ids[start_idx : start_idx + curr_seq_len])
batch_targets.append(ids[start_idx + 1 : start_idx + curr_seq_len + 1])
max_len = max(len(inp) for inp in batch_inputs)
inputs = torch.zeros(len(batch_inputs), max_len, dtype=torch.long)
targets = torch.zeros(len(batch_targets), max_len, dtype=torch.long)
for i, (inp, tgt) in enumerate(zip(batch_inputs, batch_targets)):
inputs[i, :len(inp)] = torch.tensor(inp)
targets[i, :len(tgt)] = torch.tensor(tgt)
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
if epoch % 100 == 0:
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}")
# ============================================
# 11. DATASET SELECTION MENU
# ============================================
def safe_input(prompt, default=""):
"""Safe input with keyboard interrupt handling"""
try:
return input(prompt)
except (KeyboardInterrupt, EOFError):
print("\n[Input cancelled]")
return default
# Load datasets from config file
DATASETS_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "datasets_config.json")
def load_datasets_config():
"""Load datasets configuration from JSON file"""
import json
if not os.path.exists(DATASETS_CONFIG_PATH):
print(f"Warning: Config file not found at {DATASETS_CONFIG_PATH}")
return []
try:
with open(DATASETS_CONFIG_PATH, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get("datasets", [])
except Exception as e:
print(f"Error loading config: {e}")
return []
# Load datasets at module import
DATASETS = load_datasets_config()
def get_datasets_by_category(category):
"""Get datasets filtered by category"""
return [ds for ds in DATASETS if ds.get("category") == category]
def select_dataset():
"""
Interactive dataset selection menu with graceful interrupt handling
Returns: (dataset_name, text_field, dataset_type)
"""
if not DATASETS:
print("No datasets found. Check datasets_config.json")
return None, None, None
print("\n" + "=" * 60)
print("DATASET SELECTION")
print("=" * 60)
for i, ds in enumerate(DATASETS, 1):
category_tag = f"[{ds.get('category', 'unknown')}]"
print(f"{i}. {ds['name']} {category_tag}")
print(f" {ds['description']}")
print(f"{len(DATASETS) + 1}. Cancel")
ds_choice = safe_input(f"\nSelect dataset (1-{len(DATASETS) + 1}): ", "")
try:
idx = int(ds_choice) - 1
if idx == len(DATASETS): # Cancel option
return None, None, None
if 0 <= idx < len(DATASETS):
selected = DATASETS[idx]
return selected['name'], selected['text_field'], selected.get('category', 'unknown')
except ValueError:
pass
print("Invalid selection")
return None, None, None
# Local dataset cache directory
DATASET_CACHE_DIR = os.path.join(os.path.dirname(__file__), "datasets_cache")
def get_local_dataset_path(dataset_name):
"""Get the local path for a cached dataset"""
safe_name = dataset_name.replace("/", "_").replace("\\", "_")
return os.path.join(DATASET_CACHE_DIR, f"{safe_name}.json")
async def download_dataset_async(dataset_name, text_field, num_examples=10000, batch_size=100):
"""
Async download dataset from HuggingFace with progress tracking
Args:
dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb')
text_field: The field containing text data
num_examples: Number of examples to download
batch_size: Number of examples to collect before yielding
Returns:
Path to the local dataset file
"""
import json
from datasets import load_dataset
reset_interrupt()
# Create cache directory if it doesn't exist
os.makedirs(DATASET_CACHE_DIR, exist_ok=True)
local_path = get_local_dataset_path(dataset_name)
# Check if already downloaded
if os.path.exists(local_path):
print(f"Dataset already cached at: {local_path}")
redownload = safe_input("Re-download? (y/n, default: n): ", "n").lower() == 'y'
if not redownload:
return local_path
print(f"Downloading {num_examples} examples from {dataset_name}...")
print("[Press Ctrl+C to stop and save partial download]")
try:
dataset = load_dataset(dataset_name, split="train", streaming=True)
except Exception as e:
print(f"Error loading dataset: {e}")
print("Trying with 'sample' config...")
try:
dataset = load_dataset(dataset_name, "sample", split="train", streaming=True)
except Exception as e2:
print(f"Failed to load dataset: {e2}")
return None
data = []
batch = []
for i, entry in enumerate(dataset):
# Check for interrupt
if is_interrupted():
print(f"\n[!] Interrupt detected at {i} examples. Saving partial download...")
break
text = entry.get(text_field, "")
if text:
batch.append({"text": text})
# Yield control periodically to allow interrupt checking
if len(batch) >= batch_size:
data.extend(batch)
batch = []
await asyncio.sleep(0) # Yield to event loop
if i >= num_examples - 1:
break
if (i + 1) % 1000 == 0:
print(f" Downloaded {i + 1}/{num_examples} examples...")
# Add remaining batch
data.extend(batch)
if len(data) == 0:
print("No data downloaded.")
return None
# Save to local JSON file
with open(local_path, 'w', encoding='utf-8') as f:
json.dump({
"dataset_name": dataset_name,
"text_field": text_field,
"num_examples": len(data),
"data": data
}, f, ensure_ascii=False, indent=2)
print(f"Saved {len(data)} examples to: {local_path}")
reset_interrupt()
return local_path
def download_dataset(dataset_name, text_field, num_examples=10000):
"""
Synchronous wrapper for async download
Args:
dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb')
text_field: The field containing text data
num_examples: Number of examples to download
Returns:
Path to the local dataset file
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# Already in async context - use nest_asyncio or create a new thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(
asyncio.run,
download_dataset_async(dataset_name, text_field, num_examples)
)
return future.result()
else:
# Not in async context - run directly
return asyncio.run(download_dataset_async(dataset_name, text_field, num_examples))
def load_local_dataset(dataset_name, percent=100.0):
"""
Load dataset from local cache
Args:
dataset_name: Full dataset name (to find the cached file)
percent: Percentage of cached data to use
Returns:
Combined text from the local dataset
"""
import json
local_path = get_local_dataset_path(dataset_name)
if not os.path.exists(local_path):
print(f"Dataset not found locally: {local_path}")
print("Please download the dataset first.")
return None
print(f"Loading from local cache: {local_path}")
with open(local_path, 'r', encoding='utf-8') as f:
cached_data = json.load(f)
all_data = cached_data.get("data", [])
total_examples = len(all_data)
# Calculate how many examples to use
num_to_use = max(1, int(total_examples * (percent / 100.0)))
data_subset = all_data[:num_to_use]
# Extract text from each entry
texts = [entry.get("text", "") for entry in data_subset if entry.get("text")]
combined_text = "\n\n".join(texts)
print(f"Loaded {len(texts)} examples ({percent}% of {total_examples} cached) - {len(combined_text)} characters")
return combined_text
def load_selected_dataset(dataset_name, text_field, percent=0.001):
"""
Load a dataset - downloads if not cached, then loads from local cache
Args:
dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb')
text_field: The field containing text data
percent: Percentage of dataset to load (used for download count estimation)
Returns:
Combined text from the dataset
"""
# Estimate counts for datasets
estimated_counts = {
"HuggingFaceFW/fineweb": 52500000000,
"HuggingFaceFW/fineweb-edu": 3500000000,
"HuggingFaceFW/fineweb-edu-score-2": 13900000000,
"HuggingFaceFW/fineweb-2": 4480000000,
"HuggingFaceFW/finewiki": 61600000,
"HuggingFaceFW/clean-wikipedia": 61200000,
"HuggingFaceFW/finepdfs": 476000000,
"HuggingFaceFW/finepdfs-edu": 49500000,
"HuggingFaceFW/ocr-annotations": 1620,
"roneneldan/TinyStories": 2141709,
"roneneldan/TinyStoriesInstruct": 21974061,
}
local_path = get_local_dataset_path(dataset_name)
# Check if dataset exists locally
if os.path.exists(local_path):
print(f"Found cached dataset: {local_path}")
use_cached = input("Use cached dataset? (y/n, default: y): ").lower() != 'n'
if use_cached:
use_percent = float(input("Percentage of cached data to use (default: 100): ") or "100")
return load_local_dataset(dataset_name, use_percent)
# Dataset not cached - need to download
print(f"Dataset not found in cache. Downloading...")
total_count = estimated_counts.get(dataset_name, 1000000)
default_count = max(1, int(total_count * (percent / 100.0)))
default_count = min(default_count, 50000) # Cap at 50k
num_examples = int(input(f"Number of examples to download (default: {default_count}): ") or str(default_count))
local_path = download_dataset(dataset_name, text_field, num_examples)
if local_path is None:
return None
use_percent = float(input("Percentage of downloaded data to use for training (default: 100): ") or "100")
return load_local_dataset(dataset_name, use_percent)
# ============================================
# 12. CHECKPOINTING
# ============================================
def save_checkpoint(model, tokenizer, path="models_cache/model_v4.pth"):
checkpoint = {
'model_state': model.state_dict(),
'vocab': tokenizer.chars,
'vocab_size': tokenizer.vocab_size,
'config': {
'd_model': model.d_model,
'num_heads': model.num_heads,
'num_layers': model.num_layers,
'd_ff': model.d_ff,
'vocab_size': model.vocab_size,
'use_pre_ln': model.final_norm is not None,
}
}
torch.save(checkpoint, path)
print(f"Model saved to {path}")
def load_checkpoint(path="models_cache/model_v4.pth"):
if not os.path.exists(path):
return None, None
checkpoint = torch.load(path)
vocab = checkpoint['vocab']
config = checkpoint['config']
tokenizer = CharTokenizer(chars=vocab)
model = TransformerV4(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
d_ff=config.get('d_ff', 1024),
use_pre_ln=config.get('use_pre_ln', False),
)
model.load_state_dict(checkpoint['model_state'])
model.eval()
print(f"Model loaded from {path}")
return model, tokenizer
# ============================================
# 13. INTERACTIVE DEMO
# ============================================
if __name__ == "__main__":
# Generate unique model name for this execution
timestamp = time.strftime("%Y%m%d_%H%M%S")
os.makedirs("models_cache", exist_ok=True)
MODEL_PATH = f"models_cache/model_v4_{timestamp}.pth"
# Global verbose flag
VERBOSE_MODE = False
def select_model_file(default_path=None):
"""Allow user to select a model file from current directory"""
os.makedirs("models_cache", exist_ok=True)
files = glob.glob("models_cache/model_v4*.pth")
files.sort(key=os.path.getmtime, reverse=True)
if not files:
print("No existing model files found.")
return default_path
print("\nAvailable Models:")
for i, f in enumerate(files):
size_mb = os.path.getsize(f) / (1024 * 1024)
print(f"{i+1}. {f} ({size_mb:.2f} MB)")
print(f"{len(files)+1}. Cancel / Use Default ({default_path})")
choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1")
try:
idx = int(choice) - 1
if 0 <= idx < len(files):
return files[idx]
except ValueError:
pass
return default_path
print("\n[Tip: Press Ctrl+C at any time to gracefully interrupt operations]")
while True:
reset_interrupt() # Reset interrupt flag at start of each loop
print("\n" + "=" * 60)
print("TRANSFORMER v4")
print("=" * 60)
print("1. Train with Dataset Selection")
print("2. Load & Retrain Existing Model")
print("3. Generate Text from Model")
print("4. Educational Step-by-Step Demo")
print(f"5. Toggle Verbose Mode [Current: {'ON' if VERBOSE_MODE else 'OFF'}]")
print("6. Exit")
choice = safe_input("\nEnter choice (1-6): ", "6")
if choice == '1':
# Dataset selection and training
dataset_name, text_field, dataset_type = select_dataset()
if dataset_name is None:
print("Dataset selection cancelled.")
continue
percent_str = safe_input("Enter percentage to load (default: 0.001): ", "0.001")
try:
percent = float(percent_str)
except ValueError:
percent = 0.001
print(f"\nLoading dataset: {dataset_name}")
text = load_selected_dataset(dataset_name, text_field, percent)
if text is None or len(text) < 100:
print("Failed to load dataset or dataset too small.")
continue
tokenizer = CharTokenizer(text=text)
use_pre_ln = safe_input("Use Pre-LN? (y/n, default: n): ", "n").lower() == 'y'
model = TransformerV4(
vocab_size=tokenizer.vocab_size,
d_model=256, num_heads=8, num_layers=4, d_ff=1024,
use_pre_ln=use_pre_ln
)
print(f"\nTraining on {dataset_name} (Pre-LN={use_pre_ln})...")
print("[Press Ctrl+C to stop training and save current progress]")
train_model_basic(model, tokenizer, text, epochs=2000, seq_len=64, verbose=VERBOSE_MODE)
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '2':
# Select model to load
load_path = select_model_file(MODEL_PATH)
if not load_path or not os.path.exists(load_path):
print(f"Model file not found: {load_path}")
continue
model, tokenizer = load_checkpoint(load_path)
if model is None:
print("No model found. Train first.")
continue
# Use dataset selection for retraining too
dataset_name, text_field, dataset_type = select_dataset()
if dataset_name is None:
print("Dataset selection cancelled.")
continue
percent_str = safe_input("Enter percentage (default: 0.0005): ", "0.0005")
try:
percent = float(percent_str)
except ValueError:
percent = 0.0005
print(f"\nLoading dataset: {dataset_name}")
text = load_selected_dataset(dataset_name, text_field, percent)
if text is None or len(text) < 100:
print("Failed to load dataset or dataset too small.")
continue
print(f"\nRetraining on {dataset_name}...")
print("[Press Ctrl+C to stop training and save current progress]")
# Save to the NEW unique path for this session, not overwriting the old one
print(f"Will save updated model to: {MODEL_PATH}")
train_model_basic(model, tokenizer, text, epochs=500, seq_len=64, verbose=VERBOSE_MODE)
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '3':
# Select model to load
load_path = select_model_file(MODEL_PATH)
if not load_path or not os.path.exists(load_path):
print(f"Model file not found: {load_path}")
continue
model, tokenizer = load_checkpoint(load_path)
if model is None:
print("No model found. Train first.")
continue
print("\nGeneration Options:")
print("1. Free generation")
print("2. Prompt-guided (TinyStoriesInstruct style)")
gen_choice = safe_input("Choose (1-2): ", "1")
if gen_choice == '1':
prompt = safe_input("Enter prompt: ", "Once upon a time")
else:
print("\nSample story prompts:")
sample_prompts = [
"Once upon a time",
"a quick brown fox",
"One day, a little girl",
"The sun was shining",
]
for i, p in enumerate(sample_prompts):
print(f" {i+1}. {p}")
choice_str = safe_input("Select (1-4) or 0 for custom: ", "1")
try:
choice_idx = int(choice_str) - 1
except ValueError:
choice_idx = 0
prompt = sample_prompts[choice_idx] if 0 <= choice_idx < len(sample_prompts) else safe_input("Custom: ", "Once upon a time")
temp_str = safe_input("Temperature (0.1-2.0, default 0.8): ", "0.8")
try:
temp = float(temp_str)
except ValueError:
temp = 0.8
print(f"\nGenerating from: '{prompt}'")
token_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([token_ids])
generated = prompt
# Verbose for first token generation only if requested
first_gen_verbose = VERBOSE_MODE
for _ in range(100):
if is_interrupted():
print("\n[Generation interrupted]")
break
next_id, _ = model.predict_next_token_topp(
input_tensor,
p=0.9,
temperature=temp,
verbose=first_gen_verbose
)
if first_gen_verbose:
first_gen_verbose = False
next_char = tokenizer.id2char.get(next_id.item(), " ")
generated += next_char
token_ids.append(next_id.item())
input_tensor = torch.tensor([token_ids[-64:]])
print(f"\n[GENERATED]\n{generated}\n")
print(f"\n[GENERATED]\n{generated}\n")
elif choice == '4':
demonstrate_steps()
elif choice == '5':
VERBOSE_MODE = not VERBOSE_MODE
print(f"\nVerbose Mode is now {'ON' if VERBOSE_MODE else 'OFF'}")
elif choice == '6':
print("Exiting...")
break
else:
print("Invalid choice")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment