Last active
January 3, 2026 06:11
-
-
Save amitpuri/3390df3f802729975990520ff531a868 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Transformer Architecture v4 | |
| ================================================================== | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| import asyncio | |
| import signal | |
| import sys | |
| import time | |
| import glob | |
| from typing import Tuple, List, Dict | |
| # Global flag for graceful shutdown | |
| _interrupted = False | |
| def signal_handler(signum, frame): | |
| """Handle interrupt signals gracefully""" | |
| global _interrupted | |
| _interrupted = True | |
| print("\n\n[!] Interrupt received. Finishing current operation gracefully...") | |
| print(" Press Ctrl+C again to force exit.") | |
| signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt | |
| def reset_interrupt(): | |
| """Reset the interrupt flag""" | |
| global _interrupted | |
| _interrupted = False | |
| def is_interrupted(): | |
| """Check if interrupt was requested""" | |
| return _interrupted | |
| # Install signal handler | |
| signal.signal(signal.SIGINT, signal_handler) | |
| # ============================================ | |
| # 1. ENHANCED TOKENIZATION | |
| # ============================================ | |
| class CharTokenizer: | |
| """Character-level tokenizer with special tokens for instructions""" | |
| SPECIAL_TOKENS = { | |
| '<pad>': 0, | |
| '<unk>': 1, | |
| '<story>': 2, # Story generation mode | |
| '<summary>': 3, # Summarization mode | |
| '<dialog>': 4, # Dialogue mode | |
| '<instruct>': 5, # Instruction following mode | |
| '<eos>': 6, # End of sequence | |
| } | |
| def __init__(self, text=None, chars=None): | |
| if chars is not None: | |
| self.chars = chars | |
| elif text is not None: | |
| self.chars = sorted(list(set(text))) | |
| else: | |
| raise ValueError("Must provide either text or chars") | |
| # Add special tokens | |
| self.special_token_list = list(self.SPECIAL_TOKENS.keys()) | |
| self.chars = self.special_token_list + [c for c in self.chars if c not in self.special_token_list] | |
| self.vocab_size = len(self.chars) | |
| self.char2id = {ch: i for i, ch in enumerate(self.chars)} | |
| self.id2char = {i: ch for i, ch in enumerate(self.chars)} | |
| def encode(self, text): | |
| """Convert text to token IDs""" | |
| return [self.char2id.get(c, self.char2id['<unk>']) for c in text] | |
| def decode(self, ids): | |
| """Convert token IDs back to text""" | |
| if torch.is_tensor(ids): | |
| ids = ids.tolist() | |
| return "".join([self.id2char.get(i, "?") for i in ids]) | |
| def encode_with_task(self, text, task_token='<story>'): | |
| """Encode text with task prefix""" | |
| task_id = self.char2id.get(task_token, self.char2id['<story>']) | |
| text_ids = self.encode(text) | |
| return [task_id] + text_ids | |
| def get_special_token_id(self, token_name): | |
| """Get ID of special token""" | |
| return self.char2id.get(token_name, self.char2id['<unk>']) | |
| # ============================================ | |
| # 2. POSITIONAL ENCODING | |
| # ============================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_seq_len=5000, dropout=0.1): | |
| """Sinusoidal positional encoding""" | |
| super().__init__() | |
| pe = torch.zeros(max_seq_len, d_model) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x, verbose=False): | |
| """x: [batch_size, seq_len, d_model]""" | |
| if verbose: | |
| print("\n [Positional Encoding] Adding sinusoidal position info to embeddings") | |
| print(" - Intuition: Transformer sees all words at once; this injects order info (like page numbers)") | |
| print(f" - Input shape: {x.shape}") | |
| print(f" - Formula: PE(pos, 2i) = sin(pos/10000^(2i/d_model))") | |
| print(f" - Formula: PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))") | |
| print(f" - Sample PE (first 8 dims for pos 0): {self.pe[0, :8].tolist()}") | |
| return self.dropout(x + self.pe[:, :x.size(1)]) | |
| # ============================================ | |
| # 3. LAYER NORMALIZATION | |
| # ============================================ | |
| class LayerNorm(nn.Module): | |
| def __init__(self, d_model, eps=1e-6): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(d_model)) | |
| self.beta = nn.Parameter(torch.zeros(d_model)) | |
| self.eps = eps | |
| def forward(self, x): | |
| """Normalize across embedding dimension""" | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| x_norm = (x - mean) / torch.sqrt(var + self.eps) | |
| return self.gamma * x_norm + self.beta | |
| # ============================================ | |
| # 4. MULTI-HEAD SELF-ATTENTION WITH CAUSAL MASK | |
| # ============================================ | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads, dropout=0.1): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| self.W_o = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None, verbose=False): | |
| """Multi-head self-attention with optional causal masking""" | |
| batch_size, seq_len, _ = x.shape | |
| if verbose: | |
| print(f"\n [Multi-Head Attention] Splitting into {self.num_heads} heads (d_model={self.d_model} -> d_k={self.d_k})") | |
| print(" - Intuition: 'Looking around' the sentence to gather context (e.g., matching 'it' to 'dog')") | |
| print(" - Step 1: Linear projections for Queries (Q), Keys (K), Values (V)") | |
| Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| if verbose: | |
| print(f" - Q, K, V shapes: {Q.shape} (batch, heads, seq, d_k)") | |
| print(f" - Sample Q (head 0, first token, first 4 dims): {Q[0, 0, 0, :4].tolist()}") | |
| print(" - Step 2: Scaled Dot-Product Attention (Score = Q · K^T / √d_k)") | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| if mask is not None: | |
| if verbose: | |
| print(" - Masked Attention: Hiding future tokens (causal mask)") | |
| print(f" - Scores before mask (head 0, row 0, first 4 cols): {scores[0, 0, 0, :4].tolist()}") | |
| mask = mask.to(scores.device) | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| if verbose: | |
| print(f" - Scores after mask (head 0, row 0, first 4 cols): {scores[0, 0, 0, :4].tolist()}") | |
| if verbose: | |
| print(" - Step 4: Softmax -> Attention Weights (Probability distribution)") | |
| print(" - Intuition: Deciding 'how much' to focus on each word (sum of weights = 1.0)") | |
| attention_weights = F.softmax(scores, dim=-1) | |
| if verbose: | |
| print(f" - Sample Weights (head 0, row 0, first 4 cols): {attention_weights[0, 0, 0, :4].tolist()}") | |
| attention_weights = self.dropout(attention_weights) | |
| if verbose: | |
| print(" - Step 5: Weighted sum of Values (Context aggregation)") | |
| print(f" - Formula: Context = Sum(Weight_i * Value_i)") | |
| context = torch.matmul(attention_weights, V) | |
| context = context.transpose(1, 2).contiguous() | |
| context = context.view(batch_size, seq_len, self.d_model) | |
| if verbose: | |
| print(" - Concatenating heads and final linear projection") | |
| output = self.W_o(context) | |
| return output | |
| # ============================================ | |
| # 5. FEED-FORWARD NETWORK | |
| # ============================================ | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff=2048, dropout=0.1): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| def forward(self, x, verbose=False): | |
| if verbose: | |
| print(f"\n [Feed-Forward] Expansion (d_model={x.shape[-1]} -> d_ff={self.fc1.out_features} -> d_model={x.shape[-1]})") | |
| print(" - Intuition: The 'brain' processing the gathered context to extract complex features") | |
| print(f" - Formula: ReLU(xW_1 + b_1)W_2 + b_2") | |
| hidden = self.fc1(x) | |
| activated = F.relu(hidden) | |
| if verbose: | |
| print(f" - Sample intermediate (after ReLU, first 4 dims): {activated[0, 0, :4].tolist()}") | |
| dropped = self.dropout(activated) | |
| output = self.fc2(dropped) | |
| return output | |
| # ============================================ | |
| # 6. TRANSFORMER BLOCKS (POST-LN & PRE-LN) | |
| # ============================================ | |
| class TransformerBlockPostLN(nn.Module): | |
| """Post-LN: Attention -> Residual -> LN -> FFN -> Residual -> LN""" | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.attention = MultiHeadAttention(d_model, num_heads, dropout) | |
| self.norm1 = LayerNorm(d_model) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.norm2 = LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None, verbose=False): | |
| if verbose: | |
| print("\n[Transformer Block (Post-LN)]") | |
| attn_out = self.attention(x, mask, verbose=verbose) | |
| if verbose: | |
| print(" [Residual + Norm] Adding attention output to input and normalizing") | |
| x = self.norm1(x + self.dropout(attn_out)) | |
| if verbose: | |
| print(" [Feed-Forward Network] Point-wise expansion and compression") | |
| ffn_out = self.ffn(x, verbose=verbose) | |
| if verbose: | |
| print(" [Residual + Norm] Adding FFN output and normalizing") | |
| x = self.norm2(x + self.dropout(ffn_out)) | |
| return x | |
| class TransformerBlockPreLN(nn.Module): | |
| """Pre-LN: LN -> Attention -> Residual -> LN -> FFN -> Residual""" | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.norm1 = LayerNorm(d_model) | |
| self.attention = MultiHeadAttention(d_model, num_heads, dropout) | |
| self.norm2 = LayerNorm(d_model) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None, verbose=False): | |
| if verbose: | |
| print("\n[Transformer Block (Pre-LN)]") | |
| print(" [Layer Norm] Normalizing input before attention (Pre-LN benefit: stable gradients)") | |
| x_norm = self.norm1(x) | |
| attn_out = self.attention(x_norm, mask, verbose=verbose) | |
| if verbose: | |
| print(" [Residual] Adding attention output to original input path") | |
| x = x + self.dropout(attn_out) | |
| if verbose: | |
| print(" [Layer Norm] Normalizing before FFN") | |
| x_norm = self.norm2(x) | |
| if verbose: | |
| print(" [Feed-Forward] Processing features") | |
| ffn_out = self.ffn(x_norm, verbose=verbose) | |
| if verbose: | |
| print(" [Residual] Final skip connection") | |
| x = x + self.dropout(ffn_out) | |
| return x | |
| # ============================================ | |
| # 7. SAMPLING METHODS | |
| # ============================================ | |
| def top_k_sampling(logits, k=10, temperature=1.0): | |
| """Top-k sampling""" | |
| logits = logits / max(temperature, 1e-5) | |
| topk_logits, topk_indices = torch.topk(logits, k, dim=-1) | |
| mask = torch.full_like(logits, float('-inf')) | |
| mask.scatter_(1, topk_indices, topk_logits) | |
| probs = F.softmax(mask, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token, probs | |
| def top_p_sampling(logits, p=0.9, temperature=1.0): | |
| """Top-p (nucleus) sampling""" | |
| logits = logits / max(temperature, 1e-5) | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| cutoff_mask = cumulative_probs > p | |
| cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone() | |
| cutoff_mask[..., 0] = False | |
| sorted_logits[cutoff_mask] = float('-inf') | |
| unsorted_logits = torch.full_like(logits, float('-inf')) | |
| unsorted_logits.scatter_(1, sorted_indices, sorted_logits) | |
| probs = F.softmax(unsorted_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token, probs | |
| # ============================================ | |
| # 8. TRANSFORMER MODEL | |
| # ============================================ | |
| class TransformerV4(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=4, | |
| d_ff=1024, max_seq_len=5000, dropout=0.1, use_pre_ln=False): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.num_layers = num_layers | |
| self.d_ff = d_ff | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| BlockClass = TransformerBlockPreLN if use_pre_ln else TransformerBlockPostLN | |
| self.transformer_blocks = nn.ModuleList([ | |
| BlockClass(d_model, num_heads, d_ff, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| if use_pre_ln: | |
| self.final_norm = LayerNorm(d_model) | |
| else: | |
| self.final_norm = None | |
| self.lm_head = nn.Linear(d_model, vocab_size) | |
| self._init_parameters() | |
| def _init_parameters(self): | |
| """Initialize parameters with Xavier uniform""" | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def generate_causal_mask(self, seq_len, device): | |
| """Create lower triangular causal mask""" | |
| return torch.tril(torch.ones(seq_len, seq_len, device=device)) | |
| def forward(self, token_ids, verbose=False): | |
| """Forward pass""" | |
| batch_size, seq_len = token_ids.shape | |
| if verbose: | |
| print(f"\n--- TRANSFORMER FORWARD PASS (Batch: {batch_size}, Seq Len: {seq_len}) ---") | |
| print("1. [Embedding] Converting token IDs to vectors") | |
| print(" - Intuition: Looking up the 'meaning' vector for each specific word ID") | |
| x = self.token_embedding(token_ids) | |
| x = self.positional_encoding(x, verbose=verbose) | |
| x = self.dropout(x) | |
| causal_mask = self.generate_causal_mask(seq_len, token_ids.device) | |
| if verbose: | |
| print(f"2. [Stacking] Passing through {self.num_layers} Transformer blocks") | |
| for i, block in enumerate(self.transformer_blocks): | |
| if verbose: | |
| print(f"\n--- Block {i+1}/{self.num_layers} ---") | |
| x = block(x, mask=causal_mask, verbose=verbose) | |
| if self.final_norm is not None: | |
| if verbose: | |
| print("\n3. [Final Norm] Stabilizing output features") | |
| x = self.final_norm(x) | |
| if verbose: | |
| print("\n4. [Output Head] Projecting to vocabulary size (Logits)") | |
| logits = self.lm_head(x) | |
| if verbose: | |
| print(f" Output shape: {logits.shape} (Batch, Seq, Vocab)") | |
| print("--- END FORWARD PASS ---\n") | |
| return logits | |
| def predict_next_token_greedy(self, token_ids, verbose=False): | |
| """Greedy decoding""" | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids, verbose=verbose) | |
| last_logits = logits[:, -1, :] | |
| probs = F.softmax(last_logits, dim=-1) | |
| next_token_ids = torch.argmax(probs, dim=-1) | |
| return next_token_ids, probs | |
| def predict_next_token_topk(self, token_ids, k=10, temperature=1.0, verbose=False): | |
| """Top-k sampling""" | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids, verbose=verbose) | |
| last_logits = logits[:, -1, :] | |
| next_token_ids, probs = top_k_sampling(last_logits, k=k, temperature=temperature) | |
| return next_token_ids.squeeze(-1), probs | |
| def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0, verbose=False): | |
| """Top-p sampling""" | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids, verbose=verbose) | |
| last_logits = logits[:, -1, :] | |
| next_token_ids, probs = top_p_sampling(last_logits, p=p, temperature=temperature) | |
| return next_token_ids.squeeze(-1), probs | |
| # ============================================ | |
| # 10. TRAINING UTILITIES | |
| # ============================================ | |
| def train_model_basic(model, tokenizer, text, epochs=1000, batch_size=4, seq_len=64, lr=None, verbose=False): | |
| """Basic training with warmup schedule and label smoothing""" | |
| # 1. Adam with specific betas and eps | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9) | |
| # 2. Label smoothing | |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| # 3. Warmup scheduler | |
| d_model = model.d_model | |
| warmup_steps = 4000 | |
| # lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5)) | |
| def lr_lambda(step): | |
| step = step + 1 # 1-indexed for formula | |
| return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5)) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| ids = tokenizer.encode(text) | |
| model.train() | |
| print(f"Training with: Adam(betas=(0.9, 0.98)), LabelSmoothing=0.1, WarmupSteps={warmup_steps}") | |
| # Show verbose output for first batch of first epoch if requested | |
| first_step_verbose = verbose | |
| for epoch in range(epochs): | |
| batch_inputs = [] | |
| batch_targets = [] | |
| for _ in range(batch_size): | |
| if len(ids) <= seq_len + 1: | |
| start_idx = 0 | |
| curr_seq_len = len(ids) - 1 | |
| else: | |
| start_idx = torch.randint(0, len(ids) - seq_len - 1, (1,)).item() | |
| curr_seq_len = seq_len | |
| batch_inputs.append(ids[start_idx : start_idx + curr_seq_len]) | |
| batch_targets.append(ids[start_idx + 1 : start_idx + curr_seq_len + 1]) | |
| max_len = max(len(inp) for inp in batch_inputs) | |
| inputs = torch.zeros(len(batch_inputs), max_len, dtype=torch.long) | |
| targets = torch.zeros(len(batch_targets), max_len, dtype=torch.long) | |
| for i, (inp, tgt) in enumerate(zip(batch_inputs, batch_targets)): | |
| inputs[i, :len(inp)] = torch.tensor(inp) | |
| targets[i, :len(tgt)] = torch.tensor(tgt) | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1)) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| if epoch % 100 == 0: | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| print(f"Epoch {epoch}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}") | |
| # ============================================ | |
| # 11. DATASET SELECTION MENU | |
| # ============================================ | |
| def safe_input(prompt, default=""): | |
| """Safe input with keyboard interrupt handling""" | |
| try: | |
| return input(prompt) | |
| except (KeyboardInterrupt, EOFError): | |
| print("\n[Input cancelled]") | |
| return default | |
| # Load datasets from config file | |
| DATASETS_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "datasets_config.json") | |
| def load_datasets_config(): | |
| """Load datasets configuration from JSON file""" | |
| import json | |
| if not os.path.exists(DATASETS_CONFIG_PATH): | |
| print(f"Warning: Config file not found at {DATASETS_CONFIG_PATH}") | |
| return [] | |
| try: | |
| with open(DATASETS_CONFIG_PATH, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| return config.get("datasets", []) | |
| except Exception as e: | |
| print(f"Error loading config: {e}") | |
| return [] | |
| # Load datasets at module import | |
| DATASETS = load_datasets_config() | |
| def get_datasets_by_category(category): | |
| """Get datasets filtered by category""" | |
| return [ds for ds in DATASETS if ds.get("category") == category] | |
| def select_dataset(): | |
| """ | |
| Interactive dataset selection menu with graceful interrupt handling | |
| Returns: (dataset_name, text_field, dataset_type) | |
| """ | |
| if not DATASETS: | |
| print("No datasets found. Check datasets_config.json") | |
| return None, None, None | |
| print("\n" + "=" * 60) | |
| print("DATASET SELECTION") | |
| print("=" * 60) | |
| for i, ds in enumerate(DATASETS, 1): | |
| category_tag = f"[{ds.get('category', 'unknown')}]" | |
| print(f"{i}. {ds['name']} {category_tag}") | |
| print(f" {ds['description']}") | |
| print(f"{len(DATASETS) + 1}. Cancel") | |
| ds_choice = safe_input(f"\nSelect dataset (1-{len(DATASETS) + 1}): ", "") | |
| try: | |
| idx = int(ds_choice) - 1 | |
| if idx == len(DATASETS): # Cancel option | |
| return None, None, None | |
| if 0 <= idx < len(DATASETS): | |
| selected = DATASETS[idx] | |
| return selected['name'], selected['text_field'], selected.get('category', 'unknown') | |
| except ValueError: | |
| pass | |
| print("Invalid selection") | |
| return None, None, None | |
| # Local dataset cache directory | |
| DATASET_CACHE_DIR = os.path.join(os.path.dirname(__file__), "datasets_cache") | |
| def get_local_dataset_path(dataset_name): | |
| """Get the local path for a cached dataset""" | |
| safe_name = dataset_name.replace("/", "_").replace("\\", "_") | |
| return os.path.join(DATASET_CACHE_DIR, f"{safe_name}.json") | |
| async def download_dataset_async(dataset_name, text_field, num_examples=10000, batch_size=100): | |
| """ | |
| Async download dataset from HuggingFace with progress tracking | |
| Args: | |
| dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb') | |
| text_field: The field containing text data | |
| num_examples: Number of examples to download | |
| batch_size: Number of examples to collect before yielding | |
| Returns: | |
| Path to the local dataset file | |
| """ | |
| import json | |
| from datasets import load_dataset | |
| reset_interrupt() | |
| # Create cache directory if it doesn't exist | |
| os.makedirs(DATASET_CACHE_DIR, exist_ok=True) | |
| local_path = get_local_dataset_path(dataset_name) | |
| # Check if already downloaded | |
| if os.path.exists(local_path): | |
| print(f"Dataset already cached at: {local_path}") | |
| redownload = safe_input("Re-download? (y/n, default: n): ", "n").lower() == 'y' | |
| if not redownload: | |
| return local_path | |
| print(f"Downloading {num_examples} examples from {dataset_name}...") | |
| print("[Press Ctrl+C to stop and save partial download]") | |
| try: | |
| dataset = load_dataset(dataset_name, split="train", streaming=True) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| print("Trying with 'sample' config...") | |
| try: | |
| dataset = load_dataset(dataset_name, "sample", split="train", streaming=True) | |
| except Exception as e2: | |
| print(f"Failed to load dataset: {e2}") | |
| return None | |
| data = [] | |
| batch = [] | |
| for i, entry in enumerate(dataset): | |
| # Check for interrupt | |
| if is_interrupted(): | |
| print(f"\n[!] Interrupt detected at {i} examples. Saving partial download...") | |
| break | |
| text = entry.get(text_field, "") | |
| if text: | |
| batch.append({"text": text}) | |
| # Yield control periodically to allow interrupt checking | |
| if len(batch) >= batch_size: | |
| data.extend(batch) | |
| batch = [] | |
| await asyncio.sleep(0) # Yield to event loop | |
| if i >= num_examples - 1: | |
| break | |
| if (i + 1) % 1000 == 0: | |
| print(f" Downloaded {i + 1}/{num_examples} examples...") | |
| # Add remaining batch | |
| data.extend(batch) | |
| if len(data) == 0: | |
| print("No data downloaded.") | |
| return None | |
| # Save to local JSON file | |
| with open(local_path, 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| "dataset_name": dataset_name, | |
| "text_field": text_field, | |
| "num_examples": len(data), | |
| "data": data | |
| }, f, ensure_ascii=False, indent=2) | |
| print(f"Saved {len(data)} examples to: {local_path}") | |
| reset_interrupt() | |
| return local_path | |
| def download_dataset(dataset_name, text_field, num_examples=10000): | |
| """ | |
| Synchronous wrapper for async download | |
| Args: | |
| dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb') | |
| text_field: The field containing text data | |
| num_examples: Number of examples to download | |
| Returns: | |
| Path to the local dataset file | |
| """ | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| loop = None | |
| if loop and loop.is_running(): | |
| # Already in async context - use nest_asyncio or create a new thread | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as pool: | |
| future = pool.submit( | |
| asyncio.run, | |
| download_dataset_async(dataset_name, text_field, num_examples) | |
| ) | |
| return future.result() | |
| else: | |
| # Not in async context - run directly | |
| return asyncio.run(download_dataset_async(dataset_name, text_field, num_examples)) | |
| def load_local_dataset(dataset_name, percent=100.0): | |
| """ | |
| Load dataset from local cache | |
| Args: | |
| dataset_name: Full dataset name (to find the cached file) | |
| percent: Percentage of cached data to use | |
| Returns: | |
| Combined text from the local dataset | |
| """ | |
| import json | |
| local_path = get_local_dataset_path(dataset_name) | |
| if not os.path.exists(local_path): | |
| print(f"Dataset not found locally: {local_path}") | |
| print("Please download the dataset first.") | |
| return None | |
| print(f"Loading from local cache: {local_path}") | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| cached_data = json.load(f) | |
| all_data = cached_data.get("data", []) | |
| total_examples = len(all_data) | |
| # Calculate how many examples to use | |
| num_to_use = max(1, int(total_examples * (percent / 100.0))) | |
| data_subset = all_data[:num_to_use] | |
| # Extract text from each entry | |
| texts = [entry.get("text", "") for entry in data_subset if entry.get("text")] | |
| combined_text = "\n\n".join(texts) | |
| print(f"Loaded {len(texts)} examples ({percent}% of {total_examples} cached) - {len(combined_text)} characters") | |
| return combined_text | |
| def load_selected_dataset(dataset_name, text_field, percent=0.001): | |
| """ | |
| Load a dataset - downloads if not cached, then loads from local cache | |
| Args: | |
| dataset_name: Full dataset name (e.g., 'HuggingFaceFW/fineweb') | |
| text_field: The field containing text data | |
| percent: Percentage of dataset to load (used for download count estimation) | |
| Returns: | |
| Combined text from the dataset | |
| """ | |
| # Estimate counts for datasets | |
| estimated_counts = { | |
| "HuggingFaceFW/fineweb": 52500000000, | |
| "HuggingFaceFW/fineweb-edu": 3500000000, | |
| "HuggingFaceFW/fineweb-edu-score-2": 13900000000, | |
| "HuggingFaceFW/fineweb-2": 4480000000, | |
| "HuggingFaceFW/finewiki": 61600000, | |
| "HuggingFaceFW/clean-wikipedia": 61200000, | |
| "HuggingFaceFW/finepdfs": 476000000, | |
| "HuggingFaceFW/finepdfs-edu": 49500000, | |
| "HuggingFaceFW/ocr-annotations": 1620, | |
| "roneneldan/TinyStories": 2141709, | |
| "roneneldan/TinyStoriesInstruct": 21974061, | |
| } | |
| local_path = get_local_dataset_path(dataset_name) | |
| # Check if dataset exists locally | |
| if os.path.exists(local_path): | |
| print(f"Found cached dataset: {local_path}") | |
| use_cached = input("Use cached dataset? (y/n, default: y): ").lower() != 'n' | |
| if use_cached: | |
| use_percent = float(input("Percentage of cached data to use (default: 100): ") or "100") | |
| return load_local_dataset(dataset_name, use_percent) | |
| # Dataset not cached - need to download | |
| print(f"Dataset not found in cache. Downloading...") | |
| total_count = estimated_counts.get(dataset_name, 1000000) | |
| default_count = max(1, int(total_count * (percent / 100.0))) | |
| default_count = min(default_count, 50000) # Cap at 50k | |
| num_examples = int(input(f"Number of examples to download (default: {default_count}): ") or str(default_count)) | |
| local_path = download_dataset(dataset_name, text_field, num_examples) | |
| if local_path is None: | |
| return None | |
| use_percent = float(input("Percentage of downloaded data to use for training (default: 100): ") or "100") | |
| return load_local_dataset(dataset_name, use_percent) | |
| # ============================================ | |
| # 12. CHECKPOINTING | |
| # ============================================ | |
| def save_checkpoint(model, tokenizer, path="models_cache/model_v4.pth"): | |
| checkpoint = { | |
| 'model_state': model.state_dict(), | |
| 'vocab': tokenizer.chars, | |
| 'vocab_size': tokenizer.vocab_size, | |
| 'config': { | |
| 'd_model': model.d_model, | |
| 'num_heads': model.num_heads, | |
| 'num_layers': model.num_layers, | |
| 'd_ff': model.d_ff, | |
| 'vocab_size': model.vocab_size, | |
| 'use_pre_ln': model.final_norm is not None, | |
| } | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"Model saved to {path}") | |
| def load_checkpoint(path="models_cache/model_v4.pth"): | |
| if not os.path.exists(path): | |
| return None, None | |
| checkpoint = torch.load(path) | |
| vocab = checkpoint['vocab'] | |
| config = checkpoint['config'] | |
| tokenizer = CharTokenizer(chars=vocab) | |
| model = TransformerV4( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| num_heads=config['num_heads'], | |
| num_layers=config['num_layers'], | |
| d_ff=config.get('d_ff', 1024), | |
| use_pre_ln=config.get('use_pre_ln', False), | |
| ) | |
| model.load_state_dict(checkpoint['model_state']) | |
| model.eval() | |
| print(f"Model loaded from {path}") | |
| return model, tokenizer | |
| # ============================================ | |
| # 13. INTERACTIVE DEMO | |
| # ============================================ | |
| if __name__ == "__main__": | |
| # Generate unique model name for this execution | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("models_cache", exist_ok=True) | |
| MODEL_PATH = f"models_cache/model_v4_{timestamp}.pth" | |
| # Global verbose flag | |
| VERBOSE_MODE = False | |
| def select_model_file(default_path=None): | |
| """Allow user to select a model file from current directory""" | |
| os.makedirs("models_cache", exist_ok=True) | |
| files = glob.glob("models_cache/model_v4*.pth") | |
| files.sort(key=os.path.getmtime, reverse=True) | |
| if not files: | |
| print("No existing model files found.") | |
| return default_path | |
| print("\nAvailable Models:") | |
| for i, f in enumerate(files): | |
| size_mb = os.path.getsize(f) / (1024 * 1024) | |
| print(f"{i+1}. {f} ({size_mb:.2f} MB)") | |
| print(f"{len(files)+1}. Cancel / Use Default ({default_path})") | |
| choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1") | |
| try: | |
| idx = int(choice) - 1 | |
| if 0 <= idx < len(files): | |
| return files[idx] | |
| except ValueError: | |
| pass | |
| return default_path | |
| print("\n[Tip: Press Ctrl+C at any time to gracefully interrupt operations]") | |
| while True: | |
| reset_interrupt() # Reset interrupt flag at start of each loop | |
| print("\n" + "=" * 60) | |
| print("TRANSFORMER v4") | |
| print("=" * 60) | |
| print("1. Train with Dataset Selection") | |
| print("2. Load & Retrain Existing Model") | |
| print("3. Generate Text from Model") | |
| print("4. Educational Step-by-Step Demo") | |
| print(f"5. Toggle Verbose Mode [Current: {'ON' if VERBOSE_MODE else 'OFF'}]") | |
| print("6. Exit") | |
| choice = safe_input("\nEnter choice (1-6): ", "6") | |
| if choice == '1': | |
| # Dataset selection and training | |
| dataset_name, text_field, dataset_type = select_dataset() | |
| if dataset_name is None: | |
| print("Dataset selection cancelled.") | |
| continue | |
| percent_str = safe_input("Enter percentage to load (default: 0.001): ", "0.001") | |
| try: | |
| percent = float(percent_str) | |
| except ValueError: | |
| percent = 0.001 | |
| print(f"\nLoading dataset: {dataset_name}") | |
| text = load_selected_dataset(dataset_name, text_field, percent) | |
| if text is None or len(text) < 100: | |
| print("Failed to load dataset or dataset too small.") | |
| continue | |
| tokenizer = CharTokenizer(text=text) | |
| use_pre_ln = safe_input("Use Pre-LN? (y/n, default: n): ", "n").lower() == 'y' | |
| model = TransformerV4( | |
| vocab_size=tokenizer.vocab_size, | |
| d_model=256, num_heads=8, num_layers=4, d_ff=1024, | |
| use_pre_ln=use_pre_ln | |
| ) | |
| print(f"\nTraining on {dataset_name} (Pre-LN={use_pre_ln})...") | |
| print("[Press Ctrl+C to stop training and save current progress]") | |
| train_model_basic(model, tokenizer, text, epochs=2000, seq_len=64, verbose=VERBOSE_MODE) | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '2': | |
| # Select model to load | |
| load_path = select_model_file(MODEL_PATH) | |
| if not load_path or not os.path.exists(load_path): | |
| print(f"Model file not found: {load_path}") | |
| continue | |
| model, tokenizer = load_checkpoint(load_path) | |
| if model is None: | |
| print("No model found. Train first.") | |
| continue | |
| # Use dataset selection for retraining too | |
| dataset_name, text_field, dataset_type = select_dataset() | |
| if dataset_name is None: | |
| print("Dataset selection cancelled.") | |
| continue | |
| percent_str = safe_input("Enter percentage (default: 0.0005): ", "0.0005") | |
| try: | |
| percent = float(percent_str) | |
| except ValueError: | |
| percent = 0.0005 | |
| print(f"\nLoading dataset: {dataset_name}") | |
| text = load_selected_dataset(dataset_name, text_field, percent) | |
| if text is None or len(text) < 100: | |
| print("Failed to load dataset or dataset too small.") | |
| continue | |
| print(f"\nRetraining on {dataset_name}...") | |
| print("[Press Ctrl+C to stop training and save current progress]") | |
| # Save to the NEW unique path for this session, not overwriting the old one | |
| print(f"Will save updated model to: {MODEL_PATH}") | |
| train_model_basic(model, tokenizer, text, epochs=500, seq_len=64, verbose=VERBOSE_MODE) | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '3': | |
| # Select model to load | |
| load_path = select_model_file(MODEL_PATH) | |
| if not load_path or not os.path.exists(load_path): | |
| print(f"Model file not found: {load_path}") | |
| continue | |
| model, tokenizer = load_checkpoint(load_path) | |
| if model is None: | |
| print("No model found. Train first.") | |
| continue | |
| print("\nGeneration Options:") | |
| print("1. Free generation") | |
| print("2. Prompt-guided (TinyStoriesInstruct style)") | |
| gen_choice = safe_input("Choose (1-2): ", "1") | |
| if gen_choice == '1': | |
| prompt = safe_input("Enter prompt: ", "Once upon a time") | |
| else: | |
| print("\nSample story prompts:") | |
| sample_prompts = [ | |
| "Once upon a time", | |
| "a quick brown fox", | |
| "One day, a little girl", | |
| "The sun was shining", | |
| ] | |
| for i, p in enumerate(sample_prompts): | |
| print(f" {i+1}. {p}") | |
| choice_str = safe_input("Select (1-4) or 0 for custom: ", "1") | |
| try: | |
| choice_idx = int(choice_str) - 1 | |
| except ValueError: | |
| choice_idx = 0 | |
| prompt = sample_prompts[choice_idx] if 0 <= choice_idx < len(sample_prompts) else safe_input("Custom: ", "Once upon a time") | |
| temp_str = safe_input("Temperature (0.1-2.0, default 0.8): ", "0.8") | |
| try: | |
| temp = float(temp_str) | |
| except ValueError: | |
| temp = 0.8 | |
| print(f"\nGenerating from: '{prompt}'") | |
| token_ids = tokenizer.encode(prompt) | |
| input_tensor = torch.tensor([token_ids]) | |
| generated = prompt | |
| # Verbose for first token generation only if requested | |
| first_gen_verbose = VERBOSE_MODE | |
| for _ in range(100): | |
| if is_interrupted(): | |
| print("\n[Generation interrupted]") | |
| break | |
| next_id, _ = model.predict_next_token_topp( | |
| input_tensor, | |
| p=0.9, | |
| temperature=temp, | |
| verbose=first_gen_verbose | |
| ) | |
| if first_gen_verbose: | |
| first_gen_verbose = False | |
| next_char = tokenizer.id2char.get(next_id.item(), " ") | |
| generated += next_char | |
| token_ids.append(next_id.item()) | |
| input_tensor = torch.tensor([token_ids[-64:]]) | |
| print(f"\n[GENERATED]\n{generated}\n") | |
| print(f"\n[GENERATED]\n{generated}\n") | |
| elif choice == '4': | |
| demonstrate_steps() | |
| elif choice == '5': | |
| VERBOSE_MODE = not VERBOSE_MODE | |
| print(f"\nVerbose Mode is now {'ON' if VERBOSE_MODE else 'OFF'}") | |
| elif choice == '6': | |
| print("Exiting...") | |
| break | |
| else: | |
| print("Invalid choice") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment