Last active
January 8, 2026 14:21
-
-
Save amitpuri/6721307c293a120eec34911902eccae5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Transformer Architecture v8 | |
| ================================================================== | |
| 1. **Architecture**: GPT-2 style with Learnable Positional Embeddings & aligned naming. | |
| 2. **Pre-training**: Train from scratch on TinyStories (restored from v7). | |
| 3. **Weight Loading**: Load OpenAI GPT-2 pretrained weights. | |
| 4. **Fine-tuning**: Instruction fine-tuning (Alpaca style). | |
| 5. **Persistence**: Save/Load checkpoints (restored from v7). | |
| Dependencies: torch, tiktoken, requests, tqdm, numpy, datasets (for training), tensorflow (for weight loading) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| import signal | |
| import sys | |
| import time | |
| import glob | |
| import json | |
| import array | |
| import requests | |
| import tiktoken | |
| import numpy as np | |
| from tqdm import tqdm | |
| # ============================================ | |
| # 0. CONFIG & GLOBALS | |
| # ============================================ | |
| # Default GPT-2 124M Config | |
| GPT_CONFIG_124M = { | |
| "vocab_size": 50257, | |
| "context_length": 1024, | |
| "emb_dim": 768, | |
| "n_heads": 12, | |
| "n_layers": 12, | |
| "drop_rate": 0.1, | |
| "qkv_bias": True | |
| } | |
| _interrupted = False | |
| def signal_handler(signum, frame): | |
| global _interrupted | |
| _interrupted = True | |
| print("\n\n[!] Interrupt received. Finishing current operation gracefully...") | |
| signal.signal(signal.SIGINT, signal_handler) | |
| def is_interrupted(): | |
| return _interrupted | |
| def safe_input(prompt, default_value): | |
| try: | |
| user_input = input(prompt) | |
| return user_input if user_input.strip() else default_value | |
| except EOFError: | |
| return default_value | |
| # ============================================ | |
| # 1. ARCHITECTURE (GPT-2 Style) | |
| # ============================================ | |
| class LayerNorm(nn.Module): | |
| def __init__(self, emb_dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(emb_dim)) | |
| self.shift = nn.Parameter(torch.zeros(emb_dim)) | |
| def forward(self, x): | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = x.var(dim=-1, keepdim=True, unbiased=False) | |
| norm_x = (x - mean) / torch.sqrt(var + self.eps) | |
| return self.scale * norm_x + self.shift | |
| class GELU(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return 0.5 * x * (1 + torch.tanh( | |
| torch.sqrt(torch.tensor(2.0 / torch.pi)) * | |
| (x + 0.044715 * torch.pow(x, 3)) | |
| )) | |
| class FeedForward(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), | |
| GELU(), | |
| nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), | |
| ) | |
| self.dropout = nn.Dropout(cfg["drop_rate"]) | |
| def forward(self, x): | |
| return self.dropout(self.layers(x)) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): | |
| super().__init__() | |
| assert d_out % num_heads == 0, "d_out must be divisible by num_heads" | |
| self.d_out = d_out | |
| self.num_heads = num_heads | |
| self.head_dim = d_out // num_heads | |
| self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) | |
| self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) | |
| self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) | |
| self.out_proj = nn.Linear(d_out, d_out) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer( | |
| "mask", | |
| torch.triu(torch.ones(context_length, context_length), diagonal=1) | |
| ) | |
| def forward(self, x): | |
| b, num_tokens, d_in = x.shape | |
| keys = self.W_key(x) | |
| queries = self.W_query(x) | |
| values = self.W_value(x) | |
| keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) | |
| values = values.view(b, num_tokens, self.num_heads, self.head_dim) | |
| queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) | |
| keys = keys.transpose(1, 2) | |
| queries = queries.transpose(1, 2) | |
| values = values.transpose(1, 2) | |
| attn_scores = queries @ keys.transpose(2, 3) | |
| mask_bool = self.mask.bool()[:num_tokens, :num_tokens] | |
| attn_scores.masked_fill_(mask_bool, -torch.inf) | |
| attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| context_vec = (attn_weights @ values).transpose(1, 2) | |
| context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) | |
| context_vec = self.out_proj(context_vec) | |
| return context_vec | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.att = MultiHeadAttention( | |
| d_in=cfg["emb_dim"], | |
| d_out=cfg["emb_dim"], | |
| context_length=cfg["context_length"], | |
| num_heads=cfg["n_heads"], | |
| dropout=cfg["drop_rate"], | |
| qkv_bias=cfg["qkv_bias"]) | |
| self.ff = FeedForward(cfg) | |
| self.norm1 = LayerNorm(cfg["emb_dim"]) | |
| self.norm2 = LayerNorm(cfg["emb_dim"]) | |
| self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = self.att(x) | |
| x = self.drop_shortcut(x) | |
| x = x + shortcut | |
| shortcut = x | |
| x = self.norm2(x) | |
| x = self.ff(x) | |
| x = self.drop_shortcut(x) | |
| x = x + shortcut | |
| return x | |
| class GPTModel(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg # Save config for checkpointing | |
| self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) | |
| self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) | |
| self.drop_emb = nn.Dropout(cfg["drop_rate"]) | |
| self.trf_blocks = nn.Sequential( | |
| *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])] | |
| ) | |
| self.final_norm = LayerNorm(cfg["emb_dim"]) | |
| self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) | |
| def forward(self, in_idx): | |
| batch_size, seq_len = in_idx.shape | |
| tok_embeds = self.tok_emb(in_idx) | |
| pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) | |
| x = tok_embeds + pos_embeds | |
| x = self.drop_emb(x) | |
| x = self.trf_blocks(x) | |
| x = self.final_norm(x) | |
| logits = self.out_head(x) | |
| return logits | |
| def get_num_params(self): | |
| return sum(p.numel() for p in self.parameters()) | |
| # ============================================ | |
| # 2. DATA UTILS & TRAINING (Restored from v7) | |
| # ============================================ | |
| def get_tiny_stories_generator(percent=1.0): | |
| """Yields stories from TinyStories dataset""" | |
| try: | |
| from datasets import load_dataset | |
| except ImportError: | |
| print("[!] 'datasets' library needed for training from scratch (pip install datasets)") | |
| return None, 0 | |
| print(f"Loading {percent}% of TinyStories dataset (Streaming mode)...") | |
| dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True) | |
| total_count = 2119719 | |
| target_count = max(1, int(total_count * (percent / 100.0))) | |
| def story_yield(): | |
| for i, entry in enumerate(dataset): | |
| if i >= target_count: break | |
| yield entry['text'] | |
| return story_yield(), target_count | |
| def batch_encode(tokenizer, story_generator, total_target, batch_size=1000): | |
| """Encode stories into a single large array of token IDs""" | |
| print(f" Encoding approx {total_target} stories...") | |
| # Use array.array 'H' (unsigned short) for efficiency. GPT-2 vocab is ~50k. | |
| all_tokens = array.array('H') | |
| for i, story in enumerate(story_generator): | |
| if is_interrupted(): break | |
| if story.strip(): | |
| # tiktoken encode | |
| tokens = tokenizer.encode(story, allowed_special={'<|endoftext|>'}) | |
| all_tokens.extend(tokens) | |
| all_tokens.append(50256) # EOS | |
| if (i + 1) % batch_size == 0: | |
| print(f" Processed {i + 1}/{total_target}...", end='\\r') | |
| print(f"\\n Finished. Total tokens: {len(all_tokens)}") | |
| return all_tokens | |
| def train_epoch(model, data_ids, batch_size, seq_len, optimizer, device): | |
| model.train() | |
| total_loss = 0 | |
| num_batches = 0 | |
| n_tokens = len(data_ids) | |
| # stride = batch_size * seq_len | |
| for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len): | |
| if is_interrupted(): break | |
| inputs_list = [] | |
| targets_list = [] | |
| for b in range(batch_size): | |
| start_idx = i + b * seq_len | |
| if start_idx + seq_len + 1 >= n_tokens: break | |
| chunk = data_ids[start_idx : start_idx + seq_len + 1] | |
| inputs_list.append(list(chunk[:-1])) | |
| targets_list.append(list(chunk[1:])) | |
| if not inputs_list: break | |
| inputs = torch.tensor(inputs_list).to(device) | |
| targets = torch.tensor(targets_list).to(device) # Shape: [B, T] | |
| optimizer.zero_grad() | |
| logits = model(inputs) # [B, T, Vocab] | |
| loss = F.cross_entropy(logits.flatten(0, 1), targets.flatten()) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| if num_batches % 10 == 0: | |
| print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\\r') | |
| if num_batches > 0: | |
| return total_loss / num_batches | |
| return 0.0 | |
| # ============================================ | |
| # 3. WEIGHT DOWNLOADING (GPT-2) | |
| # ============================================ | |
| def download_file(url, destination): | |
| try: | |
| response = requests.get(url, stream=True, timeout=60) | |
| response.raise_for_status() | |
| file_size = int(response.headers.get("content-length", 0)) | |
| if os.path.exists(destination) and os.path.getsize(destination) == file_size: | |
| print(f"File exists: {destination}") | |
| return True | |
| with tqdm(total=file_size, unit="B", unit_scale=True, desc=os.path.basename(destination)) as bar: | |
| with open(destination, "wb") as f: | |
| for chunk in response.iter_content(1024): | |
| f.write(chunk) | |
| bar.update(len(chunk)) | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading {url}: {e}") | |
| return False | |
| def download_and_load_gpt2_params(model_size, models_dir="models_cache/gpt2"): | |
| import tensorflow as tf | |
| import numpy as np | |
| allowed_sizes = ("124M", "355M", "774M", "1558M") | |
| if model_size not in allowed_sizes: raise ValueError("Invalid model size") | |
| model_dir = os.path.join(models_dir, model_size) | |
| os.makedirs(model_dir, exist_ok=True) | |
| base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models" | |
| filenames = ["checkpoint", "encoder.json", "hparams.json", "model.ckpt.data-00000-of-00001", "model.ckpt.index", "model.ckpt.meta", "vocab.bpe"] | |
| for filename in filenames: | |
| if not download_file(f"{base_url}/{model_size}/{filename}", os.path.join(model_dir, filename)): | |
| return None | |
| tf_ckpt_path = tf.train.latest_checkpoint(model_dir) | |
| settings = json.load(open(os.path.join(model_dir, "hparams.json"))) | |
| params = {"blocks": [{} for _ in range(settings["n_layer"])]} | |
| print("Loading TF weights...") | |
| for name, _ in tf.train.list_variables(tf_ckpt_path): | |
| val = np.squeeze(tf.train.load_variable(tf_ckpt_path, name)) | |
| parts = name.split("/")[1:] | |
| target = params | |
| if parts[0].startswith("h"): | |
| target = params["blocks"][int(parts[0][1:])] | |
| for key in parts[1:-1]: target = target.setdefault(key, {}) | |
| target[parts[-1]] = val | |
| return params | |
| def load_weights_into_gpt(gpt, params): | |
| def assign(left, right): | |
| if left.shape != right.shape: raise ValueError(f"Shape mismatch {left.shape} vs {right.shape}") | |
| return torch.nn.Parameter(torch.tensor(right)) | |
| gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe']) | |
| gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte']) | |
| for b in range(len(params["blocks"])): | |
| # Attn | |
| p_attn = params["blocks"][b]["attn"] | |
| q_w, k_w, v_w = np.split(p_attn["c_attn"]["w"], 3, axis=-1) | |
| gpt.trf_blocks[b].att.W_query.weight = assign(gpt.trf_blocks[b].att.W_query.weight, q_w.T) | |
| gpt.trf_blocks[b].att.W_key.weight = assign(gpt.trf_blocks[b].att.W_key.weight, k_w.T) | |
| gpt.trf_blocks[b].att.W_value.weight = assign(gpt.trf_blocks[b].att.W_value.weight, v_w.T) | |
| q_b, k_b, v_b = np.split(p_attn["c_attn"]["b"], 3, axis=-1) | |
| gpt.trf_blocks[b].att.W_query.bias = assign(gpt.trf_blocks[b].att.W_query.bias, q_b) | |
| gpt.trf_blocks[b].att.W_key.bias = assign(gpt.trf_blocks[b].att.W_key.bias, k_b) | |
| gpt.trf_blocks[b].att.W_value.bias = assign(gpt.trf_blocks[b].att.W_value.bias, v_b) | |
| gpt.trf_blocks[b].att.out_proj.weight = assign(gpt.trf_blocks[b].att.out_proj.weight, p_attn["c_proj"]["w"].T) | |
| gpt.trf_blocks[b].att.out_proj.bias = assign(gpt.trf_blocks[b].att.out_proj.bias, p_attn["c_proj"]["b"]) | |
| # FF | |
| p_mlp = params["blocks"][b]["mlp"] | |
| gpt.trf_blocks[b].ff.layers[0].weight = assign(gpt.trf_blocks[b].ff.layers[0].weight, p_mlp["c_fc"]["w"].T) | |
| gpt.trf_blocks[b].ff.layers[0].bias = assign(gpt.trf_blocks[b].ff.layers[0].bias, p_mlp["c_fc"]["b"]) | |
| gpt.trf_blocks[b].ff.layers[2].weight = assign(gpt.trf_blocks[b].ff.layers[2].weight, p_mlp["c_proj"]["w"].T) | |
| gpt.trf_blocks[b].ff.layers[2].bias = assign(gpt.trf_blocks[b].ff.layers[2].bias, p_mlp["c_proj"]["b"]) | |
| # Norms | |
| gpt.trf_blocks[b].norm1.scale = assign(gpt.trf_blocks[b].norm1.scale, params["blocks"][b]["ln_1"]["g"]) | |
| gpt.trf_blocks[b].norm1.shift = assign(gpt.trf_blocks[b].norm1.shift, params["blocks"][b]["ln_1"]["b"]) | |
| gpt.trf_blocks[b].norm2.scale = assign(gpt.trf_blocks[b].norm2.scale, params["blocks"][b]["ln_2"]["g"]) | |
| gpt.trf_blocks[b].norm2.shift = assign(gpt.trf_blocks[b].norm2.shift, params["blocks"][b]["ln_2"]["b"]) | |
| gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"]) | |
| gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"]) | |
| gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) | |
| # ============================================ | |
| # 4. CHECKPOINTING & GEN UTILS | |
| # ============================================ | |
| def save_checkpoint(model, path, optimizer=None): | |
| checkpoint = { | |
| 'model_state': model.state_dict(), | |
| 'config': model.cfg, | |
| } | |
| if optimizer: | |
| checkpoint['optimizer_state'] = optimizer.state_dict() | |
| torch.save(checkpoint, path) | |
| print(f"Saved model to {path}") | |
| def load_checkpoint(path, device): | |
| if not os.path.exists(path): | |
| print(f"Checkpoint {path} not found.") | |
| return None | |
| print(f"Loading {path}...") | |
| ckpt = torch.load(path, map_location=device) | |
| config = ckpt['config'] | |
| model = GPTModel(config).to(device) | |
| model.load_state_dict(ckpt['model_state']) | |
| return model | |
| def select_model_file(): | |
| os.makedirs("models_cache", exist_ok=True) | |
| files = glob.glob("models_cache/*.pth") | |
| if not files: return None | |
| print("\\nAvailable Models:") | |
| for i, f in enumerate(files): | |
| print(f"{i+1}. {f}") | |
| choice = safe_input(f"Select (1-{len(files)}): ", "1") | |
| try: | |
| idx = int(choice) - 1 | |
| if 0 <= idx < len(files): return files[idx] | |
| except: pass | |
| return None | |
| def text_to_token_ids(text, tokenizer): | |
| encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'}) | |
| return torch.tensor(encoded).unsqueeze(0) | |
| def token_ids_to_text(token_ids, tokenizer): | |
| return tokenizer.decode(token_ids.squeeze(0).tolist()) | |
| def generate_text_simple(model, idx, max_new_tokens, context_size): | |
| for _ in range(max_new_tokens): | |
| if is_interrupted(): break | |
| idx_cond = idx[:, -context_size:] | |
| with torch.no_grad(): | |
| logits = model(idx_cond) | |
| logits = logits[:, -1, :] | |
| idx_next = torch.argmax(logits, dim=-1, keepdim=True) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| if idx_next == 50256: break | |
| return idx | |
| # ============================================ | |
| # 5. INSTRUCTION TUNING UTILS | |
| # ============================================ | |
| def format_input(entry): | |
| instruction_text = ( | |
| f"Below is an instruction that describes a task. " | |
| f"Write a response that appropriately completes the request." | |
| f"\\n\\n### Instruction:\\n{entry['instruction']}" | |
| ) | |
| input_text = f"\\n\\n### Input:\\n{entry['input']}" if entry["input"] else "" | |
| return instruction_text + input_text | |
| def get_instruction_batch_generator(data_path, tokenizer, batch_size=8, device='cpu'): | |
| with open(data_path, "r") as f: data = json.load(f) | |
| print(f"Loaded {len(data)} instruction entries.") | |
| def generator(): | |
| while True: | |
| batch_data = [] | |
| for _ in range(batch_size): | |
| entry = data[np.random.randint(0, len(data))] | |
| input_text = format_input(entry) | |
| full_text = input_text + f"\\n\\n### Response:\\n{entry['output']}" | |
| encoded = tokenizer.encode(full_text, allowed_special={'<|endoftext|>'}) | |
| encoded.append(50256) # EOS | |
| input_len = len(tokenizer.encode(input_text, allowed_special={'<|endoftext|>'})) | |
| targets = encoded.copy() | |
| targets[:input_len] = [-100] * input_len | |
| batch_data.append((encoded, targets)) | |
| max_len = max(len(x[0]) for x in batch_data) | |
| padded_inputs, padded_targets = [], [] | |
| for enc, tgt in batch_data: | |
| pad_len = max_len - len(enc) | |
| padded_inputs.append(enc + [50256]*pad_len) | |
| padded_targets.append(tgt + [-100]*pad_len) | |
| yield (torch.tensor(padded_inputs).to(device), torch.tensor(padded_targets).to(device)) | |
| return generator(), len(data) | |
| # ============================================ | |
| # 6. MAIN | |
| # ============================================ | |
| def main(): | |
| print("Transformer Architecture v8 (Universal)") | |
| print("=======================================") | |
| tokenizer = tiktoken.get_encoding("gpt2") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Device: {device}") | |
| os.makedirs("models_cache", exist_ok=True) | |
| model = None | |
| while True: | |
| global _interrupted | |
| _interrupted = False | |
| print("\nMAIN MENU") | |
| print("1. Train New Model (TinyStories)") | |
| print("2. Load Pretrained GPT-2 (Weights)") | |
| print("3. Load Checkpoint (Local)") | |
| print("4. Generate Text") | |
| print("5. Instruction Fine-tune") | |
| print("6. Exit") | |
| choice = safe_input("\nEnter choice: ", "6") | |
| # --- 1. Train New Model --- | |
| if choice == '1': | |
| percent = float(safe_input("Dataset % (0.1-100, default 0.1): ", "0.1")) | |
| story_gen, target_count = get_tiny_stories_generator(percent) | |
| if not story_gen: continue | |
| # Cache logic | |
| os.makedirs("datasets_cache", exist_ok=True) | |
| cache_file = f"datasets_cache/tinystories_p{percent}.bin" | |
| if os.path.exists(cache_file): | |
| print(f"Loading tokens from {cache_file}...") | |
| data_ids = array.array('H') | |
| with open(cache_file, 'rb') as f: | |
| data_ids.fromfile(f, os.path.getsize(cache_file)//2) | |
| else: | |
| data_ids = batch_encode(tokenizer, story_gen, target_count) | |
| if data_ids: | |
| print(f"Saving to {cache_file}...") | |
| with open(cache_file, 'wb') as f: data_ids.tofile(f) | |
| if not data_ids: continue | |
| # Init Model | |
| print("Initializing GPT-2 Small Config...") | |
| # Use smaller config for faster playground training if desired, | |
| # but User wanted standard GPT-2. We'll use GPT_CONFIG_124M but maybe fewer layers for speed? | |
| # Let's stick to 124M logic but allow user to override if they want, | |
| # for now, use standard config. | |
| cfg = GPT_CONFIG_124M.copy() | |
| # cfg['n_layers'] = 4 # Uncomment for faster testing | |
| model = GPTModel(cfg).to(device) | |
| print(f"Params: {model.get_num_params():,}") | |
| epochs = int(safe_input("Epochs (default 1): ", "1")) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1) | |
| try: | |
| for epoch in range(epochs): | |
| if is_interrupted(): break | |
| loss = train_epoch(model, data_ids, 8, 256, optimizer, device) # batch 8, seq 256 | |
| print(f"Epoch {epoch+1} Loss: {loss:.4f}") | |
| save_checkpoint(model, f"models_cache/v8_scratch_e{epoch+1}.pth", optimizer) | |
| except KeyboardInterrupt: | |
| print("Interrupted") | |
| # --- 2. Load GPT-2 Weights --- | |
| elif choice == '2': | |
| try: import tensorflow | |
| except: | |
| print("Install tensorflow first.") | |
| continue | |
| params = download_and_load_gpt2_params("124M") | |
| if params: | |
| model = GPTModel(GPT_CONFIG_124M) | |
| load_weights_into_gpt(model, params) | |
| model.to(device) | |
| save_checkpoint(model, "models_cache/gpt2_124m_converted.pth") | |
| # --- 3. Load Checkpoint --- | |
| elif choice == '3': | |
| path = select_model_file() | |
| if path: | |
| model = load_checkpoint(path, device) | |
| # --- 4. Generate --- | |
| elif choice == '4': | |
| if not model: | |
| print("No model loaded.") | |
| continue | |
| prompt = safe_input("Prompt: ", "Once upon a time") | |
| out = generate_text_simple(model, text_to_token_ids(prompt, tokenizer).to(device), 50, 1024) | |
| print("Output:", token_ids_to_text(out, tokenizer)) | |
| # --- 5. Fine-tune --- | |
| elif choice == '5': | |
| if not model: | |
| print("Load a model first!") | |
| continue | |
| data_file = "instruction-data.json" | |
| if not os.path.exists(data_file): | |
| # ... [Synthetic generator code from previous step] ... | |
| print(f"Creating synthetic {data_file}...") | |
| synthetic_data = [ | |
| {"instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris."}, | |
| {"instruction": "Summarize the following text.", "input": "Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.", "output": "AI is machine intelligence, distinct from animal or human natural intelligence."}, | |
| {"instruction": "Write a python function to add two numbers.", "input": "", "output": "def add(a, b):\n return a + b"}, | |
| {"instruction": "Translate 'Hello' to Spanish.", "input": "", "output": "Hola"}, | |
| {"instruction": "Explain the concept of gravity.", "input": "", "output": "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."}, | |
| {"instruction": "List three primary colors.", "input": "", "output": "Red, Blue, Yellow"}, | |
| {"instruction": "Who wrote 'Romeo and Juliet'?", "input": "", "output": "William Shakespeare"}, | |
| {"instruction": "Convert 100 degrees Celsius to Fahrenheit.", "input": "", "output": "212 degrees Fahrenheit"}, | |
| {"instruction": "What is the largest planet in our solar system?", "input": "", "output": "Jupiter"}, | |
| {"instruction": "Compose a haiku about code.", "input": "", "output": "Logic flows clearly,\nBugs hide in the shadows deep,\nScreen light glows all night."} | |
| ] * 10 | |
| with open(data_file, "w") as f: json.dump(synthetic_data, f, indent=2) | |
| gen, count = get_instruction_batch_generator(data_file, tokenizer, device=device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.1) | |
| model.train() | |
| print("Fine-tuning (Ctrl+C to stop)...") | |
| try: | |
| for step in range(20): | |
| if is_interrupted(): break | |
| inputs, targets = next(gen) | |
| optimizer.zero_grad() | |
| logits = model(inputs) | |
| loss = F.cross_entropy(logits[:, :-1, :].flatten(0,1), targets[:, 1:].flatten(), ignore_index=-100) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Step {step+1} Loss: {loss.item():.4f}") | |
| except KeyboardInterrupt: pass | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| save_path = f"models_cache/v8_finetuned_{timestamp}.pth" | |
| save_checkpoint(model, save_path, optimizer) | |
| elif choice == '6': | |
| break | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment