Last active
March 30, 2025 14:58
-
-
Save itdxer/b30ceedb4ac0f3fd2b3e37fb54f71398 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | |
| # Source for "Build a Large Language Model From Scratch" | |
| # - https://www.manning.com/books/build-a-large-language-model-from-scratch | |
| # Code: https://github.com/rasbt/LLMs-from-scratch | |
| # This is a summary file containing the main takeaways from chapter 6. | |
| import urllib.request | |
| import zipfile | |
| import os | |
| from pathlib import Path | |
| import time | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import tiktoken | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from gpt_download import download_and_load_gpt2 | |
| from previous_chapters import GPTModel, load_weights_into_gpt | |
| def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False): | |
| if data_file_path.exists(): | |
| print(f"{data_file_path} already exists. Skipping download and extraction.") | |
| return | |
| if test_mode: # Try multiple times since CI sometimes has connectivity issues | |
| max_retries = 5 | |
| delay = 5 # delay between retries in seconds | |
| for attempt in range(max_retries): | |
| try: | |
| # Downloading the file | |
| with urllib.request.urlopen(url, timeout=10) as response: | |
| with open(zip_path, "wb") as out_file: | |
| out_file.write(response.read()) | |
| break # if download is successful, break out of the loop | |
| except urllib.error.URLError as e: | |
| print(f"Attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(delay) # wait before retrying | |
| else: | |
| print("Failed to download file after several attempts.") | |
| return # exit if all retries fail | |
| else: # Code as it appears in the chapter | |
| # Downloading the file | |
| with urllib.request.urlopen(url) as response: | |
| with open(zip_path, "wb") as out_file: | |
| out_file.write(response.read()) | |
| # Unzipping the file | |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| zip_ref.extractall(extracted_path) | |
| # Add .tsv file extension | |
| original_file_path = Path(extracted_path) / "SMSSpamCollection" | |
| os.rename(original_file_path, data_file_path) | |
| print(f"File downloaded and saved as {data_file_path}") | |
| def create_balanced_dataset(df): | |
| # Count the instances of "spam" | |
| num_spam = df[df["Label"] == "spam"].shape[0] | |
| # Randomly sample "ham" instances to match the number of "spam" instances | |
| ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) | |
| # Combine ham "subset" with "spam" | |
| balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]]) | |
| return balanced_df | |
| def random_split(df, train_frac, validation_frac): | |
| # Shuffle the entire DataFrame | |
| df = df.sample(frac=1, random_state=123).reset_index(drop=True) | |
| # Calculate split indices | |
| train_end = int(len(df) * train_frac) | |
| validation_end = train_end + int(len(df) * validation_frac) | |
| # Split the DataFrame | |
| train_df = df[:train_end] | |
| validation_df = df[train_end:validation_end] | |
| test_df = df[validation_end:] | |
| return train_df, validation_df, test_df | |
| class SpamDataset(Dataset): | |
| def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, ignore_index=-100): | |
| self.data = pd.read_csv(csv_file) | |
| # Pre-tokenize texts | |
| self.labels = self.data.Label.values | |
| self.encoded_texts = [ | |
| tokenizer.encode(text) for text in self.data["Text"] | |
| ] | |
| if max_length is None: | |
| self.max_length = self._longest_encoded_length() | |
| else: | |
| self.max_length = max_length | |
| # Truncate sequences if they are longer than max_length | |
| self.encoded_texts = [ | |
| encoded_text[:self.max_length] | |
| for encoded_text in self.encoded_texts | |
| ] | |
| self.labels = [ | |
| [label] * len(encoded_text) + [ignore_index] * (self.max_length - len(encoded_text)) | |
| for label, encoded_text in zip(self.labels, self.encoded_texts) | |
| ] | |
| # Pad sequences to the longest sequence | |
| self.encoded_texts = [ | |
| encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) | |
| for encoded_text in self.encoded_texts | |
| ] | |
| def __getitem__(self, index): | |
| encoded = self.encoded_texts[index] | |
| label = self.labels[index] | |
| return ( | |
| torch.tensor(encoded, dtype=torch.long), | |
| torch.tensor(label, dtype=torch.long) | |
| ) | |
| def __len__(self): | |
| return len(self.data) | |
| def _longest_encoded_length(self): | |
| max_length = 0 | |
| for encoded_text in self.encoded_texts: | |
| encoded_length = len(encoded_text) | |
| if encoded_length > max_length: | |
| max_length = encoded_length | |
| return max_length | |
| def calc_accuracy_loader(data_loader, model, device, num_batches=None): | |
| model.eval() | |
| correct_predictions, num_examples = 0, 0 | |
| if num_batches is None: | |
| num_batches = len(data_loader) | |
| else: | |
| num_batches = min(num_batches, len(data_loader)) | |
| for i, (input_batch, target_batch) in enumerate(data_loader): | |
| if i < num_batches: | |
| input_batch, target_batch = input_batch.to(device), target_batch.to(device) | |
| with torch.no_grad(): | |
| logits = model(input_batch)[:, -1, :] # Logits of last output token | |
| predicted_labels = torch.argmax(logits, dim=-1) | |
| num_examples += predicted_labels.shape[0] | |
| correct_predictions += (predicted_labels == target_batch[:, -1]).sum().item() | |
| else: | |
| break | |
| return correct_predictions / num_examples # broken | |
| def calc_loss_batch(input_batch, target_batch, model, device): | |
| input_batch, target_batch = input_batch.to(device), target_batch.to(device) | |
| logits = model(input_batch) | |
| loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) | |
| return loss | |
| def calc_loss_loader(data_loader, model, device, num_batches=None): | |
| total_loss = 0. | |
| if len(data_loader) == 0: | |
| return float("nan") | |
| elif num_batches is None: | |
| num_batches = len(data_loader) | |
| else: | |
| num_batches = min(num_batches, len(data_loader)) | |
| for i, (input_batch, target_batch) in enumerate(data_loader): | |
| if i < num_batches: | |
| loss = calc_loss_batch(input_batch, target_batch, model, device) | |
| total_loss += loss.item() | |
| else: | |
| break | |
| return total_loss / num_batches | |
| def evaluate_model(model, train_loader, val_loader, device, eval_iter): | |
| model.eval() | |
| with torch.no_grad(): | |
| train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) | |
| val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) | |
| model.train() | |
| return train_loss, val_loss | |
| def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, | |
| eval_freq, eval_iter, tokenizer): | |
| # Initialize lists to track losses and tokens seen | |
| train_losses, val_losses, train_accs, val_accs = [], [], [], [] | |
| examples_seen, global_step = 0, -1 | |
| # Main training loop | |
| for epoch in range(num_epochs): | |
| model.train() # Set model to training mode | |
| for input_batch, target_batch in train_loader: | |
| optimizer.zero_grad() # Reset loss gradients from previous batch iteration | |
| loss = calc_loss_batch(input_batch, target_batch, model, device) | |
| loss.backward() # Calculate loss gradients | |
| optimizer.step() # Update model weights using loss gradients | |
| examples_seen += input_batch.shape[0] # New: track examples instead of tokens | |
| global_step += 1 | |
| # Optional evaluation step | |
| if global_step % eval_freq == 0: | |
| train_loss, val_loss = evaluate_model( | |
| model, train_loader, val_loader, device, eval_iter) | |
| train_losses.append(train_loss) | |
| val_losses.append(val_loss) | |
| print(f"Ep {epoch+1} (Step {global_step:06d}): " | |
| f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") | |
| # Calculate accuracy after each epoch | |
| train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) | |
| val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) | |
| # print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") | |
| # print(f"Validation accuracy: {val_accuracy*100:.2f}%") | |
| train_accs.append(train_accuracy) | |
| val_accs.append(val_accuracy) | |
| return train_losses, val_losses, train_accs, val_accs, examples_seen | |
| def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"): | |
| fig, ax1 = plt.subplots(figsize=(5, 3)) | |
| # Plot training and validation loss against epochs | |
| ax1.plot(epochs_seen, train_values, label=f"Training {label}") | |
| ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}") | |
| ax1.set_xlabel("Epochs") | |
| ax1.set_ylabel(label.capitalize()) | |
| ax1.legend() | |
| # Create a second x-axis for tokens seen | |
| ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis | |
| ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks | |
| ax2.set_xlabel("Examples seen") | |
| fig.tight_layout() # Adjust layout to make room | |
| plt.savefig(f"{label}-plot.pdf") | |
| # plt.show() | |
| def predict_and_visualize(model, tokenizer, message, device): | |
| tokens_encoded = tokenizer.encode(message) | |
| output = model(torch.tensor([tokens_encoded]).to(device)) | |
| exp_logits = np.exp(output.cpu().detach()) | |
| p = exp_logits / exp_logits.sum(axis=2).reshape(1, -1, 1) | |
| p_spam = p[0, :, 1] | |
| token_labels = -np.arange(output.shape[1]) | |
| tokens_decoded = [tokenizer.decode_single_token_bytes(t).decode() for t in tokens_encoded] | |
| plt.barh(token_labels, p_spam); | |
| plt.xlim(0, 1) | |
| plt.yticks(token_labels, tokens_decoded) | |
| plt.xlabel("Probability of being a spam") | |
| plt.title(f"Message: {message}") | |
| plt.show() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Finetune a GPT model for classification" | |
| ) | |
| parser.add_argument( | |
| "--test_mode", | |
| default=False, | |
| action="store_true", | |
| help=("This flag runs the model in test mode for internal testing purposes. " | |
| "Otherwise, it runs the model as it is used in the chapter (recommended).") | |
| ) | |
| args = parser.parse_args() | |
| ######################################## | |
| # Download and prepare dataset | |
| ######################################## | |
| url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip" | |
| zip_path = "sms_spam_collection.zip" | |
| extracted_path = "sms_spam_collection" | |
| data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" | |
| download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode) | |
| df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) | |
| # balanced_df = create_balanced_dataset(df) | |
| balanced_df = df | |
| balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) | |
| train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1) | |
| train_df.to_csv("train.csv", index=None) | |
| validation_df.to_csv("validation.csv", index=None) | |
| test_df.to_csv("test.csv", index=None) | |
| ######################################## | |
| # Create data loaders | |
| ######################################## | |
| tokenizer = tiktoken.get_encoding("gpt2") | |
| train_dataset = SpamDataset( | |
| csv_file="train.csv", | |
| max_length=None, | |
| tokenizer=tokenizer | |
| ) | |
| val_dataset = SpamDataset( | |
| csv_file="validation.csv", | |
| max_length=train_dataset.max_length, | |
| tokenizer=tokenizer | |
| ) | |
| test_dataset = SpamDataset( | |
| csv_file="test.csv", | |
| max_length=train_dataset.max_length, | |
| tokenizer=tokenizer | |
| ) | |
| num_workers = 0 | |
| batch_size = 8 | |
| torch.manual_seed(123) | |
| train_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| drop_last=True, | |
| ) | |
| val_loader = DataLoader( | |
| dataset=val_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| ) | |
| test_loader = DataLoader( | |
| dataset=test_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| ) | |
| ######################################## | |
| # Load pretrained model | |
| ######################################## | |
| # Small GPT model for testing purposes | |
| if args.test_mode: | |
| BASE_CONFIG = { | |
| "vocab_size": 50257, | |
| "context_length": 120, | |
| "drop_rate": 0.0, | |
| "qkv_bias": False, | |
| "emb_dim": 12, | |
| "n_layers": 1, | |
| "n_heads": 2 | |
| } | |
| model = GPTModel(BASE_CONFIG) | |
| model.eval() | |
| device = "cpu" | |
| # Code as it is used in the main chapter | |
| else: | |
| CHOOSE_MODEL = "gpt2-small (124M)" | |
| INPUT_PROMPT = "Every effort moves" | |
| BASE_CONFIG = { | |
| "vocab_size": 50257, # Vocabulary size | |
| "context_length": 1024, # Context length | |
| "drop_rate": 0.0, # Dropout rate | |
| "qkv_bias": True # Query-key-value bias | |
| } | |
| model_configs = { | |
| "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, | |
| "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, | |
| "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, | |
| "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, | |
| } | |
| BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) | |
| assert train_dataset.max_length <= BASE_CONFIG["context_length"], ( | |
| f"Dataset length {train_dataset.max_length} exceeds model's context " | |
| f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with " | |
| f"`max_length={BASE_CONFIG['context_length']}`" | |
| ) | |
| model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") | |
| settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") | |
| model = GPTModel(BASE_CONFIG) | |
| load_weights_into_gpt(model, params) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ######################################## | |
| # Modify and pretrained model | |
| ######################################## | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| torch.manual_seed(123) | |
| num_classes = 2 | |
| model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes) | |
| model.to(device) | |
| for param in model.trf_blocks[-1].parameters(): | |
| param.requires_grad = True | |
| for param in model.final_norm.parameters(): | |
| param.requires_grad = True | |
| ######################################## | |
| # Finetune modified model | |
| ######################################## | |
| start_time = time.time() | |
| torch.manual_seed(123) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=2.5e-5, weight_decay=0.1) | |
| num_epochs = 10 | |
| train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( | |
| model, train_loader, val_loader, optimizer, device, | |
| num_epochs=num_epochs, eval_freq=50, eval_iter=5, | |
| tokenizer=tokenizer | |
| ) | |
| end_time = time.time() | |
| execution_time_minutes = (end_time - start_time) / 60 | |
| print(f"Training completed in {execution_time_minutes:.2f} minutes.") | |
| predict_and_visualize( | |
| model, | |
| tokenizer, | |
| "Do you enjoy hiking and camping under the stars? Register now and get 50% off on a new hammock", | |
| device, | |
| ) | |
| ######################################## | |
| # Plot results | |
| ######################################## | |
| # loss plot | |
| epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) | |
| examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses)) | |
| plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses) | |
| # # accuracy plot (broken) | |
| # epochs_tensor = torch.linspace(0, num_epochs, len(train_accs)) | |
| # examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs)) | |
| # plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment