Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Created January 3, 2026 05:46
Show Gist options
  • Select an option

  • Save amitpuri/19334362c91184c9dda3868763f24ce2 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/19334362c91184c9dda3868763f24ce2 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ============================================
# 1. TOKENIZATION (using simple mapping)
# ============================================
class SimpleTokenizer:
def __init__(self, vocab):
self.word2id = {word: idx for idx, word in enumerate(vocab)}
self.id2word = {idx: word for word, idx in self.word2id.items()}
self.vocab_size = len(vocab)
def encode(self, text):
"""Convert text to token IDs"""
tokens = text.lower().split()
return [self.word2id.get(token, self.word2id.get("<unk>", 0)) for token in tokens]
def decode(self, token_ids):
"""Convert token IDs back to text"""
return [self.id2word[idx] for idx in token_ids]
# ============================================
# 2. POSITIONAL ENCODING
# ============================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000):
"""
Args:
d_model: embedding dimension
max_seq_len: maximum sequence length
"""
super().__init__()
# Create positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
# sin for even indices, cos for odd
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register as buffer (not a parameter, but part of the model state)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""
Args:
x: embeddings of shape [batch_size, seq_len, d_model]
Returns:
x with positional encoding added
"""
return x + self.pe[:, :x.size(1)]
# ============================================
# 3. LAYER NORMALIZATION
# ============================================
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
"""
Args:
d_model: dimension size
eps: small constant for numerical stability
"""
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
Returns:
normalized output
"""
# Calculate mean and variance along the last dimension
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalize: (x - mean) / sqrt(var + eps)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift: gamma * x_norm + beta
return self.gamma * x_norm + self.beta
# ============================================
# 4. MULTI-HEAD SELF-ATTENTION
# ============================================
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Args:
d_model: embedding dimension
num_heads: number of attention heads
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
mask: optional attention mask
Returns:
attention output
"""
batch_size, seq_len, _ = x.shape
# Linear projections
Q = self.W_q(x) # [batch_size, seq_len, d_model]
K = self.W_k(x)
V = self.W_v(x)
# Reshape for multi-head attention
# [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k]
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention scores: Q @ K^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Apply attention to values
context = torch.matmul(attention_weights, V)
# Concatenate heads
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
# Final output projection
output = self.W_o(context)
return output
# ============================================
# 5. FEED-FORWARD NETWORK
# ============================================
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
"""
Args:
d_model: embedding dimension
d_ff: hidden dimension (typically 4x d_model)
dropout: dropout rate
"""
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
Returns:
output of same shape
"""
return self.fc2(self.dropout(F.relu(self.fc1(x))))
# ============================================
# 6. TRANSFORMER BLOCK
# ============================================
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Self-attention
self.attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = LayerNorm(d_model)
# Feed-forward
self.ffn = FeedForward(d_model, d_ff)
self.norm2 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: input of shape [batch_size, seq_len, d_model]
mask: optional attention mask
Returns:
output of same shape
"""
# Self-attention with residual connection and layer norm
attn_output = self.attention(x, mask)
x = self.norm1(x + self.dropout(attn_output)) # Add & Norm
# Feed-forward with residual connection and layer norm
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output)) # Add & Norm
return x
# ============================================
# 7. COMPLETE TRANSFORMER MODEL
# ============================================
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, d_model=256, num_heads=4,
num_layers=2, d_ff=1024, max_seq_len=5000, dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Language model head
self.lm_head = nn.Linear(d_model, vocab_size)
def generate_causal_mask(self, seq_len):
"""
Creates a lower triangular matrix for causal attention.
1 = look at this token, 0 = ignore this token.
"""
return torch.tril(torch.ones(seq_len, seq_len))
def forward(self, token_ids):
"""
Args:
token_ids: tensor of shape [batch_size, seq_len]
Returns:
logits of shape [batch_size, seq_len, vocab_size]
"""
# Tokenization is done before this (input is token IDs)
batch_size, seq_len = token_ids.shape
# Step 2: Token embeddings
x = self.token_embedding(token_ids)
# Step 3 & 4: Add positional encoding
x = self.positional_encoding(x)
x = self.dropout(x)
# Causal mask for decoder-only model
mask = self.generate_causal_mask(seq_len).to(token_ids.device)
# Step 5-8: Transformer blocks (attention + FFN + residual + norm)
for block in self.transformer_blocks:
x = block(x, mask=mask)
# Step 10: Language model head (vocabulary scores)
logits = self.lm_head(x)
return logits
def predict_next_token(self, token_ids):
"""
Predict the next token given a sequence.
Args:
token_ids: tensor of shape [batch_size, seq_len]
Returns:
next_token_id: most likely next token
probabilities: softmax probabilities
"""
with torch.no_grad():
logits = self.forward(token_ids)
# Get logits for the last token only
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Step 11: Softmax to convert logits to probabilities
probabilities = F.softmax(last_logits, dim=-1)
# Step 12: Select token with highest probability
next_token_ids = torch.argmax(probabilities, dim=-1)
return next_token_ids, probabilities
# ============================================
# 8. TRAINING UTILITY
# ============================================
def train_model(model, tokenizer, text, epochs=50):
"""Simple training loop to overfit on a single sentence for demonstration."""
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
# Prepare data (predict next token)
full_ids = tokenizer.encode(text)
inputs = torch.tensor([full_ids[:-1]])
targets = torch.tensor([full_ids[1:]])
model.train()
for _ in range(epochs):
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1))
loss.backward()
optimizer.step()
model.eval()
# ============================================
# 9. DEMONSTRATION
# ============================================
if __name__ == "__main__":
# Setup
vocab = ["<pad>", "a", "quick", "brown", "fox", "jumps", "over",
"the", "lazy", "dog", "<unk>"]
tokenizer = SimpleTokenizer(vocab)
# Model parameters
d_model = 128
num_heads = 4
num_layers = 2
d_ff = 512
# Initialize model
model = SimpleTransformer(
vocab_size=len(vocab),
d_model=d_model,
num_heads=num_heads,
num_layers=num_layers,
d_ff=d_ff
)
# --- PARAMETERS ---
target_sentence = "a quick brown fox jumps over the lazy dog"
test_inputs = [
"a quick brown fox",
"a quick brown fox jumps",
"a quick brown ",
"a quick brown fox jumps over",
"a quick brown fox jumps over the lazy"
]
# ------------------
print("=" * 70)
print("TRANSFORMER ARCHITECTURE DEMONSTRATION")
print("=" * 70)
# Step 0: Training
print("\n[STEP 0] PRE-TRAINING (teaching the model the sentence)")
print(f"Training on: '{target_sentence}'...")
train_model(model, tokenizer, target_sentence, epochs=150)
print("Training complete.")
for input_text in test_inputs:
print("\n" + "-" * 50)
print(f"TESTING INPUT: '{input_text}'")
print("-" * 50)
# Step 1: Tokenization
token_ids = tokenizer.encode(input_text)
token_tensor = torch.tensor([token_ids])
# Step 12: Predict next token
next_token_id, probs = model.predict_next_token(token_tensor)
next_word = vocab[next_token_id.item()]
confidence = probs[0, next_token_id].item()
print(f"Predicted next word: '{next_word}' (confidence: {confidence:.2%})")
print(f"Sequence: '{input_text} {next_word}'")
print("\n" + "=" * 70)
print("DETAILED STEP-BY-STEP FOR LAST INPUT:")
# Re-run the last one with full details
input_text = test_inputs[-1]
token_ids = tokenizer.encode(input_text)
token_tensor = torch.tensor([token_ids])
# Step 1: Tokenization details
print("\n[STEP 1] TOKENIZATION")
print(f"Input text: '{input_text}'")
print(f"Tokens: {tokenizer.decode(token_ids)}")
print(f"Token IDs: {token_ids}")
print(f"Tensor shape: {token_tensor.shape}")
# Step 2: Token embeddings
print("\n[STEP 2] TOKEN EMBEDDINGS")
token_emb = model.token_embedding(token_tensor)
print(f"Token embedding shape: {token_emb.shape}")
print(f"Token embedding for 'quick' (first 10 dims):")
print(f" {token_emb[0, 1, :10]}")
# Step 3: Positional encoding
print("\n[STEP 3] POSITIONAL ENCODING")
pos_enc = model.positional_encoding.pe[:, :token_emb.size(1), :10]
print(f"Positional encoding shape: {model.positional_encoding.pe.shape}")
print(f"Positional encoding for positions 0-2 (first 10 dims):")
print(f" Position 0: {pos_enc[0, 0]}")
print(f" Position 1: {pos_enc[0, 1]}")
# Step 4: Combined embeddings
print("\n[STEP 4] TOKEN + POSITIONAL EMBEDDINGS")
combined_emb = token_emb + model.positional_encoding(token_emb)
print(f"Combined embedding shape: {combined_emb.shape}")
# Step 5-8: Transformer Blocks
print("\n[STEP 5-8] TRANSFORMER BLOCKS (Attention + FFN)")
logits = model.forward(token_tensor)
print(f"Logits shape: {logits.shape}")
# Step 10: Language model head
print("\n[STEP 10] LANGUAGE MODEL HEAD")
last_logits = logits[:, -1, :]
print(f"Logits for last token: {last_logits.shape}")
print(f"Logits for top 5 vocabulary items:")
top_5_logits, top_5_indices = torch.topk(last_logits, 5, dim=-1)
for i, (idx, logit) in enumerate(zip(top_5_indices[0], top_5_logits[0])):
print(f" {i+1}. {vocab[idx.item()]:15} logit: {logit.item():.4f}")
# Step 11: Softmax
print("\n[STEP 11] SOFTMAX")
probabilities = F.softmax(last_logits, dim=-1)
top_5_probs, top_5_indices = torch.topk(probabilities, 5, dim=-1)
print(f"Top 5 probabilities:")
for i, (idx, prob) in enumerate(zip(top_5_indices[0], top_5_probs[0])):
print(f" {i+1}. {vocab[idx.item()]:15} probability: {prob.item():.4%}")
print("\n" + "=" * 70)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment