# create_pairs.py import argparse import copy import json import random from tqdm import tqdm from datasets import load_dataset from vllm import LLM, SamplingParams from transformers import AutoTokenizer def create_pairs( dataset_name: str, model_name: str, tokenizer_name: str, pairs_per_example: int, total_pairs: int, min_length: int, max_length: int, seed: int, offset: int, output_filepath: str, ): random.seed(seed) llm = LLM(model=model_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) dataset = ( load_dataset(dataset_name, streaming=True)["train"] .shuffle(seed=seed) .skip(offset) ) dataset_iter = iter(dataset) progress = tqdm(total=total_pairs) progress.display(f"Processing dataset {dataset_name}...") pairs = [] while len(pairs) < total_pairs: example = next(dataset_iter) messages = example.get("conversations") source = example.get("source") if not source or not messages: raise ValueError(f"Missing source or conversations: {example}") fmt_msgs = [ { "role": { "system": "system", "gpt": "assistant", "human": "user", }[msg["from"]], "content": msg["value"], } for msg in messages ] history, last_message = (fmt_msgs[:-1], fmt_msgs[-1]) if last_message["role"] != "assistant": raise ValueError(f"Last message is not from user: {example}") prefix = ( tokenizer.apply_chat_template( history, add_generation_prompt=True, tokenize=False ) + "<|steer_start|>" ) gen_token_count = random.randrange(min_length, max_length) # TODO: add system message requesting `gen_token_count` of thinking tokens sampling_params = SamplingParams( temperature=1.0, min_tokens=gen_token_count, max_tokens=gen_token_count, n=pairs_per_example * 2, ) output = llm.generate([prefix], sampling_params, use_tqdm=False)[0] ex = { "prompt": history, "steer1": None, "steer2": None, "completion": [last_message], } for op in output.outputs: txt = ( op.text.strip() .replace("<|steer_start|>", "") .replace("<|steer_end|>", "") .strip() ) if ex["steer1"] is None: ex["steer1"] = txt elif ex["steer2"] is None: ex["steer2"] = txt pairs.append(copy.deepcopy(ex)) progress.update(len(pairs) - progress.n) ex.update({"steer1": None, "steer2": None}) else: raise ValueError("Unexpected output") with open(output_filepath, "w") as f: for pair in pairs: f.write(json.dumps(pair) + "\n") def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset_name", type=str, default="mlabonne/FineTome-100k") parser.add_argument( "--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B" ) parser.add_argument( "--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" ) parser.add_argument("--pairs_per_example", type=int, default=10) parser.add_argument("--total_pairs", type=int, default=1000) parser.add_argument("--min_length", type=int, default=32) parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--offset", type=int, default=0) parser.add_argument("--output_filepath", type=str, default="pairs_01.jsonl") args = parser.parse_args() create_pairs( dataset_name=args.dataset_name, model_name=args.model_name, tokenizer_name=args.tokenizer_name, pairs_per_example=args.pairs_per_example, total_pairs=args.total_pairs, min_length=args.min_length, max_length=args.max_length, seed=args.seed, offset=args.offset, output_filepath=args.output_filepath, ) if __name__ == "__main__": main()