Skip to content

Instantly share code, notes, and snippets.

@amitpuri
Last active January 8, 2026 14:21
Show Gist options
  • Select an option

  • Save amitpuri/6721307c293a120eec34911902eccae5 to your computer and use it in GitHub Desktop.

Select an option

Save amitpuri/6721307c293a120eec34911902eccae5 to your computer and use it in GitHub Desktop.
"""
Transformer Architecture v8
==================================================================
1. **Architecture**: GPT-2 style with Learnable Positional Embeddings & aligned naming.
2. **Pre-training**: Train from scratch on TinyStories (restored from v7).
3. **Weight Loading**: Load OpenAI GPT-2 pretrained weights.
4. **Fine-tuning**: Instruction fine-tuning (Alpaca style).
5. **Persistence**: Save/Load checkpoints (restored from v7).
Dependencies: torch, tiktoken, requests, tqdm, numpy, datasets (for training), tensorflow (for weight loading)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import signal
import sys
import time
import glob
import json
import array
import requests
import tiktoken
import numpy as np
from tqdm import tqdm
# ============================================
# 0. CONFIG & GLOBALS
# ============================================
# Default GPT-2 124M Config
GPT_CONFIG_124M = {
"vocab_size": 50257,
"context_length": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": True
}
_interrupted = False
def signal_handler(signum, frame):
global _interrupted
_interrupted = True
print("\n\n[!] Interrupt received. Finishing current operation gracefully...")
signal.signal(signal.SIGINT, signal_handler)
def is_interrupted():
return _interrupted
def safe_input(prompt, default_value):
try:
user_input = input(prompt)
return user_input if user_input.strip() else default_value
except EOFError:
return default_value
# ============================================
# 1. ARCHITECTURE (GPT-2 Style)
# ============================================
class LayerNorm(nn.Module):
def __init__(self, emb_dim, eps=1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
self.dropout = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
return self.dropout(self.layers(x))
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
shortcut = x
x = self.norm1(x)
x = self.att(x)
x = self.drop_shortcut(x)
x = x + shortcut
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg # Save config for checkpointing
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def get_num_params(self):
return sum(p.numel() for p in self.parameters())
# ============================================
# 2. DATA UTILS & TRAINING (Restored from v7)
# ============================================
def get_tiny_stories_generator(percent=1.0):
"""Yields stories from TinyStories dataset"""
try:
from datasets import load_dataset
except ImportError:
print("[!] 'datasets' library needed for training from scratch (pip install datasets)")
return None, 0
print(f"Loading {percent}% of TinyStories dataset (Streaming mode)...")
dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
total_count = 2119719
target_count = max(1, int(total_count * (percent / 100.0)))
def story_yield():
for i, entry in enumerate(dataset):
if i >= target_count: break
yield entry['text']
return story_yield(), target_count
def batch_encode(tokenizer, story_generator, total_target, batch_size=1000):
"""Encode stories into a single large array of token IDs"""
print(f" Encoding approx {total_target} stories...")
# Use array.array 'H' (unsigned short) for efficiency. GPT-2 vocab is ~50k.
all_tokens = array.array('H')
for i, story in enumerate(story_generator):
if is_interrupted(): break
if story.strip():
# tiktoken encode
tokens = tokenizer.encode(story, allowed_special={'<|endoftext|>'})
all_tokens.extend(tokens)
all_tokens.append(50256) # EOS
if (i + 1) % batch_size == 0:
print(f" Processed {i + 1}/{total_target}...", end='\\r')
print(f"\\n Finished. Total tokens: {len(all_tokens)}")
return all_tokens
def train_epoch(model, data_ids, batch_size, seq_len, optimizer, device):
model.train()
total_loss = 0
num_batches = 0
n_tokens = len(data_ids)
# stride = batch_size * seq_len
for i in range(0, n_tokens - batch_size * (seq_len + 1), batch_size * seq_len):
if is_interrupted(): break
inputs_list = []
targets_list = []
for b in range(batch_size):
start_idx = i + b * seq_len
if start_idx + seq_len + 1 >= n_tokens: break
chunk = data_ids[start_idx : start_idx + seq_len + 1]
inputs_list.append(list(chunk[:-1]))
targets_list.append(list(chunk[1:]))
if not inputs_list: break
inputs = torch.tensor(inputs_list).to(device)
targets = torch.tensor(targets_list).to(device) # Shape: [B, T]
optimizer.zero_grad()
logits = model(inputs) # [B, T, Vocab]
loss = F.cross_entropy(logits.flatten(0, 1), targets.flatten())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if num_batches % 10 == 0:
print(f" Batch {num_batches}, Loss: {loss.item():.4f}", end='\\r')
if num_batches > 0:
return total_loss / num_batches
return 0.0
# ============================================
# 3. WEIGHT DOWNLOADING (GPT-2)
# ============================================
def download_file(url, destination):
try:
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
file_size = int(response.headers.get("content-length", 0))
if os.path.exists(destination) and os.path.getsize(destination) == file_size:
print(f"File exists: {destination}")
return True
with tqdm(total=file_size, unit="B", unit_scale=True, desc=os.path.basename(destination)) as bar:
with open(destination, "wb") as f:
for chunk in response.iter_content(1024):
f.write(chunk)
bar.update(len(chunk))
return True
except Exception as e:
print(f"Error downloading {url}: {e}")
return False
def download_and_load_gpt2_params(model_size, models_dir="models_cache/gpt2"):
import tensorflow as tf
import numpy as np
allowed_sizes = ("124M", "355M", "774M", "1558M")
if model_size not in allowed_sizes: raise ValueError("Invalid model size")
model_dir = os.path.join(models_dir, model_size)
os.makedirs(model_dir, exist_ok=True)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = ["checkpoint", "encoder.json", "hparams.json", "model.ckpt.data-00000-of-00001", "model.ckpt.index", "model.ckpt.meta", "vocab.bpe"]
for filename in filenames:
if not download_file(f"{base_url}/{model_size}/{filename}", os.path.join(model_dir, filename)):
return None
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
print("Loading TF weights...")
for name, _ in tf.train.list_variables(tf_ckpt_path):
val = np.squeeze(tf.train.load_variable(tf_ckpt_path, name))
parts = name.split("/")[1:]
target = params
if parts[0].startswith("h"):
target = params["blocks"][int(parts[0][1:])]
for key in parts[1:-1]: target = target.setdefault(key, {})
target[parts[-1]] = val
return params
def load_weights_into_gpt(gpt, params):
def assign(left, right):
if left.shape != right.shape: raise ValueError(f"Shape mismatch {left.shape} vs {right.shape}")
return torch.nn.Parameter(torch.tensor(right))
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
for b in range(len(params["blocks"])):
# Attn
p_attn = params["blocks"][b]["attn"]
q_w, k_w, v_w = np.split(p_attn["c_attn"]["w"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.weight = assign(gpt.trf_blocks[b].att.W_query.weight, q_w.T)
gpt.trf_blocks[b].att.W_key.weight = assign(gpt.trf_blocks[b].att.W_key.weight, k_w.T)
gpt.trf_blocks[b].att.W_value.weight = assign(gpt.trf_blocks[b].att.W_value.weight, v_w.T)
q_b, k_b, v_b = np.split(p_attn["c_attn"]["b"], 3, axis=-1)
gpt.trf_blocks[b].att.W_query.bias = assign(gpt.trf_blocks[b].att.W_query.bias, q_b)
gpt.trf_blocks[b].att.W_key.bias = assign(gpt.trf_blocks[b].att.W_key.bias, k_b)
gpt.trf_blocks[b].att.W_value.bias = assign(gpt.trf_blocks[b].att.W_value.bias, v_b)
gpt.trf_blocks[b].att.out_proj.weight = assign(gpt.trf_blocks[b].att.out_proj.weight, p_attn["c_proj"]["w"].T)
gpt.trf_blocks[b].att.out_proj.bias = assign(gpt.trf_blocks[b].att.out_proj.bias, p_attn["c_proj"]["b"])
# FF
p_mlp = params["blocks"][b]["mlp"]
gpt.trf_blocks[b].ff.layers[0].weight = assign(gpt.trf_blocks[b].ff.layers[0].weight, p_mlp["c_fc"]["w"].T)
gpt.trf_blocks[b].ff.layers[0].bias = assign(gpt.trf_blocks[b].ff.layers[0].bias, p_mlp["c_fc"]["b"])
gpt.trf_blocks[b].ff.layers[2].weight = assign(gpt.trf_blocks[b].ff.layers[2].weight, p_mlp["c_proj"]["w"].T)
gpt.trf_blocks[b].ff.layers[2].bias = assign(gpt.trf_blocks[b].ff.layers[2].bias, p_mlp["c_proj"]["b"])
# Norms
gpt.trf_blocks[b].norm1.scale = assign(gpt.trf_blocks[b].norm1.scale, params["blocks"][b]["ln_1"]["g"])
gpt.trf_blocks[b].norm1.shift = assign(gpt.trf_blocks[b].norm1.shift, params["blocks"][b]["ln_1"]["b"])
gpt.trf_blocks[b].norm2.scale = assign(gpt.trf_blocks[b].norm2.scale, params["blocks"][b]["ln_2"]["g"])
gpt.trf_blocks[b].norm2.shift = assign(gpt.trf_blocks[b].norm2.shift, params["blocks"][b]["ln_2"]["b"])
gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
# ============================================
# 4. CHECKPOINTING & GEN UTILS
# ============================================
def save_checkpoint(model, path, optimizer=None):
checkpoint = {
'model_state': model.state_dict(),
'config': model.cfg,
}
if optimizer:
checkpoint['optimizer_state'] = optimizer.state_dict()
torch.save(checkpoint, path)
print(f"Saved model to {path}")
def load_checkpoint(path, device):
if not os.path.exists(path):
print(f"Checkpoint {path} not found.")
return None
print(f"Loading {path}...")
ckpt = torch.load(path, map_location=device)
config = ckpt['config']
model = GPTModel(config).to(device)
model.load_state_dict(ckpt['model_state'])
return model
def select_model_file():
os.makedirs("models_cache", exist_ok=True)
files = glob.glob("models_cache/*.pth")
if not files: return None
print("\\nAvailable Models:")
for i, f in enumerate(files):
print(f"{i+1}. {f}")
choice = safe_input(f"Select (1-{len(files)}): ", "1")
try:
idx = int(choice) - 1
if 0 <= idx < len(files): return files[idx]
except: pass
return None
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
return torch.tensor(encoded).unsqueeze(0)
def token_ids_to_text(token_ids, tokenizer):
return tokenizer.decode(token_ids.squeeze(0).tolist())
def generate_text_simple(model, idx, max_new_tokens, context_size):
for _ in range(max_new_tokens):
if is_interrupted(): break
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
idx = torch.cat((idx, idx_next), dim=1)
if idx_next == 50256: break
return idx
# ============================================
# 5. INSTRUCTION TUNING UTILS
# ============================================
def format_input(entry):
instruction_text = (
f"Below is an instruction that describes a task. "
f"Write a response that appropriately completes the request."
f"\\n\\n### Instruction:\\n{entry['instruction']}"
)
input_text = f"\\n\\n### Input:\\n{entry['input']}" if entry["input"] else ""
return instruction_text + input_text
def get_instruction_batch_generator(data_path, tokenizer, batch_size=8, device='cpu'):
with open(data_path, "r") as f: data = json.load(f)
print(f"Loaded {len(data)} instruction entries.")
def generator():
while True:
batch_data = []
for _ in range(batch_size):
entry = data[np.random.randint(0, len(data))]
input_text = format_input(entry)
full_text = input_text + f"\\n\\n### Response:\\n{entry['output']}"
encoded = tokenizer.encode(full_text, allowed_special={'<|endoftext|>'})
encoded.append(50256) # EOS
input_len = len(tokenizer.encode(input_text, allowed_special={'<|endoftext|>'}))
targets = encoded.copy()
targets[:input_len] = [-100] * input_len
batch_data.append((encoded, targets))
max_len = max(len(x[0]) for x in batch_data)
padded_inputs, padded_targets = [], []
for enc, tgt in batch_data:
pad_len = max_len - len(enc)
padded_inputs.append(enc + [50256]*pad_len)
padded_targets.append(tgt + [-100]*pad_len)
yield (torch.tensor(padded_inputs).to(device), torch.tensor(padded_targets).to(device))
return generator(), len(data)
# ============================================
# 6. MAIN
# ============================================
def main():
print("Transformer Architecture v8 (Universal)")
print("=======================================")
tokenizer = tiktoken.get_encoding("gpt2")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
os.makedirs("models_cache", exist_ok=True)
model = None
while True:
global _interrupted
_interrupted = False
print("\nMAIN MENU")
print("1. Train New Model (TinyStories)")
print("2. Load Pretrained GPT-2 (Weights)")
print("3. Load Checkpoint (Local)")
print("4. Generate Text")
print("5. Instruction Fine-tune")
print("6. Exit")
choice = safe_input("\nEnter choice: ", "6")
# --- 1. Train New Model ---
if choice == '1':
percent = float(safe_input("Dataset % (0.1-100, default 0.1): ", "0.1"))
story_gen, target_count = get_tiny_stories_generator(percent)
if not story_gen: continue
# Cache logic
os.makedirs("datasets_cache", exist_ok=True)
cache_file = f"datasets_cache/tinystories_p{percent}.bin"
if os.path.exists(cache_file):
print(f"Loading tokens from {cache_file}...")
data_ids = array.array('H')
with open(cache_file, 'rb') as f:
data_ids.fromfile(f, os.path.getsize(cache_file)//2)
else:
data_ids = batch_encode(tokenizer, story_gen, target_count)
if data_ids:
print(f"Saving to {cache_file}...")
with open(cache_file, 'wb') as f: data_ids.tofile(f)
if not data_ids: continue
# Init Model
print("Initializing GPT-2 Small Config...")
# Use smaller config for faster playground training if desired,
# but User wanted standard GPT-2. We'll use GPT_CONFIG_124M but maybe fewer layers for speed?
# Let's stick to 124M logic but allow user to override if they want,
# for now, use standard config.
cfg = GPT_CONFIG_124M.copy()
# cfg['n_layers'] = 4 # Uncomment for faster testing
model = GPTModel(cfg).to(device)
print(f"Params: {model.get_num_params():,}")
epochs = int(safe_input("Epochs (default 1): ", "1"))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
try:
for epoch in range(epochs):
if is_interrupted(): break
loss = train_epoch(model, data_ids, 8, 256, optimizer, device) # batch 8, seq 256
print(f"Epoch {epoch+1} Loss: {loss:.4f}")
save_checkpoint(model, f"models_cache/v8_scratch_e{epoch+1}.pth", optimizer)
except KeyboardInterrupt:
print("Interrupted")
# --- 2. Load GPT-2 Weights ---
elif choice == '2':
try: import tensorflow
except:
print("Install tensorflow first.")
continue
params = download_and_load_gpt2_params("124M")
if params:
model = GPTModel(GPT_CONFIG_124M)
load_weights_into_gpt(model, params)
model.to(device)
save_checkpoint(model, "models_cache/gpt2_124m_converted.pth")
# --- 3. Load Checkpoint ---
elif choice == '3':
path = select_model_file()
if path:
model = load_checkpoint(path, device)
# --- 4. Generate ---
elif choice == '4':
if not model:
print("No model loaded.")
continue
prompt = safe_input("Prompt: ", "Once upon a time")
out = generate_text_simple(model, text_to_token_ids(prompt, tokenizer).to(device), 50, 1024)
print("Output:", token_ids_to_text(out, tokenizer))
# --- 5. Fine-tune ---
elif choice == '5':
if not model:
print("Load a model first!")
continue
data_file = "instruction-data.json"
if not os.path.exists(data_file):
# ... [Synthetic generator code from previous step] ...
print(f"Creating synthetic {data_file}...")
synthetic_data = [
{"instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris."},
{"instruction": "Summarize the following text.", "input": "Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.", "output": "AI is machine intelligence, distinct from animal or human natural intelligence."},
{"instruction": "Write a python function to add two numbers.", "input": "", "output": "def add(a, b):\n return a + b"},
{"instruction": "Translate 'Hello' to Spanish.", "input": "", "output": "Hola"},
{"instruction": "Explain the concept of gravity.", "input": "", "output": "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."},
{"instruction": "List three primary colors.", "input": "", "output": "Red, Blue, Yellow"},
{"instruction": "Who wrote 'Romeo and Juliet'?", "input": "", "output": "William Shakespeare"},
{"instruction": "Convert 100 degrees Celsius to Fahrenheit.", "input": "", "output": "212 degrees Fahrenheit"},
{"instruction": "What is the largest planet in our solar system?", "input": "", "output": "Jupiter"},
{"instruction": "Compose a haiku about code.", "input": "", "output": "Logic flows clearly,\nBugs hide in the shadows deep,\nScreen light glows all night."}
] * 10
with open(data_file, "w") as f: json.dump(synthetic_data, f, indent=2)
gen, count = get_instruction_batch_generator(data_file, tokenizer, device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.1)
model.train()
print("Fine-tuning (Ctrl+C to stop)...")
try:
for step in range(20):
if is_interrupted(): break
inputs, targets = next(gen)
optimizer.zero_grad()
logits = model(inputs)
loss = F.cross_entropy(logits[:, :-1, :].flatten(0,1), targets[:, 1:].flatten(), ignore_index=-100)
loss.backward()
optimizer.step()
print(f"Step {step+1} Loss: {loss.item():.4f}")
except KeyboardInterrupt: pass
timestamp = time.strftime("%Y%m%d_%H%M%S")
save_path = f"models_cache/v8_finetuned_{timestamp}.pth"
save_checkpoint(model, save_path, optimizer)
elif choice == '6':
break
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment