|
|
@@ -0,0 +1,245 @@ |
|
|
# train_grpo.py |
|
|
from typing import * |
|
|
import re |
|
|
import torch |
|
|
from datasets import load_dataset, Dataset, load_from_disk |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
|
|
from peft import LoraConfig |
|
|
from trl import GRPOConfig, GRPOTrainer, TrlParser |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
""" |
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
|
""" |
|
|
|
|
|
model_name_or_path: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." |
|
|
) |
|
|
}, |
|
|
) |
|
|
model_type: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "If training from scratch, pass a model type from the list: "}, |
|
|
) |
|
|
config_overrides: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"Override some existing default config settings when a model is trained from scratch. Example: " |
|
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" |
|
|
) |
|
|
}, |
|
|
) |
|
|
config_name: Optional[str] = field( |
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
|
) |
|
|
tokenizer_name: Optional[str] = field( |
|
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
|
) |
|
|
cache_dir: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
|
|
) |
|
|
use_fast_tokenizer: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
|
) |
|
|
model_revision: str = field( |
|
|
default="main", |
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
|
|
) |
|
|
token: str = field( |
|
|
default=None, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " |
|
|
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)." |
|
|
) |
|
|
}, |
|
|
) |
|
|
trust_remote_code: bool = field( |
|
|
default=False, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"Whether to trust the execution of code from datasets/models defined on the Hub." |
|
|
" This option should only be set to `True` for repositories you trust and in which you have read the" |
|
|
" code, as it will execute code present on the Hub on your local machine." |
|
|
) |
|
|
}, |
|
|
) |
|
|
torch_dtype: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " |
|
|
"dtype will be automatically derived from the model's weights." |
|
|
), |
|
|
"choices": ["auto", "bfloat16", "float16", "float32"], |
|
|
}, |
|
|
) |
|
|
low_cpu_mem_usage: bool = field( |
|
|
default=False, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " |
|
|
"set True will benefit LLM loading time and RAM consumption." |
|
|
) |
|
|
}, |
|
|
) |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): |
|
|
raise ValueError( |
|
|
"--config_overrides can't be used in combination with --config_name or --model_name_or_path" |
|
|
) |
|
|
|
|
|
# Load and prep dataset |
|
|
|
|
|
SYSTEM_PROMPT = """ |
|
|
Respond in the following format: |
|
|
|
|
|
<reasoning> |
|
|
... |
|
|
</reasoning> |
|
|
<answer> |
|
|
... |
|
|
</answer> |
|
|
""" |
|
|
|
|
|
XML_COT_FORMAT = """\ |
|
|
<reasoning> |
|
|
{reasoning} |
|
|
</reasoning> |
|
|
<answer> |
|
|
{answer} |
|
|
</answer> |
|
|
""" |
|
|
|
|
|
def extract_xml_answer(text: str) -> str: |
|
|
answer = text.split("<answer>")[-1] |
|
|
answer = answer.split("</answer>")[0] |
|
|
return answer.strip() |
|
|
|
|
|
def extract_hash_answer(text: str) -> str | None: |
|
|
if "####" not in text: |
|
|
return None |
|
|
return text.split("####")[1].strip() |
|
|
|
|
|
# uncomment middle messages for 1-shot prompting |
|
|
def get_gsm8k_questions(split = "train") -> Dataset: |
|
|
|
|
|
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore |
|
|
# data = load_from_disk("path tp gsm8k")[split] # for local path |
|
|
data = data.map(lambda x: { # type: ignore |
|
|
'prompt': [ |
|
|
{'role': 'system', 'content': SYSTEM_PROMPT}, |
|
|
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, |
|
|
#{'role': 'assistant', 'content': XML_COT_FORMAT.format( |
|
|
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", |
|
|
# answer="7" |
|
|
#)}, |
|
|
{'role': 'user', 'content': x['question']} |
|
|
], |
|
|
'answer': extract_hash_answer(x['answer']) |
|
|
}) # type: ignore |
|
|
return data # type: ignore |
|
|
|
|
|
dataset = get_gsm8k_questions() |
|
|
|
|
|
# Reward functions |
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: |
|
|
responses = [completion[0]['content'] for completion in completions] |
|
|
q = prompts[0][-1]['content'] |
|
|
extracted_responses = [extract_xml_answer(r) for r in responses] |
|
|
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") |
|
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] |
|
|
|
|
|
def int_reward_func(completions, **kwargs) -> list[float]: |
|
|
responses = [completion[0]['content'] for completion in completions] |
|
|
extracted_responses = [extract_xml_answer(r) for r in responses] |
|
|
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] |
|
|
|
|
|
def strict_format_reward_func(completions, **kwargs) -> list[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" |
|
|
responses = [completion[0]["content"] for completion in completions] |
|
|
matches = [re.match(pattern, r) for r in responses] |
|
|
return [0.5 if match else 0.0 for match in matches] |
|
|
|
|
|
def soft_format_reward_func(completions, **kwargs) -> list[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" |
|
|
responses = [completion[0]["content"] for completion in completions] |
|
|
matches = [re.match(pattern, r) for r in responses] |
|
|
return [0.5 if match else 0.0 for match in matches] |
|
|
|
|
|
def count_xml(text) -> float: |
|
|
count = 0.0 |
|
|
if text.count("<reasoning>\n") == 1: |
|
|
count += 0.125 |
|
|
if text.count("\n</reasoning>\n") == 1: |
|
|
count += 0.125 |
|
|
if text.count("\n<answer>\n") == 1: |
|
|
count += 0.125 |
|
|
count -= len(text.split("\n</answer>\n")[-1])*0.001 |
|
|
if text.count("\n</answer>") == 1: |
|
|
count += 0.125 |
|
|
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 |
|
|
return count |
|
|
|
|
|
def xmlcount_reward_func(completions, **kwargs) -> list[float]: |
|
|
contents = [completion[0]["content"] for completion in completions] |
|
|
return [count_xml(c) for c in contents] |
|
|
|
|
|
def main(model_args, training_args): |
|
|
|
|
|
|
|
|
# peft_config = LoraConfig( |
|
|
# r=16, |
|
|
# lora_alpha=64, |
|
|
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], |
|
|
# task_type="CAUSAL_LM", |
|
|
# lora_dropout=0.05, |
|
|
# ) |
|
|
torch_dtype = ( |
|
|
model_args.torch_dtype |
|
|
if model_args.torch_dtype in ["auto", None] |
|
|
else getattr(torch, model_args.torch_dtype) |
|
|
) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_args.model_name_or_path, |
|
|
torch_dtype=torch_dtype, |
|
|
attn_implementation="flash_attention_2", |
|
|
) |
|
|
|
|
|
model = model.to("cuda") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
# use peft at your own risk; not working for me with multi-GPU training |
|
|
trainer = GRPOTrainer( |
|
|
model=model, |
|
|
processing_class=tokenizer, |
|
|
reward_funcs=[ |
|
|
xmlcount_reward_func, |
|
|
soft_format_reward_func, |
|
|
strict_format_reward_func, |
|
|
int_reward_func, |
|
|
correctness_reward_func], |
|
|
args=training_args, |
|
|
train_dataset=dataset, |
|
|
#peft_config=peft_config |
|
|
) |
|
|
trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = TrlParser((ModelArguments,GRPOConfig,)) |
|
|
model_args, training_args, = parser.parse_args_and_config() |
|
|
main(model_args, training_args) |