# This is a modified version of TRL's `SFTTrainer` example (https://github.com/huggingface/trl/blob/main/examples/scripts/sft_trainer.py), # adapted to run with DeepSpeed ZeRO-3 and Mistral-7B-V1.0. The settings below were run on 1 node of 8 x A100 (80GB) GPUs. # # Usage: # - Install the latest transformers & accelerate versions: `pip install -U transformers accelerate` # - Install deepspeed: `pip install deepspeed==0.9.5` # - Install TRL from main: pip install git+https://github.com/huggingface/trl.git # - Clone the repo: git clone github.com/huggingface/trl.git # - Copy this Gist into trl/examples/scripts # - Run from root of trl repo with: accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --gradient_accumulation_steps 8 examples/scripts/sft_trainer.py from dataclasses import dataclass, field from typing import Optional import torch from accelerate import Accelerator from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, AutoTokenizer from trl import SFTTrainer tqdm.pandas() # Define and parse arguments. @dataclass class ScriptArguments: """ The name of the Casual LM model we wish to fine with SFTTrainer """ model_name: Optional[str] = field(default="mistralai/Mistral-7B-v0.1", metadata={"help": "the model name"}) dataset_name: Optional[str] = field( default="stingning/ultrachat", metadata={"help": "the dataset name"} ) dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"}) batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"}) gradient_accumulation_steps: Optional[int] = field( default=8, metadata={"help": "the number of gradient accumulation steps"} ) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"}) use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) save_steps: Optional[int] = field( default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"} ) save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push the model to HF Hub"}) hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"}) parser = HfArgumentParser(ScriptArguments) script_args = parser.parse_args_into_dataclasses()[0] # Step 1: Load the dataset tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) dataset = load_dataset(script_args.dataset_name, split="train[:20000]") dataset = dataset.train_test_split(test_size=0.1) def prepare_dialogue(example): text = "" for idx, msg in enumerate(example["data"]): if idx % 2 == 0: text += f"<|user|>\n{msg}{tokenizer.eos_token}\n" else: text += f"<|assistant|>\n{msg}{tokenizer.eos_token}\n" example["text"] = text return example dataset = dataset.map(prepare_dialogue, num_proc=4, remove_columns=["id", "data"]) # Step 2: Load the model if script_args.load_in_8bit and script_args.load_in_4bit: raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") elif script_args.load_in_8bit or script_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit ) # Copy the model to each device device_map = {"": Accelerator().local_process_index} torch_dtype = torch.bfloat16 else: device_map = None quantization_config = None torch_dtype = None model = AutoModelForCausalLM.from_pretrained( script_args.model_name, quantization_config=quantization_config, device_map=device_map, trust_remote_code=script_args.trust_remote_code, torch_dtype=torch_dtype, use_auth_token=script_args.use_auth_token, ) # Step 3: Define the training arguments training_args = TrainingArguments( output_dir=script_args.output_dir, per_device_train_batch_size=script_args.batch_size, gradient_accumulation_steps=script_args.gradient_accumulation_steps, gradient_checkpointing=True, learning_rate=script_args.learning_rate, logging_steps=script_args.logging_steps, num_train_epochs=script_args.num_train_epochs, max_steps=script_args.max_steps, report_to=script_args.log_with, save_steps=script_args.save_steps, save_total_limit=script_args.save_total_limit, push_to_hub=script_args.push_to_hub, hub_model_id=script_args.hub_model_id, bf16=True, lr_scheduler_type="cosine", warmup_ratio=0.1, evaluation_strategy="epoch", logging_first_step=True, ) # Step 4: Define the LoraConfig if script_args.use_peft: peft_config = LoraConfig( r=script_args.peft_lora_r, lora_alpha=script_args.peft_lora_alpha, bias="none", task_type="CAUSAL_LM", ) else: peft_config = None # Step 5: Define the Trainer trainer = SFTTrainer( model=model, args=training_args, max_seq_length=script_args.seq_length, train_dataset=dataset["train"], eval_dataset=dataset["test"], dataset_text_field=script_args.dataset_text_field, peft_config=peft_config, packing=True, ) trainer.train() # Step 6: Save the model trainer.save_model(script_args.output_dir)