Created
January 3, 2026 06:13
-
-
Save amitpuri/5a21b7cdc6086990230a6878c5ad2000 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Transformer Architecture v6 (NLTK Tokenization) | |
| ================================================================== | |
| Improved version of v5 using NLTK for word-level tokenization. | |
| Includes interactive menu for Training, Saving, Loading, and Generation. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| import asyncio | |
| import signal | |
| import sys | |
| import time | |
| import glob | |
| import nltk | |
| from collections import Counter | |
| from typing import Tuple, List, Dict | |
| # Ensure NLTK data is available | |
| try: | |
| nltk.download('punkt', quiet=True) | |
| nltk.download('punkt_tab', quiet=True) | |
| except Exception as e: | |
| print(f"Warning: Failed to download NLTK data: {e}") | |
| # Global flag for graceful shutdown | |
| _interrupted = False | |
| def signal_handler(signum, frame): | |
| """Handle interrupt signals gracefully""" | |
| global _interrupted | |
| _interrupted = True | |
| print("\n\n[!] Interrupt received. Finishing current operation gracefully...") | |
| print(" Press Ctrl+C again to force exit.") | |
| signal.signal(signal.SIGINT, signal.SIG_DFL) # Reset to default on next interrupt | |
| def reset_interrupt(): | |
| """Reset the interrupt flag""" | |
| global _interrupted | |
| _interrupted = False | |
| def is_interrupted(): | |
| """Check if interrupt was requested""" | |
| return _interrupted | |
| # Install signal handler | |
| signal.signal(signal.SIGINT, signal_handler) | |
| # ============================================ | |
| # 0. HELPER FUNCTIONS | |
| # ============================================ | |
| def safe_input(prompt, default_value): | |
| """Safe input with default value""" | |
| try: | |
| user_input = input(prompt) | |
| if not user_input.strip(): | |
| return default_value | |
| return user_input | |
| except EOFError: | |
| return default_value | |
| def batch_encode(tokenizer, text, batch_size=1000): | |
| """ | |
| Encode text in batches to show progress | |
| """ | |
| # Split text into lines or chunks (using newline which is common in TinyStories) | |
| lines = text.split('\n') | |
| total_lines = len(lines) | |
| print(f" Encoding {total_lines} lines/chunks in batches of {batch_size}...") | |
| all_tokens = [] | |
| for i in range(0, total_lines, batch_size): | |
| batch = lines[i : i + batch_size] | |
| # Join batch back for cleaner tokenization context or tokenize line by line | |
| # Tokenizing line-by-line is safer for simple string split | |
| for line in batch: | |
| if line.strip(): | |
| all_tokens.extend(tokenizer.encode(line)) | |
| if (i // batch_size) % 10 == 0: | |
| print(f" Processed line {min(i + batch_size, total_lines)}/{total_lines}...", end='\r') | |
| print(f" Finished encoding. Total tokens: {len(all_tokens)}") | |
| return all_tokens | |
| def build_vocab_batched(text, max_vocab_size=10000, batch_size=1000): | |
| """ | |
| Build vocabulary from text in batches to show progress | |
| """ | |
| lines = text.split('\n') | |
| total_lines = len(lines) | |
| print(f" Tokenizing {total_lines} lines/chunks in batches of {batch_size} to build vocab...") | |
| counter = Counter() | |
| for i in range(0, total_lines, batch_size): | |
| if is_interrupted(): | |
| print("\n [!] Interrupted during vocabulary building.") | |
| return [] | |
| batch = lines[i : i + batch_size] | |
| for line in batch: | |
| if line.strip(): | |
| try: | |
| tokens = nltk.word_tokenize(line) | |
| counter.update(tokens) | |
| except LookupError: | |
| # Fallback | |
| counter.update(line.split()) | |
| if (i // batch_size) % 10 == 0: | |
| print(f" Processed line {min(i + batch_size, total_lines)}/{total_lines}...", end='\r') | |
| # Get most common words | |
| most_common = counter.most_common(max_vocab_size) | |
| vocab = [word for word, count in most_common] | |
| print(f"\n Vocabulary built with {len(vocab)} words (top frequency: {most_common[0][1] if most_common else 0})") | |
| return vocab | |
| # ============================================ | |
| # 1. ENHANCED TOKENIZATION (NLTK) | |
| # ============================================ | |
| class NLTKTokenizer: | |
| """Word-level tokenizer using NLTK with dynamic vocabulary""" | |
| SPECIAL_TOKENS = { | |
| '<pad>': 0, | |
| '<unk>': 1, | |
| '<story>': 2, # Story generation mode | |
| '<summary>': 3, # Summarization mode | |
| '<instruct>': 4, # Instruction following mode | |
| '<eos>': 5, # End of sequence | |
| } | |
| def __init__(self, vocab=None): | |
| self.special_token_list = list(self.SPECIAL_TOKENS.keys()) | |
| if vocab is None: | |
| # Default minimal vocab if none provided | |
| self.word2id = self.SPECIAL_TOKENS.copy() | |
| self.id2word = {v: k for k, v in self.word2id.items()} | |
| self.vocab_size = len(self.word2id) | |
| else: | |
| # Vocab should be a list of words | |
| self.word2id = self.SPECIAL_TOKENS.copy() | |
| next_id = len(self.SPECIAL_TOKENS) | |
| for word in vocab: | |
| if word not in self.word2id: | |
| self.word2id[word] = next_id | |
| next_id += 1 | |
| self.id2word = {v: k for k, v in self.word2id.items()} | |
| self.vocab_size = len(self.word2id) | |
| def encode(self, text): | |
| """Convert text to token IDs using NLTK word_tokenize""" | |
| try: | |
| tokens = nltk.word_tokenize(text) | |
| except LookupError: | |
| # Fallback if punkt is missing | |
| tokens = text.split() | |
| return [self.word2id.get(token, self.word2id['<unk>']) for token in tokens] | |
| def decode(self, ids): | |
| """Convert token IDs back to text""" | |
| if torch.is_tensor(ids): | |
| ids = ids.tolist() | |
| words = [self.id2word.get(i, "<unk>") for i in ids] | |
| # Simple detokenization (can be improved with TreebankWordDetokenizer) | |
| from nltk.tokenize.treebank import TreebankWordDetokenizer | |
| try: | |
| return TreebankWordDetokenizer().detokenize(words) | |
| except: | |
| return " ".join(words) | |
| def build_vocab_from_text(text, max_vocab_size=10000): | |
| """ | |
| Build vocabulary from a large text corpus. | |
| Returns a list of words (most common first). | |
| """ | |
| print(" Tokenizing corpus to build vocabulary...") | |
| try: | |
| tokens = nltk.word_tokenize(text) | |
| except LookupError: | |
| print(" Warning: NLTK punkt not found, using split()") | |
| tokens = text.split() | |
| print(f" Found {len(tokens)} tokens. Counting frequencies...") | |
| counter = Counter(tokens) | |
| # Get most common words | |
| most_common = counter.most_common(max_vocab_size) | |
| vocab = [word for word, count in most_common] | |
| print(f" Vocabulary built with {len(vocab)} words (top frequency: {most_common[0][1]})") | |
| return vocab | |
| # ============================================ | |
| # 2. TRANSFORMER COMPONENTS | |
| # ============================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_seq_len=5000, dropout=0.1): | |
| """Sinusoidal positional encoding""" | |
| super().__init__() | |
| pe = torch.zeros(max_seq_len, d_model) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x, verbose=False): | |
| """x: [batch_size, seq_len, d_model]""" | |
| return self.dropout(x + self.pe[:, :x.size(1)]) | |
| class LayerNorm(nn.Module): | |
| def __init__(self, d_model, eps=1e-6): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(d_model)) | |
| self.beta = nn.Parameter(torch.zeros(d_model)) | |
| self.eps = eps | |
| def forward(self, x): | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| x_norm = (x - mean) / torch.sqrt(var + self.eps) | |
| return self.gamma * x_norm + self.beta | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads, dropout=0.1): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| self.W_o = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None, verbose=False): | |
| batch_size, seq_len, _ = x.shape | |
| Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| if mask is not None: | |
| mask = mask.to(scores.device) | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attention_weights = F.softmax(scores, dim=-1) | |
| attention_weights = self.dropout(attention_weights) | |
| context = torch.matmul(attention_weights, V) | |
| context = context.transpose(1, 2).contiguous() | |
| context = context.view(batch_size, seq_len, self.d_model) | |
| output = self.W_o(context) | |
| return output | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff=2048, dropout=0.1): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| def forward(self, x, verbose=False): | |
| hidden = self.fc1(x) | |
| activated = F.relu(hidden) | |
| dropped = self.dropout(activated) | |
| output = self.fc2(dropped) | |
| return output | |
| class TransformerBlockV6(nn.Module): | |
| """Pre-LN Transformer Block (Standard for modern LLMs)""" | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.norm1 = LayerNorm(d_model) | |
| self.attention = MultiHeadAttention(d_model, num_heads, dropout) | |
| self.norm2 = LayerNorm(d_model) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None, verbose=False): | |
| # 1. Pre-LN Attention | |
| x_norm = self.norm1(x) | |
| attn_out = self.attention(x_norm, mask, verbose=verbose) | |
| x = x + self.dropout(attn_out) | |
| # 2. Pre-LN FFN | |
| x_norm = self.norm2(x) | |
| ffn_out = self.ffn(x_norm, verbose=verbose) | |
| x = x + self.dropout(ffn_out) | |
| return x | |
| class TransformerV6(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=6, | |
| d_ff=1024, max_seq_len=5000, dropout=0.1): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.num_layers = num_layers | |
| self.d_ff = d_ff | |
| self.token_embedding = nn.Embedding(vocab_size, d_model) | |
| self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| self.transformer_blocks = nn.ModuleList([ | |
| TransformerBlockV6(d_model, num_heads, d_ff, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.final_norm = LayerNorm(d_model) | |
| self.lm_head = nn.Linear(d_model, vocab_size) | |
| self._init_parameters() | |
| def _init_parameters(self): | |
| """Initialize parameters""" | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def generate_causal_mask(self, seq_len, device): | |
| return torch.tril(torch.ones(seq_len, seq_len, device=device)) | |
| def forward(self, token_ids, verbose=False): | |
| batch_size, seq_len = token_ids.shape | |
| x = self.token_embedding(token_ids) | |
| x = self.positional_encoding(x, verbose=verbose) | |
| x = self.dropout(x) | |
| causal_mask = self.generate_causal_mask(seq_len, token_ids.device) | |
| for block in self.transformer_blocks: | |
| x = block(x, mask=causal_mask, verbose=verbose) | |
| x = self.final_norm(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| def get_num_params(self): | |
| return sum(p.numel() for p in self.parameters()) | |
| def predict_next_token_topp(self, token_ids, p=0.9, temperature=1.0, verbose=False): | |
| """Top-p sampling for generation""" | |
| with torch.no_grad(): | |
| logits = self.forward(token_ids, verbose=verbose) | |
| last_logits = logits[:, -1, :] | |
| # Apply temperature | |
| last_logits = last_logits / max(temperature, 1e-5) | |
| # Top-p (Nucleus) Sampling | |
| sorted_logits, sorted_indices = torch.sort(last_logits, descending=True, dim=-1) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| cutoff_mask = cumulative_probs > p | |
| cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone() | |
| cutoff_mask[..., 0] = False | |
| sorted_logits[cutoff_mask] = float('-inf') | |
| unsorted_logits = torch.full_like(last_logits, float('-inf')) | |
| unsorted_logits.scatter_(1, sorted_indices, sorted_logits) | |
| probs = F.softmax(unsorted_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token, probs | |
| # ============================================ | |
| # 3. DATA LOADING & TRAINING | |
| # ============================================ | |
| def load_tiny_stories(percent=1.0): | |
| from datasets import load_dataset | |
| print(f"Loading {percent}% of TinyStories dataset...") | |
| dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True) | |
| # Estimate count (approx 2M total) | |
| total_count = 2000000 | |
| target_count = max(1, int(total_count * (percent / 100.0))) | |
| print(f"Fetching approx {target_count} examples...") | |
| texts = [] | |
| for i, entry in enumerate(dataset): | |
| if i >= target_count: | |
| break | |
| texts.append(entry['text']) | |
| if i % 100 == 0: | |
| print(f" Fetched {i} stories...", end='\r') | |
| print(f"\nLoaded {len(texts)} stories.") | |
| return "\n".join(texts) | |
| def train_epoch(model, tokenizer, data_ids, batch_size=32, seq_len=64, optimizer=None, criterion=None): | |
| model.train() | |
| total_loss = 0 | |
| num_batches = 0 | |
| n_tokens = len(data_ids) | |
| # Simple sliding window batching | |
| for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len): | |
| if is_interrupted(): | |
| break | |
| inputs_list = [] | |
| targets_list = [] | |
| for b in range(batch_size): | |
| start_idx = i + b * seq_len | |
| # Check bounds | |
| if start_idx + seq_len + 1 >= n_tokens: | |
| break | |
| chunk = data_ids[start_idx : start_idx + seq_len + 1] | |
| inputs_list.append(chunk[:-1]) | |
| targets_list.append(chunk[1:]) | |
| if len(inputs_list) == 0: | |
| break | |
| inputs = torch.tensor(inputs_list).to(model.token_embedding.weight.device) | |
| targets = torch.tensor(targets_list).to(model.token_embedding.weight.device) | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = criterion(logits.view(-1, model.vocab_size), targets.view(-1)) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| if num_batches % 10 == 0: | |
| print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\r') | |
| print(f" Batch {num_batches}, Final Loss: {loss.item():.4f}") | |
| return total_loss / max(1, num_batches) | |
| def generate_text(model, tokenizer, prompt="Once upon a time", max_len=100, temperature=0.8): | |
| model.eval() | |
| tokens = tokenizer.encode(prompt) | |
| input_ids = torch.tensor([tokens]).to(model.token_embedding.weight.device) | |
| print(f"Generating (Prompt: '{prompt}')...") | |
| for _ in range(max_len): | |
| with torch.no_grad(): | |
| logits = model(input_ids) | |
| last_logits = logits[:, -1, :] / temperature | |
| probs = F.softmax(last_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| if next_token.item() == tokenizer.SPECIAL_TOKENS['<eos>']: | |
| break | |
| return tokenizer.decode(input_ids[0]) | |
| # ============================================ | |
| # 4. CHECKPOINTING | |
| # ============================================ | |
| def save_checkpoint(model, tokenizer, path): | |
| # Extract only new words (excluding special tokens which are added by init) | |
| # Actually, simpler to just save the ordered list of all tokens or just the vocab words | |
| # NLTKTokenizer constructor takes 'vocab' which are words to ADD to special tokens. | |
| # We want to reconstruct: word2id | |
| # SPECIAL_TOKENS are positions 0-5. | |
| # User words start at 6. | |
| full_vocab_list = [tokenizer.id2word[i] for i in range(tokenizer.vocab_size)] | |
| # Filter out special tokens | |
| special_tokens = list(tokenizer.SPECIAL_TOKENS.keys()) | |
| saved_vocab = [w for w in full_vocab_list if w not in special_tokens] | |
| checkpoint = { | |
| 'model_state': model.state_dict(), | |
| 'vocab_words': saved_vocab, | |
| 'config': { | |
| 'd_model': model.d_model, | |
| 'num_heads': model.num_heads, | |
| 'num_layers': model.num_layers, | |
| 'd_ff': model.d_ff, | |
| 'vocab_size': model.vocab_size, | |
| } | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"Model and vocabulary saved to {path}") | |
| def load_checkpoint(path): | |
| if not os.path.exists(path): | |
| return None, None | |
| print(f"Loading checkpoint from {path}...") | |
| checkpoint = torch.load(path) | |
| vocab_words = checkpoint['vocab_words'] | |
| config = checkpoint['config'] | |
| # Reconstruct tokenizer | |
| tokenizer = NLTKTokenizer(vocab=vocab_words) | |
| # Initialize model | |
| model = TransformerV6( | |
| vocab_size=tokenizer.vocab_size, | |
| d_model=config['d_model'], | |
| num_heads=config['num_heads'], | |
| num_layers=config['num_layers'], | |
| d_ff=config.get('d_ff', 1024), | |
| ) | |
| model.load_state_dict(checkpoint['model_state']) | |
| model.eval() | |
| return model, tokenizer | |
| def select_model_file(default_path=None): | |
| """Allow user to select a model file""" | |
| os.makedirs("models_cache", exist_ok=True) | |
| files = glob.glob("models_cache/model_v6*.pth") | |
| files.sort(key=os.path.getmtime, reverse=True) | |
| if not files: | |
| print("No existing v6 model files found.") | |
| return default_path | |
| print("\nAvailable Models:") | |
| for i, f in enumerate(files): | |
| size_mb = os.path.getsize(f) / (1024 * 1024) | |
| print(f"{i+1}. {f} ({size_mb:.2f} MB)") | |
| print(f"{len(files)+1}. Cancel / Use New Name ({default_path})") | |
| choice = safe_input(f"Select model (1-{len(files)+1}, default 1): ", "1") | |
| try: | |
| idx = int(choice) - 1 | |
| if 0 <= idx < len(files): | |
| return files[idx] | |
| except ValueError: | |
| pass | |
| return default_path | |
| # ============================================ | |
| # 5. MAIN EXECUTION MENU | |
| # ============================================ | |
| def main(): | |
| print("Transformer Architecture v6 (NLTK Tokenization)") | |
| print("=============================================") | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("models_cache", exist_ok=True) | |
| MODEL_PATH = f"models_cache/model_v6_{timestamp}.pth" | |
| # State variables | |
| model = None | |
| tokenizer = None | |
| while True: | |
| reset_interrupt() | |
| print("\n\nMAIN MENU") | |
| print("1. Train New Model (TinyStories)") | |
| print("2. Load Model & Resume Training") | |
| print("3. Generate Text") | |
| print("4. Exit") | |
| choice = safe_input("\nEnter choice (1-4): ", "4") | |
| if choice == '1': | |
| # --- TRAIN NEW --- | |
| percent_str = safe_input("Enter percentage of TinyStories to load (0.01-100, default 0.5): ", "0.5") | |
| try: | |
| percent = float(percent_str) | |
| except: | |
| percent = 0.5 | |
| print("\n[1] Loading Data...") | |
| raw_text = load_tiny_stories(percent=percent) | |
| print("\n[2] Building Vocabulary...") | |
| vocab_size_str = safe_input("Max vocab size (default 10000): ", "10000") | |
| # Use batched vocab builder | |
| vocab_words = build_vocab_batched(raw_text, max_vocab_size=int(vocab_size_str), batch_size=1000) | |
| if not vocab_words: | |
| print("Vocabulary is empty or interrupted. Exiting training setup.") | |
| continue | |
| tokenizer = NLTKTokenizer(vocab=vocab_words) | |
| print(f"Tokenizer ready. Vocab size: {tokenizer.vocab_size}") | |
| print("\n[3] Initializing Model...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Training on: {device}") | |
| model = TransformerV6( | |
| vocab_size=tokenizer.vocab_size, | |
| d_model=256, | |
| num_heads=8, | |
| num_layers=4, | |
| d_ff=1024 | |
| ).to(device) | |
| print(f"Params: {model.get_num_params():,}") | |
| print("\n[4] Encoding Data...") | |
| # Use batch encoding for progress | |
| train_ids = batch_encode(tokenizer, raw_text, batch_size=1000) | |
| print("\n[5] Training Loop (Press Ctrl+C to stop & save)") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>']) | |
| epochs = int(safe_input("Epochs (default 5): ", "5")) | |
| try: | |
| for epoch in range(1, epochs + 1): | |
| if is_interrupted(): break | |
| print(f"\n--- Epoch {epoch}/{epochs} ---") | |
| avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64, | |
| optimizer=optimizer, criterion=criterion) | |
| print(f"Avg Loss: {avg_loss:.4f}") | |
| if epoch % 2 == 0: | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| except KeyboardInterrupt: | |
| print("\nInterrupted.") | |
| if is_interrupted(): | |
| print("\n[!] Loop interrupted. Saving checkpoint...") | |
| # Save at the end | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '2': | |
| # --- LOAD & RESUME --- | |
| path = select_model_file() | |
| if not path: | |
| continue | |
| model, tokenizer = load_checkpoint(path) | |
| if not model: | |
| print("Failed to load model.") | |
| continue | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| MODEL_PATH = path # Continue saving to same file | |
| print("\nNote: To resume training, we need data.") | |
| percent_str = safe_input("Enter percentage of TinyStories to load for training (default 0.5): ", "0.5") | |
| raw_text = load_tiny_stories(percent=float(percent_str)) | |
| print("Encoding data with loaded tokenizer...") | |
| train_ids = batch_encode(tokenizer, raw_text, batch_size=1000) | |
| print("\nResuming Training (Press Ctrl+C to stop & save)") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) | |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.SPECIAL_TOKENS['<pad>']) | |
| epochs = int(safe_input("Additional Epochs (default 5): ", "5")) | |
| try: | |
| for epoch in range(1, epochs + 1): | |
| if is_interrupted(): break | |
| print(f"\n--- Epoch {epoch}/{epochs} ---") | |
| avg_loss = train_epoch(model, tokenizer, train_ids, batch_size=32, seq_len=64, | |
| optimizer=optimizer, criterion=criterion) | |
| print(f"Avg Loss: {avg_loss:.4f}") | |
| except KeyboardInterrupt: | |
| print("\nInterrupted.") | |
| if is_interrupted(): | |
| print("\n[!] Loop interrupted. Saving checkpoint...") | |
| save_checkpoint(model, tokenizer, MODEL_PATH) | |
| elif choice == '3': | |
| # --- GENERATE --- | |
| if model is None: | |
| path = select_model_file() | |
| if path: | |
| model, tokenizer = load_checkpoint(path) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| if model is None: | |
| print("No model loaded. Please Train or Load first.") | |
| continue | |
| print("\n--- Text Generation ---") | |
| print("1. Custom Prompt") | |
| print("2. Preset: 'Once upon a time'") | |
| print("3. Preset: 'The little dog'") | |
| p_choice = safe_input("Choice (1-3): ", "2") | |
| if p_choice == '1': | |
| prompt = safe_input("Enter prompt: ", "Once upon a time") | |
| elif p_choice == '3': | |
| prompt = "The little dog" | |
| else: | |
| prompt = "Once upon a time" | |
| length = int(safe_input("Length (default 100): ", "100")) | |
| temp = float(safe_input("Temperature (default 0.8): ", "0.8")) | |
| generated = generate_text(model, tokenizer, prompt, max_len=length, temperature=temp) | |
| print(f"\n[OUTPUT]\n{generated}\n") | |
| elif choice == '4': | |
| print("Exiting.") | |
| break | |
| else: | |
| print("Invalid choice.") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment