Created
January 3, 2026 05:46
-
-
Save amitpuri/19334362c91184c9dda3868763f24ce2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| # ============================================ | |
| # 1. TOKENIZATION (using simple mapping) | |
| # ============================================ | |
| class SimpleTokenizer: | |
| def __init__(self, vocab): | |
| self.word2id = {word: idx for idx, word in enumerate(vocab)} | |
| self.id2word = {idx: word for word, idx in self.word2id.items()} | |
| self.vocab_size = len(vocab) | |
| def encode(self, text): | |
| """Convert text to token IDs""" | |
| tokens = text.lower().split() | |
| return [self.word2id.get(token, self.word2id.get("<unk>", 0)) for token in tokens] | |
| def decode(self, token_ids): | |
| """Convert token IDs back to text""" | |
| return [self.id2word[idx] for idx in token_ids] | |
| # ============================================ | |
| # 2. POSITIONAL ENCODING | |
| # ============================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_seq_len=5000): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| max_seq_len: maximum sequence length | |
| """ | |
| super().__init__() | |
| # Create positional encoding matrix | |
| pe = torch.zeros(max_seq_len, d_model) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| # sin for even indices, cos for odd | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| # Register as buffer (not a parameter, but part of the model state) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: embeddings of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| x with positional encoding added | |
| """ | |
| return x + self.pe[:, :x.size(1)] | |
| # ============================================ | |
| # 3. LAYER NORMALIZATION | |
| # ============================================ | |
| class LayerNorm(nn.Module): | |
| def __init__(self, d_model, eps=1e-6): | |
| """ | |
| Args: | |
| d_model: dimension size | |
| eps: small constant for numerical stability | |
| """ | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(d_model)) | |
| self.beta = nn.Parameter(torch.zeros(d_model)) | |
| self.eps = eps | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| normalized output | |
| """ | |
| # Calculate mean and variance along the last dimension | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| # Normalize: (x - mean) / sqrt(var + eps) | |
| x_norm = (x - mean) / torch.sqrt(var + self.eps) | |
| # Scale and shift: gamma * x_norm + beta | |
| return self.gamma * x_norm + self.beta | |
| # ============================================ | |
| # 4. MULTI-HEAD SELF-ATTENTION | |
| # ============================================ | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| num_heads: number of attention heads | |
| """ | |
| super().__init__() | |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| # Linear projections for Q, K, V | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| # Output projection | |
| self.W_o = nn.Linear(d_model, d_model) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| mask: optional attention mask | |
| Returns: | |
| attention output | |
| """ | |
| batch_size, seq_len, _ = x.shape | |
| # Linear projections | |
| Q = self.W_q(x) # [batch_size, seq_len, d_model] | |
| K = self.W_k(x) | |
| V = self.W_v(x) | |
| # Reshape for multi-head attention | |
| # [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k] | |
| Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| # Attention scores: Q @ K^T / sqrt(d_k) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| # Apply mask if provided | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| # Softmax to get attention weights | |
| attention_weights = F.softmax(scores, dim=-1) | |
| # Apply attention to values | |
| context = torch.matmul(attention_weights, V) | |
| # Concatenate heads | |
| context = context.transpose(1, 2).contiguous() | |
| context = context.view(batch_size, seq_len, self.d_model) | |
| # Final output projection | |
| output = self.W_o(context) | |
| return output | |
| # ============================================ | |
| # 5. FEED-FORWARD NETWORK | |
| # ============================================ | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff=2048, dropout=0.1): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| d_ff: hidden dimension (typically 4x d_model) | |
| dropout: dropout rate | |
| """ | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| output of same shape | |
| """ | |
| return self.fc2(self.dropout(F.relu(self.fc1(x)))) | |
| # ============================================ | |
| # 6. TRANSFORMER BLOCK | |
| # ============================================ | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| # Self-attention | |
| self.attention = MultiHeadAttention(d_model, num_heads) | |
| self.norm1 = LayerNorm(d_model) | |
| # Feed-forward | |
| self.ffn = FeedForward(d_model, d_ff) | |
| self.norm2 = LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| mask: optional attention mask | |
| Returns: | |
| output of same shape | |
| """ | |
| # Self-attention with residual connection and layer norm | |
| attn_output = self.attention(x, mask) | |
| x = self.norm1(x + self.dropout(attn_output)) # Add & Norm | |
| # Feed-forward with residual connection and layer norm | |
| ffn_output = self.ffn(x) | |
| x = self.norm2(x + self.dropout(ffn_output)) # Add & Norm | |
| return x | |
| # ============================================ | |
| # 7. COMPLETE TRANSFORMER MODEL | |
| # ============================================ | |
| class SimpleTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, num_heads=4, | |
| num_layers=2, d_ff=1024, max_seq_len=5000, dropout=0.1): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| # Embeddings | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.positional_encoding = PositionalEncoding(d_model, max_seq_len) | |
| self.dropout = nn.Dropout(dropout) | |
| # Transformer blocks | |
| self.transformer_blocks = nn.ModuleList([ | |
| TransformerBlock(d_model, num_heads, d_ff, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| # Language model head | |
| self.lm_head = nn.Linear(d_model, vocab_size) | |
| def generate_causal_mask(self, seq_len): | |
| """ | |
| Creates a lower triangular matrix for causal attention. | |
| 1 = look at this token, 0 = ignore this token. | |
| """ | |
| return torch.tril(torch.ones(seq_len, seq_len)) | |
| def forward(self, token_ids): | |
| """ | |
| Args: | |
| token_ids: tensor of shape [batch_size, seq_len] | |
| Returns: | |
| logits of shape [batch_size, seq_len, vocab_size] | |
| """ | |
| # Tokenization is done before this (input is token IDs) | |
| batch_size, seq_len = token_ids.shape | |
| # Step 2: Token embeddings | |
| x = self.token_embedding(token_ids) | |
| # Step 3 & 4: Add positional encoding | |
| x = self.positional_encoding(x) | |
| x = self.dropout(x) | |
| # Causal mask for decoder-only model | |
| mask = self.generate_causal_mask(seq_len).to(token_ids.device) | |
| # Step 5-8: Transformer blocks (attention + FFN + residual + norm) | |
| for block in self.transformer_blocks: | |
| x = block(x, mask=mask) | |
| # Step 10: Language model head (vocabulary scores) | |
| logits = self.lm_head(x) | |
| return logits | |
| def predict_next_token(self, token_ids): | |
| """ | |
| Predict the next token given a sequence. | |
| Args: | |
| token_ids: tensor of shape [batch_size, seq_len] | |
| Returns: | |
| next_token_id: most likely next token | |
| probabilities: softmax probabilities | |
| """ | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids) | |
| # Get logits for the last token only | |
| last_logits = logits[:, -1, :] # [batch_size, vocab_size] | |
| # Step 11: Softmax to convert logits to probabilities | |
| probabilities = F.softmax(last_logits, dim=-1) | |
| # Step 12: Select token with highest probability | |
| next_token_ids = torch.argmax(probabilities, dim=-1) | |
| return next_token_ids, probabilities | |
| # ============================================ | |
| # 8. TRAINING UTILITY | |
| # ============================================ | |
| def train_model(model, tokenizer, text, epochs=50): | |
| """Simple training loop to overfit on a single sentence for demonstration.""" | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.005) | |
| criterion = nn.CrossEntropyLoss() | |
| # Prepare data (predict next token) | |
| full_ids = tokenizer.encode(text) | |
| inputs = torch.tensor([full_ids[:-1]]) | |
| targets = torch.tensor([full_ids[1:]]) | |
| model.train() | |
| for _ in range(epochs): | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| model.eval() | |
| # ============================================ | |
| # 9. DEMONSTRATION | |
| # ============================================ | |
| if __name__ == "__main__": | |
| # Setup | |
| vocab = ["<pad>", "a", "quick", "brown", "fox", "jumps", "over", | |
| "the", "lazy", "dog", "<unk>"] | |
| tokenizer = SimpleTokenizer(vocab) | |
| # Model parameters | |
| d_model = 128 | |
| num_heads = 4 | |
| num_layers = 2 | |
| d_ff = 512 | |
| # Initialize model | |
| model = SimpleTransformer( | |
| vocab_size=len(vocab), | |
| d_model=d_model, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| d_ff=d_ff | |
| ) | |
| # --- PARAMETERS --- | |
| target_sentence = "a quick brown fox jumps over the lazy dog" | |
| test_inputs = [ | |
| "a quick brown fox", | |
| "a quick brown fox jumps", | |
| "a quick brown ", | |
| "a quick brown fox jumps over", | |
| "a quick brown fox jumps over the lazy" | |
| ] | |
| # ------------------ | |
| print("=" * 70) | |
| print("TRANSFORMER ARCHITECTURE DEMONSTRATION") | |
| print("=" * 70) | |
| # Step 0: Training | |
| print("\n[STEP 0] PRE-TRAINING (teaching the model the sentence)") | |
| print(f"Training on: '{target_sentence}'...") | |
| train_model(model, tokenizer, target_sentence, epochs=150) | |
| print("Training complete.") | |
| for input_text in test_inputs: | |
| print("\n" + "-" * 50) | |
| print(f"TESTING INPUT: '{input_text}'") | |
| print("-" * 50) | |
| # Step 1: Tokenization | |
| token_ids = tokenizer.encode(input_text) | |
| token_tensor = torch.tensor([token_ids]) | |
| # Step 12: Predict next token | |
| next_token_id, probs = model.predict_next_token(token_tensor) | |
| next_word = vocab[next_token_id.item()] | |
| confidence = probs[0, next_token_id].item() | |
| print(f"Predicted next word: '{next_word}' (confidence: {confidence:.2%})") | |
| print(f"Sequence: '{input_text} {next_word}'") | |
| print("\n" + "=" * 70) | |
| print("DETAILED STEP-BY-STEP FOR LAST INPUT:") | |
| # Re-run the last one with full details | |
| input_text = test_inputs[-1] | |
| token_ids = tokenizer.encode(input_text) | |
| token_tensor = torch.tensor([token_ids]) | |
| # Step 1: Tokenization details | |
| print("\n[STEP 1] TOKENIZATION") | |
| print(f"Input text: '{input_text}'") | |
| print(f"Tokens: {tokenizer.decode(token_ids)}") | |
| print(f"Token IDs: {token_ids}") | |
| print(f"Tensor shape: {token_tensor.shape}") | |
| # Step 2: Token embeddings | |
| print("\n[STEP 2] TOKEN EMBEDDINGS") | |
| token_emb = model.token_embedding(token_tensor) | |
| print(f"Token embedding shape: {token_emb.shape}") | |
| print(f"Token embedding for 'quick' (first 10 dims):") | |
| print(f" {token_emb[0, 1, :10]}") | |
| # Step 3: Positional encoding | |
| print("\n[STEP 3] POSITIONAL ENCODING") | |
| pos_enc = model.positional_encoding.pe[:, :token_emb.size(1), :10] | |
| print(f"Positional encoding shape: {model.positional_encoding.pe.shape}") | |
| print(f"Positional encoding for positions 0-2 (first 10 dims):") | |
| print(f" Position 0: {pos_enc[0, 0]}") | |
| print(f" Position 1: {pos_enc[0, 1]}") | |
| # Step 4: Combined embeddings | |
| print("\n[STEP 4] TOKEN + POSITIONAL EMBEDDINGS") | |
| combined_emb = token_emb + model.positional_encoding(token_emb) | |
| print(f"Combined embedding shape: {combined_emb.shape}") | |
| # Step 5-8: Transformer Blocks | |
| print("\n[STEP 5-8] TRANSFORMER BLOCKS (Attention + FFN)") | |
| logits = model.forward(token_tensor) | |
| print(f"Logits shape: {logits.shape}") | |
| # Step 10: Language model head | |
| print("\n[STEP 10] LANGUAGE MODEL HEAD") | |
| last_logits = logits[:, -1, :] | |
| print(f"Logits for last token: {last_logits.shape}") | |
| print(f"Logits for top 5 vocabulary items:") | |
| top_5_logits, top_5_indices = torch.topk(last_logits, 5, dim=-1) | |
| for i, (idx, logit) in enumerate(zip(top_5_indices[0], top_5_logits[0])): | |
| print(f" {i+1}. {vocab[idx.item()]:15} logit: {logit.item():.4f}") | |
| # Step 11: Softmax | |
| print("\n[STEP 11] SOFTMAX") | |
| probabilities = F.softmax(last_logits, dim=-1) | |
| top_5_probs, top_5_indices = torch.topk(probabilities, 5, dim=-1) | |
| print(f"Top 5 probabilities:") | |
| for i, (idx, prob) in enumerate(zip(top_5_indices[0], top_5_probs[0])): | |
| print(f" {i+1}. {vocab[idx.item()]:15} probability: {prob.item():.4%}") | |
| print("\n" + "=" * 70) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment