Created
January 8, 2026 14:06
-
-
Save amitpuri/218233bd9272d6f8de3cf171341b7a45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Transformer Architecture v7 | |
| ================================================================== | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| import signal | |
| import sys | |
| import time | |
| import glob | |
| import tiktoken | |
| from collections import Counter | |
| from typing import Tuple, List, Dict | |
| # Global flag for graceful shutdown | |
| _interrupted = False | |
| def signal_handler(signum, frame): | |
| """Handle interrupt signals gracefully""" | |
| global _interrupted | |
| _interrupted = True | |
| print("\n\n[!] Interrupt received. Finishing current operation gracefully...") | |
| print(" Press Ctrl+C again to force exit.") | |
| signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt | |
| def reset_interrupt(): | |
| """Reset the interrupt flag""" | |
| global _interrupted | |
| _interrupted = False | |
| def is_interrupted(): | |
| """Check if interrupt was requested""" | |
| return _interrupted | |
| # Install signal handler | |
| signal.signal(signal.SIGINT, signal_handler) | |
| import array | |
| # ============================================ | |
| # 0. HELPER FUNCTIONS | |
| # ============================================ | |
| def safe_input(prompt, default_value): | |
| """Safe input with default value""" | |
| try: | |
| user_input = input(prompt) | |
| if not user_input.strip(): | |
| return default_value | |
| return user_input | |
| except EOFError: | |
| return default_value | |
| def batch_encode(tokenizer, story_generator, total_target, batch_size=1000): | |
| """ | |
| Encode stories in batches to show progress | |
| """ | |
| print(f" Encoding approx {total_target} stories in batches of {batch_size}...") | |
| # Use array.array 'H' (unsigned short, 2 bytes) for memory efficiency. | |
| # GPT-2 vocab is 50257, which fits in 0-65535. | |
| all_tokens = array.array('H') | |
| for i, story in enumerate(story_generator): | |
| if is_interrupted(): | |
| print("\n [!] Interrupted during encoding.") | |
| break | |
| if story.strip(): | |
| tokens = tokenizer.encode(story) | |
| all_tokens.extend(tokens) | |
| # Append EOS token so the model learns to stop | |
| all_tokens.append(tokenizer.eos_token_id) | |
| if (i + 1) % batch_size == 0: | |
| print(f" Processed {i + 1}/{total_target} stories...", end='\r') | |
| print(f"\n Finished encoding. Total tokens: {len(all_tokens)}") | |
| return all_tokens | |
| # ============================================ | |
| # 1. ENHANCED TOKENIZATION (Tiktoken) | |
| # ============================================ | |
| class TiktokenTokenizer: | |
| """Subword tokenizer using tiktoken (gpt2 by default)""" | |
| def __init__(self, encoding_name="gpt2"): | |
| self.encoding_name = encoding_name | |
| self.encoder = tiktoken.get_encoding(encoding_name) | |
| self.vocab_size = self.encoder.n_vocab | |
| # Tiktoken gpt2 has <|endoftext|> as token 50256 | |
| # We'll map a few common labels for compatibility if needed, | |
| # but subword tokenizers usually handle their own special tokens. | |
| self.eos_token_id = self.encoder.eot_token | |
| # Compatibility with v6's SPECIAL_TOKENS interface where possible | |
| self.SPECIAL_TOKENS = { | |
| '<pad>': 0, # gpt2 doesn't have a pad token by default, but 0 is usually fine or we could add one | |
| '<eos>': self.eos_token_id | |
| } | |
| def encode(self, text): | |
| """Convert text to token IDs""" | |
| return self.encoder.encode(text, allowed_special={'<|endoftext|>'}) | |
| def decode(self, ids): | |
| """Convert token IDs back to text""" | |
| if torch.is_tensor(ids): | |
| ids = ids.tolist() | |
| # Handle array.array or list | |
| return self.encoder.decode(ids) | |
| # ============================================ | |
| # 2. TRANSFORMER COMPONENTS | |
| # ============================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_seq_len=5000, dropout=0.1): | |
| """Sinusoidal positional encoding""" | |
| super().__init__() | |
| pe = torch.zeros(max_seq_len, d_model) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x): | |
| """x: [batch_size, seq_len, d_model]""" | |
| return self.dropout(x + self.pe[:, :x.size(1)]) | |
| class LayerNorm(nn.Module): | |
| def __init__(self, d_model, eps=1e-6): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(d_model)) | |
| self.beta = nn.Parameter(torch.zeros(d_model)) | |
| self.eps = eps | |
| def forward(self, x): | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| x_norm = (x - mean) / torch.sqrt(var + self.eps) | |
| return self.gamma * x_norm + self.beta | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads, dropout=0.1): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| self.W_o = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| batch_size, seq_len, _ = x.shape | |
| Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| if mask is not None: | |
| mask = mask.to(scores.device) | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attention_weights = F.softmax(scores, dim=-1) | |
| attention_weights = self.dropout(attention_weights) | |
| context = torch.matmul(attention_weights, V) | |
| context = context.transpose(1, 2).contiguous() | |
| context = context.view(batch_size, seq_len, self.d_model) | |
| output = self.W_o(context) | |
| return output | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff=2048, dropout=0.1): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| def forward(self, x): | |
| return self.fc2(self.dropout(F.relu(self.fc1(x)))) | |
| class TransformerBlockV7(nn.Module): | |
| """Pre-LN Transformer Block""" | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.norm1 = LayerNorm(d_model) | |
| self.attention = MultiHeadAttention(d_model, num_heads, dropout) | |
| self.norm2 = LayerNorm(d_model) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| x = x + self.dropout(self.attention(self.norm1(x), mask)) | |
| x = x + self.dropout(self.ffn(self.norm2(x))) | |
| return x | |
| class TransformerV7(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=6, | |
| d_ff=1024, max_seq_len=5000, dropout=0.1): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.num_layers = num_layers | |
| self.d_ff = d_ff | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| self.transformer_blocks = nn.ModuleList([ | |
| TransformerBlockV7(d_model, num_heads, d_ff, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.final_norm = LayerNorm(d_model) | |
| self.lm_head = nn.Linear(d_model, vocab_size) | |
| self._init_parameters() | |
| def _init_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def generate_causal_mask(self, seq_len, device): | |
| return torch.tril(torch.ones(seq_len, seq_len, device=device)) | |
| def forward(self, token_ids): | |
| batch_size, seq_len = token_ids.shape | |
| x = self.token_embedding(token_ids) | |
| x = self.positional_encoding(x) | |
| x = self.dropout(x) | |
| causal_mask = self.generate_causal_mask(seq_len, token_ids.device) | |
| for block in self.transformer_blocks: | |
| x = block(x, mask=causal_mask) | |
| x = self.final_norm(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| def get_num_params(self): | |
| return sum(p.numel() for p in self.parameters()) | |
| # ============================================ | |
| # 3. DATA LOADING & TRAINING | |
| # ============================================ | |
| def get_tiny_stories_generator(percent=1.0): | |
| """ | |
| Generator that yields stories one by one from TinyStories | |
| """ | |
| from datasets import load_dataset | |
| print(f"Loading {percent}% of TinyStories dataset (Streaming mode)...") | |
| dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True) | |
| # Estimate target count (approx 2.1M total in train split) | |
| total_count = 2119719 | |
| target_count = max(1, int(total_count * (percent / 100.0))) | |
| def story_yield(): | |
| for i, entry in enumerate(dataset): | |
| if i >= target_count: | |
| break | |
| yield entry['text'] | |
| return story_yield(), target_count | |
| def train_epoch(model, tokenizer, data_ids, batch_size=32, seq_len=64, optimizer=None, criterion=None): | |
| model.train() | |
| total_loss = 0 | |
| num_batches = 0 | |
| n_tokens = len(data_ids) | |
| # data_ids is now an array.array or list | |
| for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len): | |
| if is_interrupted(): | |
| break | |
| inputs_list = [] | |
| targets_list = [] | |
| for b in range(batch_size): | |
| start_idx = i + b * seq_len | |
| if start_idx + seq_len + 1 >= n_tokens: | |
| break | |
| chunk = data_ids[start_idx : start_idx + seq_len + 1] | |
| inputs_list.append(list(chunk[:-1])) | |
| targets_list.append(list(chunk[1:])) | |
| if len(inputs_list) == 0: | |
| break | |
| inputs = torch.tensor(inputs_list).to(model.token_embedding.weight.device) | |
| targets = torch.tensor(targets_list).to(model.token_embedding.weight.device) | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1)) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| if num_batches % 10 == 0: | |
| print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\r') | |
| print(f" Batch {num_batches}, Final Loss: {loss.item():.4f}") | |
| return total_loss / max(1, num_batches) | |
| def generate_text(model, tokenizer, prompt="Once upon a time", max_len=100, temperature=0.8): | |
| model.eval() | |
| tokens = tokenizer.encode(prompt) | |
| input_ids = torch.tensor([tokens]).to(model.token_embedding.weight.device) | |
| print(f"Generating (Prompt: '{prompt}')...") | |
| start_len = input_ids.shape[1] | |
| max_model_len = model.positional_encoding.pe.size(1) | |
| # Truncate input if it exceeds model capacity | |
| if start_len > max_model_len: | |
| print(f" [!] Prompt length ({start_len}) exceeds model limit ({max_model_len}). Truncating.") | |
| input_ids = input_ids[:, -max_model_len:] | |
| for _ in range(max_len): | |
| if is_interrupted(): break | |
| # Safety check for context window | |
| if input_ids.size(1) >= max_model_len: | |
| print(f"\n [!] Context window reached ({max_model_len}). Stopping generation.") | |
| break | |
| with torch.no_grad(): | |
| logits = model(input_ids) | |
| last_logits = logits[:, -1, :] / max(temperature, 1e-5) | |
| probs = F.softmax(last_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| return tokenizer.decode(input_ids[0]) | |
| # ============================================ | |
| # 4. CHECKPOINTING | |
| # ============================================ | |
| def save_checkpoint(model, tokenizer, path): | |
| checkpoint = { | |
| 'model_state': model.state_dict(), | |
| 'encoding_name': tokenizer.encoding_name, | |
| 'config': { | |
| 'd_model': model.d_model, | |
| 'num_heads': model.num_heads, | |
| 'num_layers': model.num_layers, | |
| 'd_ff': model.d_ff, | |
| 'vocab_size': model.vocab_size, | |
| } | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"Model and configuration saved to {path}") | |
| def load_checkpoint(path): | |
| if not os.path.exists(path): | |
| return None, None | |
| print(f"Loading checkpoint from {path}...") | |
| checkpoint = torch.load(path) | |
| encoding_name = checkpoint.get('encoding_name', 'gpt2') | |
| config = checkpoint['config'] | |
| tokenizer = TiktokenTokenizer(encoding_name=encoding_name) | |
| model = TransformerV7( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| num_heads=config['num_heads'], | |
| num_layers=config['num_layers'], | |
| d_ff=config.get('d_ff', 1024), | |
| ) | |
| model.load_state_dict(checkpoint['model_state']) | |
| model.eval() | |
| return model, tokenizer | |
| def select_model_file(default_path=None): | |
| os.makedirs("models_cache", exist_ok=True) | |
| files = glob.glob("models_cache/model_v7*.pth") | |
| files.sort(key=os.path.getmtime, reverse=True) | |
| if not files: | |
| print("No existing v7 model files found.") | |
| return default_path | |
| print("\nAvailable Models:") | |
| for i, f in enumerate(files): | |
| size_mb = os.path.getsize(f) / (1024 * 1024) | |
| print(f"{i+1}. {f} ({size_mb:.2f} MB)") | |
| print(f"{len(files)+1}. Cancel / Use New Name ({default_path})") | |
| choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1") | |
| try: | |
| idx = int(choice) - 1 | |
| if 0 <= idx < len(files): | |
| return files[idx] | |
| except ValueError: | |
| pass | |
| return default_path | |
| # ============================================ | |
| # 5. MAIN EXECUTION MENU | |
| # ============================================ | |
| def main(): | |
| print("Transformer Architecture v7") | |
| print("=============================================") | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("models_cache", exist_ok=True) | |
| MODEL_PATH = f"models_cache/model_v7_{timestamp}.pth" | |
| model = None | |
| tokenizer = None | |
| while True: | |
| reset_interrupt() | |
| print("\n\nMAIN MENU") | |
| print("1. Train New Model (TinyStories)") | |
| print("2. Load Model & Resume Training") | |
| print("3. Generate Text") | |
| print("4. Exit") | |
| choice = safe_input("\nEnter choice (1-4): ", "4") | |
| if choice == '1': | |
| percent_str = safe_input("Enter percentage of TinyStories to load (0.01-100, default 0.1): ", "0.1") | |
| try: | |
| percent = float(percent_str) | |
| except: | |
| percent = 0.1 | |
| print("\n[1] Preparing Data Stream...") | |
| story_gen, target_count = get_tiny_stories_generator(percent=percent) | |
| print("\n[2] Initializing Tokenizer...") | |
| tokenizer = TiktokenTokenizer("gpt2") | |
| print(f"Tokenizer ready. Vocab size: {tokenizer.vocab_size}") | |
| print("\n[3] Initializing Model...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Training on: {device}") | |
| # Using smaller default for subword to keep speed vs v6 word-level | |
| model = TransformerV7( | |
| vocab_size=tokenizer.vocab_size, | |
| d_model=256, | |
| num_heads=8, | |
| num_layers=4, | |
| d_ff=1024 | |
| ).to(device) | |
| print(f"Params: {model.get_num_params():,}") | |
| # --- TOKEN CACHING LOGIC --- | |
| os.makedirs("datasets_cache", exist_ok=True) | |
| cache_file = f"datasets_cache/tinystories_p{percent}.bin" | |
| if os.path.exists(cache_file): | |
| print(f"\n[4] Loading Encoded Data from Cache ({cache_file})...") | |
| train_ids = array.array('H') | |
| with open(cache_file, 'rb') as f: | |
| train_ids.fromfile(f, os.path.getsize(cache_file) // 2) | |
| print(f" Loaded {len(train_ids)} tokens.") | |
| else: | |
| print("\n[4] Encoding Data (Streaming)...") | |
| train_ids = batch_encode(tokenizer, story_gen, total_target=target_count, batch_size=1000) | |
| if train_ids and not is_interrupted(): | |
| print(f" Saving tokens to cache: {cache_file}") | |
| with open(cache_file, 'wb') as f: | |
| train_ids.tofile(f) | |
| if not train_ids: | |
| print("No tokens found. Training cancelled.") | |
| continue | |
| print("\n[5] Training Loop (Press Ctrl+C to stop & save)") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>']) | |
| epochs = int(safe_input("Epochs (default 3): ", "3")) | |
| try: | |
| for epoch in range(1, epochs + 1): | |
| if is_interrupted(): break | |
| print(f"\n--- Epoch {epoch}/{epochs} ---") | |
| avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64, | |
| optimizer=optimizer, criterion=criterion) | |
| print(f"Avg Loss: {avg_loss:.4f}") | |
| if epoch % 1 == 0: | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| except KeyboardInterrupt: | |
| print("\nInterrupted.") | |
| if is_interrupted(): | |
| print("\n[!] Loop interrupted. Saving checkpoint...") | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '2': | |
| path = select_model_file() | |
| if not path: | |
| continue | |
| model, tokenizer = load_checkpoint(path) | |
| if not model: | |
| print("Failed to load model.") | |
| continue | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| MODEL_PATH = path | |
| print("\nNote: To resume training, we need data.") | |
| percent_str = safe_input("Enter percentage of TinyStories to load for training (default 0.1): ", "0.1") | |
| percent = float(percent_str) | |
| # --- TOKEN CACHING LOGIC --- | |
| os.makedirs("datasets_cache", exist_ok=True) | |
| cache_file = f"datasets_cache/tinystories_p{percent}.bin" | |
| if os.path.exists(cache_file): | |
| print(f"Loading Encoded Data from Cache ({cache_file})...") | |
| train_ids = array.array('H') | |
| with open(cache_file, 'rb') as f: | |
| train_ids.fromfile(f, os.path.getsize(cache_file) // 2) | |
| else: | |
| story_gen, target_count = get_tiny_stories_generator(percent=percent) | |
| print("Encoding data...") | |
| train_ids = batch_encode(tokenizer, story_gen, total_target=target_count, batch_size=1000) | |
| if train_ids and not is_interrupted(): | |
| print(f"Saving tokens to cache: {cache_file}") | |
| with open(cache_file, 'wb') as f: | |
| train_ids.tofile(f) | |
| if not train_ids: | |
| print("Data loading failed.") | |
| continue | |
| print("\nResuming Training (Press Ctrl+C to stop & save)") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>']) | |
| epochs = int(safe_input("Additional Epochs (default 3): ", "3")) | |
| try: | |
| for epoch in range(1, epochs + 1): | |
| if is_interrupted(): break | |
| print(f"\n--- Epoch {epoch}/{epochs} ---") | |
| avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64, | |
| optimizer=optimizer, criterion=criterion) | |
| print(f"Avg Loss: {avg_loss:.4f}") | |
| except KeyboardInterrupt: | |
| print("\nInterrupted.") | |
| if is_interrupted(): | |
| print("\n[!] Loop interrupted. Saving checkpoint...") | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '3': | |
| if model is None: | |
| path = select_model_file() | |
| if path: | |
| model, tokenizer = load_checkpoint(path) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| if model is None: | |
| print("No model loaded. Please Train or Load first.") | |
| continue | |
| print("\n--- Text Generation ---") | |
| prompt = safe_input("Enter prompt (default: 'Once upon a time'): ", "Once upon a time") | |
| length = int(safe_input("Length (default 100): ", "100")) | |
| temp = float(safe_input("Temperature (default 0.8): ", "0.8")) | |
| generated = generate_text(model, tokenizer, prompt, max_len=length, temperature=temp) | |
| print(f"\n[OUTPUT]\n{generated}\n") | |
| elif choice == '4': | |
| print("Exiting.") | |
| break | |
| else: | |
| print("Invalid choice.") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment