Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Created January 3, 2026 05:50
Show Gist options
  • Select an option

  • Save amitpuri/f06b192295c4caffd635ed472d5108e4 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/f06b192295c4caffd635ed472d5108e4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
# ============================================
# 1. TOKENIZATION (using simple mapping)
# ============================================
class SimpleTokenizer:
def __init__(self, vocab):
self.word2id = {word: idx for idx, word in enumerate(vocab)}
self.id2word = {idx: word for word, idx in self.word2id.items()}
self.vocab_size = len(vocab)
def encode(self, text):
"""Convert text to token IDs"""
tokens = text.lower().split()
return [self.word2id.get(token, self.word2id.get("<unk>", 0)) for token in tokens]
def decode(self, token_ids):
"""Convert token IDs back to text"""
return [self.id2word[idx] for idx in token_ids]
class CharTokenizer:
def __init__(self, text=None, chars=None):
if chars is not None:
self.chars = chars
elif text is not None:
self.chars = sorted(list(set(text)))
else:
raise ValueError("Must provide either text or chars")
self.vocab_size = len(self.chars)
self.char2id = {ch: i for i, ch in enumerate(self.chars)}
self.id2char = {i: ch for i, ch in enumerate(self.chars)}
def encode(self, text):
return [self.char2id.get(c, 0) for c in text]
def decode(self, ids):
if torch.is_tensor(ids):
ids = ids.tolist()
return "".join([self.id2char.get(i, "?") for i in ids])
# ============================================
# 2. POSITIONAL ENCODING
# ============================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000):
"""
Args:
d_model: embedding dimension
max_seq_len: maximum sequence length
"""
super().__init__()
# Create positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
# sin for even indices, cos for odd
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register as buffer (not a parameter, but part of the model state)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""
Args:
x: embeddings of shape [batch_size, seq_len, d_model]
Returns:
x with positional encoding added
"""
return x + self.pe[:, :x.size(1)]
# ============================================
# 3. LAYER NORMALIZATION
# ============================================
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
"""
Args:
d_model: dimension size
eps: small constant for numerical stability
"""
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
Returns:
normalized output
"""
# Calculate mean and variance along the last dimension
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalize: (x - mean) / sqrt(var + eps)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift: gamma * x_norm + beta
return self.gamma * x_norm + self.beta
# ============================================
# 4. MULTI-HEAD SELF-ATTENTION
# ============================================
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Args:
d_model: embedding dimension
num_heads: number of attention heads
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
mask: optional attention mask
Returns:
attention output
"""
batch_size, seq_len, _ = x.shape
# Linear projections
Q = self.W_q(x) # [batch_size, seq_len, d_model]
K = self.W_k(x)
V = self.W_v(x)
# Reshape for multi-head attention
# [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k]
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention scores: Q @ K^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Apply attention to values
context = torch.matmul(attention_weights, V)
# Concatenate heads
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
# Final output projection
output = self.W_o(context)
return output
# ============================================
# 5. FEED-FORWARD NETWORK
# ============================================
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
"""
Args:
d_model: embedding dimension
d_ff: hidden dimension (typically 4x d_model)
dropout: dropout rate
"""
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
Returns:
output of same shape
"""
return self.fc2(self.dropout(F.relu(self.fc1(x))))
# ============================================
# 6. TRANSFORMER BLOCK
# ============================================
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Self-attention
self.attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = LayerNorm(d_model)
# Feed-forward
self.ffn = FeedForward(d_model, d_ff)
self.norm2 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
mask: optional attention mask
Returns:
output of same shape
"""
# Self-attention with residual connection and layer norm
attn_output = self.attention(x, mask)
x = self.norm1(x + self.dropout(attn_output)) # Add & Norm
# Feed-forward with residual connection and layer norm
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output)) # Add & Norm
return x
# ============================================
# 7. SAMPLING METHODS
# ============================================
def top_k_sampling(logits, k=10, temperature=1.0):
"""
Top-k (top-n) sampling with temperature.
"""
# Apply temperature
logits = logits / max(temperature, 1e-5)
# Get top k logits and indices
topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(1, topk_indices, topk_logits)
probs = F.softmax(mask, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token, probs
def top_p_sampling(logits, p=0.9, temperature=1.0):
"""
Top-p (nucleus) sampling with temperature.
"""
# Apply temperature
logits = logits / max(temperature, 1e-5)
# Sort logits by probability
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
# Cumulative probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_mask = cumulative_probs > p
cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
cutoff_mask[..., 0] = False
sorted_logits[cutoff_mask] = float('-inf')
# Unsort back to original order
unsorted_logits = torch.full_like(logits, float('-inf'))
unsorted_logits.scatter_(1, sorted_indices, sorted_logits)
probs = F.softmax(unsorted_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token, probs
# ============================================
# 8. COMPLETE TRANSFORMER MODEL
# ============================================
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=4,
num_layers=2, d_ff=1024, max_seq_len=5000, dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Language model head
self.lm_head = nn.Linear(d_model, vocab_size)
def generate_causal_mask(self, seq_len):
"""
Creates a lower triangular matrix for causal attention.
1 = look at this token, 0 = ignore this token.
"""
return torch.tril(torch.ones(seq_len, seq_len))
def forward(self, token_ids):
"""
Args:
token_ids: tensor of shape [batch_size, seq_len]
Returns:
logits of shape [batch_size, seq_len, vocab_size]
"""
# Tokenization is done before this (input is token IDs)
batch_size, seq_len = token_ids.shape
# Step 2: Token embeddings
x = self.token_embedding(token_ids)
# Step 3 & 4: Add positional encoding
x = self.positional_encoding(x)
x = self.dropout(x)
# Causal mask for decoder-only model
mask = self.generate_causal_mask(seq_len).to(token_ids.device)
# Step 5-8: Transformer blocks (attention + FFN + residual + norm)
for block in self.transformer_blocks:
x = block(x, mask=mask)
# Step 10: Language model head (vocabulary scores)
logits = self.lm_head(x)
return logits
def predict_next_token(self, token_ids):
"""
Predict the next token given a sequence.
Args:
token_ids: tensor of shape [batch_size, seq_len]
Returns:
next_token_id: most likely next token
probabilities: softmax probabilities
"""
with torch.no_grad():
logits = self.forward(token_ids)
# Get logits for the last token only
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Step 11: Softmax to convert logits to probabilities
probabilities = F.softmax(last_logits, dim=-1)
# Step 12: Select token with highest probability
next_token_ids = torch.argmax(probabilities, dim=-1)
return next_token_ids, probabilities
def predict_next_token_greedy(self, token_ids):
"""
Greedy decoding: argmax over probabilities.
"""
return self.predict_next_token(token_ids)
def predict_next_token_topk(self, token_ids, k=10, temperature=1.0):
"""
Top-k sampling with temperature.
"""
with torch.no_grad():
logits = self.forward(token_ids)
last_logits = logits[:, -1, :]
next_token_ids, probs = top_k_sampling(last_logits, k=k, temperature=temperature)
return next_token_ids.squeeze(-1), probs
def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0):
"""
Top-p (nucleus) sampling with temperature.
"""
with torch.no_grad():
logits = self.forward(token_ids)
last_logits = logits[:, -1, :]
next_token_ids, probs = top_p_sampling(last_logits, p=p, temperature=temperature)
return next_token_ids.squeeze(-1), probs
# ============================================
# 9. PERSISTENCE
# ============================================
def save_checkpoint(model, tokenizer, path="model.pth"):
checkpoint = {
'model_state': model.state_dict(),
'vocab': tokenizer.chars,
'config': {
'd_model': model.d_model,
'num_heads': model.num_heads,
'num_layers': model.num_layers,
'd_ff': model.d_ff,
'vocab_size': model.vocab_size
}
}
torch.save(checkpoint, path)
print(f"Model saved to {path}")
def load_checkpoint(path="model.pth"):
if not os.path.exists(path):
return None, None
checkpoint = torch.load(path)
tokenizer = CharTokenizer(chars=checkpoint['vocab'])
config = checkpoint['config']
model = SimpleTransformer(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
d_ff=config.get('d_ff', 1024)
)
model.load_state_dict(checkpoint['model_state'])
model.eval()
print(f"Model loaded from {path}")
return model, tokenizer
# ============================================
# 10. TRAINING UTILITY
# ============================================
def train_model(model, tokenizer, text, epochs=1000, batch_size=4, seq_len=64):
"""Training loop with batching."""
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # Lower LR for stability
criterion = nn.CrossEntropyLoss()
ids = tokenizer.encode(text)
model.train()
for epoch in range(epochs):
batch_inputs = []
batch_targets = []
for _ in range(batch_size):
if len(ids) <= seq_len + 1:
start_idx = 0
curr_seq_len = len(ids) - 1
else:
start_idx = torch.randint(0, len(ids) - seq_len - 1, (1,)).item()
curr_seq_len = seq_len
batch_inputs.append(ids[start_idx : start_idx + curr_seq_len])
batch_targets.append(ids[start_idx + 1 : start_idx + curr_seq_len + 1])
# Pad or truncate to ensure uniform batch shape if necessary (here they are balanced)
inputs = torch.tensor(batch_inputs)
targets = torch.tensor(batch_targets)
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1))
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
model.eval()
# ============================================
# 11. DATASET UTILITY
# ============================================
def load_tiny_stories_subset(percent):
"""Load a percentage of the TinyStories dataset."""
from datasets import load_dataset
TOTAL_STORIES = 2141709
count = max(1, int(TOTAL_STORIES * (percent / 100.0)))
print(f"Loading {percent}% of TinyStories ({count} stories)...")
dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
stories = []
for i, entry in enumerate(dataset):
stories.append(entry['text'])
if i >= count - 1:
break
training_text = "\n\n".join(stories)
print(f"Fetched {len(training_text)} characters.")
return training_text
# ============================================
# 12. DEMONSTRATION
# ============================================
if __name__ == "__main__":
MODEL_PATH = "model.pth"
while True:
print("\n" + "=" * 50)
print("TRANSFORMER INTERACTIVE MENU")
print("=" * 50)
print("1. Train Fresh Model")
print("2. Load & Retrain Existing Model")
print("3. Generate Text from Model")
print("4. Exit")
choice = input("Enter choice (1-4): ")
if choice == '1':
percent_str = input("Enter percentage of dataset to train on (default: 0.001): ") or "0.001"
percent = float(percent_str)
training_text = load_tiny_stories_subset(percent)
tokenizer = CharTokenizer(training_text)
# Scaled up model
model = SimpleTransformer(
vocab_size=tokenizer.vocab_size,
d_model=256,
num_heads=8,
num_layers=4,
d_ff=1024
)
print("\nStarting training (Scaled Up Architecture)...")
train_model(model, tokenizer, training_text, epochs=2000, seq_len=64)
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '2':
model, tokenizer = load_checkpoint(MODEL_PATH)
if model is None:
print("No local model found. Please train a fresh one first.")
continue
percent_str = input("Enter percentage of new data to retrain on (default: 0.001): ") or "0.001"
percent = float(percent_str)
training_text = load_tiny_stories_subset(percent)
print("\nStarting retraining...")
train_model(model, tokenizer, training_text, epochs=500, seq_len=64)
save_checkpoint(model, tokenizer, MODEL_PATH)
elif choice == '3':
model, tokenizer = load_checkpoint(MODEL_PATH)
if model is None:
print("No local model found. Please train a fresh one first.")
continue
print("\nSelect a prompt:")
prompts = [
"Once upon a time",
"a quick brown fox",
"a quick brown fox jumps",
"a quick brown ",
"a quick brown fox jumps over",
"a quick brown fox jumps over the lazy"
]
for i, p in enumerate(prompts):
print(f" {i+1}. {p}")
print(f" {len(prompts)+1}. Custom input")
p_choice = input(f"Enter choice (1-{len(prompts)+1}, default: 1): ") or "1"
temp_str = input("Enter temperature (0.1-2.0, default 0.8): ") or "0.8"
temperature = float(temp_str)
if p_choice.isdigit() and 1 <= int(p_choice) <= len(prompts):
prompt = prompts[int(p_choice) - 1]
elif p_choice == str(len(prompts) + 1):
prompt = input("Enter custom prompt: ")
else:
prompt = prompts[0]
input_text = prompt
generated = prompt
print(f"\nGenerating (temp={temperature})...")
for _ in range(100):
token_ids = tokenizer.encode(input_text)
token_tensor = torch.tensor([token_ids])
next_token_id, _ = model.predict_next_token_topp(token_tensor, p=0.9, temperature=temperature)
next_char = tokenizer.id2char.get(next_token_id.item(), " ")
generated += next_char
input_text = generated[-64:]
print(f"\nRESULT:\n{generated}")
elif choice == '4':
print("Exiting...")
break
else:
print("Invalid choice.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment