Created
January 3, 2026 05:50
-
-
Save amitpuri/f06b192295c4caffd635ed472d5108e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| # ============================================ | |
| # 1. TOKENIZATION (using simple mapping) | |
| # ============================================ | |
| class SimpleTokenizer: | |
| def __init__(self, vocab): | |
| self.word2id = {word: idx for idx, word in enumerate(vocab)} | |
| self.id2word = {idx: word for word, idx in self.word2id.items()} | |
| self.vocab_size = len(vocab) | |
| def encode(self, text): | |
| """Convert text to token IDs""" | |
| tokens = text.lower().split() | |
| return [self.word2id.get(token, self.word2id.get("<unk>", 0)) for token in tokens] | |
| def decode(self, token_ids): | |
| """Convert token IDs back to text""" | |
| return [self.id2word[idx] for idx in token_ids] | |
| class CharTokenizer: | |
| def __init__(self, text=None, chars=None): | |
| if chars is not None: | |
| self.chars = chars | |
| elif text is not None: | |
| self.chars = sorted(list(set(text))) | |
| else: | |
| raise ValueError("Must provide either text or chars") | |
| self.vocab_size = len(self.chars) | |
| self.char2id = {ch: i for i, ch in enumerate(self.chars)} | |
| self.id2char = {i: ch for i, ch in enumerate(self.chars)} | |
| def encode(self, text): | |
| return [self.char2id.get(c, 0) for c in text] | |
| def decode(self, ids): | |
| if torch.is_tensor(ids): | |
| ids = ids.tolist() | |
| return "".join([self.id2char.get(i, "?") for i in ids]) | |
| # ============================================ | |
| # 2. POSITIONAL ENCODING | |
| # ============================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_seq_len=5000): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| max_seq_len: maximum sequence length | |
| """ | |
| super().__init__() | |
| # Create positional encoding matrix | |
| pe = torch.zeros(max_seq_len, d_model) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| # sin for even indices, cos for odd | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| # Register as buffer (not a parameter, but part of the model state) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: embeddings of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| x with positional encoding added | |
| """ | |
| return x + self.pe[:, :x.size(1)] | |
| # ============================================ | |
| # 3. LAYER NORMALIZATION | |
| # ============================================ | |
| class LayerNorm(nn.Module): | |
| def __init__(self, d_model, eps=1e-6): | |
| """ | |
| Args: | |
| d_model: dimension size | |
| eps: small constant for numerical stability | |
| """ | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(d_model)) | |
| self.beta = nn.Parameter(torch.zeros(d_model)) | |
| self.eps = eps | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| normalized output | |
| """ | |
| # Calculate mean and variance along the last dimension | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| # Normalize: (x - mean) / sqrt(var + eps) | |
| x_norm = (x - mean) / torch.sqrt(var + self.eps) | |
| # Scale and shift: gamma * x_norm + beta | |
| return self.gamma * x_norm + self.beta | |
| # ============================================ | |
| # 4. MULTI-HEAD SELF-ATTENTION | |
| # ============================================ | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| num_heads: number of attention heads | |
| """ | |
| super().__init__() | |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| # Linear projections for Q, K, V | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| # Output projection | |
| self.W_o = nn.Linear(d_model, d_model) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| mask: optional attention mask | |
| Returns: | |
| attention output | |
| """ | |
| batch_size, seq_len, _ = x.shape | |
| # Linear projections | |
| Q = self.W_q(x) # [batch_size, seq_len, d_model] | |
| K = self.W_k(x) | |
| V = self.W_v(x) | |
| # Reshape for multi-head attention | |
| # [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k] | |
| Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| # Attention scores: Q @ K^T / sqrt(d_k) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| # Apply mask if provided | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| # Softmax to get attention weights | |
| attention_weights = F.softmax(scores, dim=-1) | |
| # Apply attention to values | |
| context = torch.matmul(attention_weights, V) | |
| # Concatenate heads | |
| context = context.transpose(1, 2).contiguous() | |
| context = context.view(batch_size, seq_len, self.d_model) | |
| # Final output projection | |
| output = self.W_o(context) | |
| return output | |
| # ============================================ | |
| # 5. FEED-FORWARD NETWORK | |
| # ============================================ | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff=2048, dropout=0.1): | |
| """ | |
| Args: | |
| d_model: embedding dimension | |
| d_ff: hidden dimension (typically 4x d_model) | |
| dropout: dropout rate | |
| """ | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| Returns: | |
| output of same shape | |
| """ | |
| return self.fc2(self.dropout(F.relu(self.fc1(x)))) | |
| # ============================================ | |
| # 6. TRANSFORMER BLOCK | |
| # ============================================ | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| # Self-attention | |
| self.attention = MultiHeadAttention(d_model, num_heads) | |
| self.norm1 = LayerNorm(d_model) | |
| # Feed-forward | |
| self.ffn = FeedForward(d_model, d_ff) | |
| self.norm2 = LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: input of shape [batch_size, seq_len, d_model] | |
| mask: optional attention mask | |
| Returns: | |
| output of same shape | |
| """ | |
| # Self-attention with residual connection and layer norm | |
| attn_output = self.attention(x, mask) | |
| x = self.norm1(x + self.dropout(attn_output)) # Add & Norm | |
| # Feed-forward with residual connection and layer norm | |
| ffn_output = self.ffn(x) | |
| x = self.norm2(x + self.dropout(ffn_output)) # Add & Norm | |
| return x | |
| # ============================================ | |
| # 7. SAMPLING METHODS | |
| # ============================================ | |
| def top_k_sampling(logits, k=10, temperature=1.0): | |
| """ | |
| Top-k (top-n) sampling with temperature. | |
| """ | |
| # Apply temperature | |
| logits = logits / max(temperature, 1e-5) | |
| # Get top k logits and indices | |
| topk_logits, topk_indices = torch.topk(logits, k, dim=-1) | |
| mask = torch.full_like(logits, float('-inf')) | |
| mask.scatter_(1, topk_indices, topk_logits) | |
| probs = F.softmax(mask, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token, probs | |
| def top_p_sampling(logits, p=0.9, temperature=1.0): | |
| """ | |
| Top-p (nucleus) sampling with temperature. | |
| """ | |
| # Apply temperature | |
| logits = logits / max(temperature, 1e-5) | |
| # Sort logits by probability | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| # Cumulative probabilities | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| cutoff_mask = cumulative_probs > p | |
| cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone() | |
| cutoff_mask[..., 0] = False | |
| sorted_logits[cutoff_mask] = float('-inf') | |
| # Unsort back to original order | |
| unsorted_logits = torch.full_like(logits, float('-inf')) | |
| unsorted_logits.scatter_(1, sorted_indices, sorted_logits) | |
| probs = F.softmax(unsorted_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token, probs | |
| # ============================================ | |
| # 8. COMPLETE TRANSFORMER MODEL | |
| # ============================================ | |
| class SimpleTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, num_heads=4, | |
| num_layers=2, d_ff=1024, max_seq_len=5000, dropout=0.1): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.num_layers = num_layers | |
| self.d_ff = d_ff | |
| # Embeddings | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.positional_encoding = PositionalEncoding(d_model, max_seq_len) | |
| self.dropout = nn.Dropout(dropout) | |
| # Transformer blocks | |
| self.transformer_blocks = nn.ModuleList([ | |
| TransformerBlock(d_model, num_heads, d_ff, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| # Language model head | |
| self.lm_head = nn.Linear(d_model, vocab_size) | |
| def generate_causal_mask(self, seq_len): | |
| """ | |
| Creates a lower triangular matrix for causal attention. | |
| 1 = look at this token, 0 = ignore this token. | |
| """ | |
| return torch.tril(torch.ones(seq_len, seq_len)) | |
| def forward(self, token_ids): | |
| """ | |
| Args: | |
| token_ids: tensor of shape [batch_size, seq_len] | |
| Returns: | |
| logits of shape [batch_size, seq_len, vocab_size] | |
| """ | |
| # Tokenization is done before this (input is token IDs) | |
| batch_size, seq_len = token_ids.shape | |
| # Step 2: Token embeddings | |
| x = self.token_embedding(token_ids) | |
| # Step 3 & 4: Add positional encoding | |
| x = self.positional_encoding(x) | |
| x = self.dropout(x) | |
| # Causal mask for decoder-only model | |
| mask = self.generate_causal_mask(seq_len).to(token_ids.device) | |
| # Step 5-8: Transformer blocks (attention + FFN + residual + norm) | |
| for block in self.transformer_blocks: | |
| x = block(x, mask=mask) | |
| # Step 10: Language model head (vocabulary scores) | |
| logits = self.lm_head(x) | |
| return logits | |
| def predict_next_token(self, token_ids): | |
| """ | |
| Predict the next token given a sequence. | |
| Args: | |
| token_ids: tensor of shape [batch_size, seq_len] | |
| Returns: | |
| next_token_id: most likely next token | |
| probabilities: softmax probabilities | |
| """ | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids) | |
| # Get logits for the last token only | |
| last_logits = logits[:, -1, :] # [batch_size, vocab_size] | |
| # Step 11: Softmax to convert logits to probabilities | |
| probabilities = F.softmax(last_logits, dim=-1) | |
| # Step 12: Select token with highest probability | |
| next_token_ids = torch.argmax(probabilities, dim=-1) | |
| return next_token_ids, probabilities | |
| def predict_next_token_greedy(self, token_ids): | |
| """ | |
| Greedy decoding: argmax over probabilities. | |
| """ | |
| return self.predict_next_token(token_ids) | |
| def predict_next_token_topk(self, token_ids, k=10, temperature=1.0): | |
| """ | |
| Top-k sampling with temperature. | |
| """ | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids) | |
| last_logits = logits[:, -1, :] | |
| next_token_ids, probs = top_k_sampling(last_logits, k=k, temperature=temperature) | |
| return next_token_ids.squeeze(-1), probs | |
| def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0): | |
| """ | |
| Top-p (nucleus) sampling with temperature. | |
| """ | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids) | |
| last_logits = logits[:, -1, :] | |
| next_token_ids, probs = top_p_sampling(last_logits, p=p, temperature=temperature) | |
| return next_token_ids.squeeze(-1), probs | |
| # ============================================ | |
| # 9. PERSISTENCE | |
| # ============================================ | |
| def save_checkpoint(model, tokenizer, path="model.pth"): | |
| checkpoint = { | |
| 'model_state': model.state_dict(), | |
| 'vocab': tokenizer.chars, | |
| 'config': { | |
| 'd_model': model.d_model, | |
| 'num_heads': model.num_heads, | |
| 'num_layers': model.num_layers, | |
| 'd_ff': model.d_ff, | |
| 'vocab_size': model.vocab_size | |
| } | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"Model saved to {path}") | |
| def load_checkpoint(path="model.pth"): | |
| if not os.path.exists(path): | |
| return None, None | |
| checkpoint = torch.load(path) | |
| tokenizer = CharTokenizer(chars=checkpoint['vocab']) | |
| config = checkpoint['config'] | |
| model = SimpleTransformer( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| num_heads=config['num_heads'], | |
| num_layers=config['num_layers'], | |
| d_ff=config.get('d_ff', 1024) | |
| ) | |
| model.load_state_dict(checkpoint['model_state']) | |
| model.eval() | |
| print(f"Model loaded from {path}") | |
| return model, tokenizer | |
| # ============================================ | |
| # 10. TRAINING UTILITY | |
| # ============================================ | |
| def train_model(model, tokenizer, text, epochs=1000, batch_size=4, seq_len=64): | |
| """Training loop with batching.""" | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # Lower LR for stability | |
| criterion = nn.CrossEntropyLoss() | |
| ids = tokenizer.encode(text) | |
| model.train() | |
| for epoch in range(epochs): | |
| batch_inputs = [] | |
| batch_targets = [] | |
| for _ in range(batch_size): | |
| if len(ids) <= seq_len + 1: | |
| start_idx = 0 | |
| curr_seq_len = len(ids) - 1 | |
| else: | |
| start_idx = torch.randint(0, len(ids) - seq_len - 1, (1,)).item() | |
| curr_seq_len = seq_len | |
| batch_inputs.append(ids[start_idx : start_idx + curr_seq_len]) | |
| batch_targets.append(ids[start_idx + 1 : start_idx + curr_seq_len + 1]) | |
| # Pad or truncate to ensure uniform batch shape if necessary (here they are balanced) | |
| inputs = torch.tensor(batch_inputs) | |
| targets = torch.tensor(batch_targets) | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| if epoch % 100 == 0: | |
| print(f"Epoch {epoch}, Loss: {loss.item():.4f}") | |
| model.eval() | |
| # ============================================ | |
| # 11. DATASET UTILITY | |
| # ============================================ | |
| def load_tiny_stories_subset(percent): | |
| """Load a percentage of the TinyStories dataset.""" | |
| from datasets import load_dataset | |
| TOTAL_STORIES = 2141709 | |
| count = max(1, int(TOTAL_STORIES * (percent / 100.0))) | |
| print(f"Loading {percent}% of TinyStories ({count} stories)...") | |
| dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True) | |
| stories = [] | |
| for i, entry in enumerate(dataset): | |
| stories.append(entry['text']) | |
| if i >= count - 1: | |
| break | |
| training_text = "\n\n".join(stories) | |
| print(f"Fetched {len(training_text)} characters.") | |
| return training_text | |
| # ============================================ | |
| # 12. DEMONSTRATION | |
| # ============================================ | |
| if __name__ == "__main__": | |
| MODEL_PATH = "model.pth" | |
| while True: | |
| print("\n" + "=" * 50) | |
| print("TRANSFORMER INTERACTIVE MENU") | |
| print("=" * 50) | |
| print("1. Train Fresh Model") | |
| print("2. Load & Retrain Existing Model") | |
| print("3. Generate Text from Model") | |
| print("4. Exit") | |
| choice = input("Enter choice (1-4): ") | |
| if choice == '1': | |
| percent_str = input("Enter percentage of dataset to train on (default: 0.001): ") or "0.001" | |
| percent = float(percent_str) | |
| training_text = load_tiny_stories_subset(percent) | |
| tokenizer = CharTokenizer(training_text) | |
| # Scaled up model | |
| model = SimpleTransformer( | |
| vocab_size=tokenizer.vocab_size, | |
| d_model=256, | |
| num_heads=8, | |
| num_layers=4, | |
| d_ff=1024 | |
| ) | |
| print("\nStarting training (Scaled Up Architecture)...") | |
| train_model(model, tokenizer, training_text, epochs=2000, seq_len=64) | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '2': | |
| model, tokenizer = load_checkpoint(MODEL_PATH) | |
| if model is None: | |
| print("No local model found. Please train a fresh one first.") | |
| continue | |
| percent_str = input("Enter percentage of new data to retrain on (default: 0.001): ") or "0.001" | |
| percent = float(percent_str) | |
| training_text = load_tiny_stories_subset(percent) | |
| print("\nStarting retraining...") | |
| train_model(model, tokenizer, training_text, epochs=500, seq_len=64) | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '3': | |
| model, tokenizer = load_checkpoint(MODEL_PATH) | |
| if model is None: | |
| print("No local model found. Please train a fresh one first.") | |
| continue | |
| print("\nSelect a prompt:") | |
| prompts = [ | |
| "Once upon a time", | |
| "a quick brown fox", | |
| "a quick brown fox jumps", | |
| "a quick brown ", | |
| "a quick brown fox jumps over", | |
| "a quick brown fox jumps over the lazy" | |
| ] | |
| for i, p in enumerate(prompts): | |
| print(f" {i+1}. {p}") | |
| print(f" {len(prompts)+1}. Custom input") | |
| p_choice = input(f"Enter choice (1-{len(prompts)+1}, default: 1): ") or "1" | |
| temp_str = input("Enter temperature (0.1-2.0, default 0.8): ") or "0.8" | |
| temperature = float(temp_str) | |
| if p_choice.isdigit() and 1 <= int(p_choice) <= len(prompts): | |
| prompt = prompts[int(p_choice) - 1] | |
| elif p_choice == str(len(prompts) + 1): | |
| prompt = input("Enter custom prompt: ") | |
| else: | |
| prompt = prompts[0] | |
| input_text = prompt | |
| generated = prompt | |
| print(f"\nGenerating (temp={temperature})...") | |
| for _ in range(100): | |
| token_ids = tokenizer.encode(input_text) | |
| token_tensor = torch.tensor([token_ids]) | |
| next_token_id, _ = model.predict_next_token_topp(token_tensor, p=0.9, temperature=temperature) | |
| next_char = tokenizer.id2char.get(next_token_id.item(), " ") | |
| generated += next_char | |
| input_text = generated[-64:] | |
| print(f"\nRESULT:\n{generated}") | |
| elif choice == '4': | |
| print("Exiting...") | |
| break | |
| else: | |
| print("Invalid choice.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment