Created
November 10, 2025 19:46
-
-
Save burtenshaw/88ae119d0090d35d267ddd5000bcf2e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # | |
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "git+https://github.com/huggingface/trl.git#egg=trl[vllm]", | |
| # "trackio==0.8.1", | |
| # /// | |
| """ | |
| CUDA_VISIBLE_DEVICES=1 python trackio_late_start.py | |
| """ | |
| import os | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B") | |
| GRADIENT_ACCUMULATION_STEPS = int(os.getenv("GRAD_ACCUM_STEPS", "2")) | |
| PER_DEVICE_BATCH_SIZE = int(os.getenv("PER_DEVICE_BATCH_SIZE", "1")) | |
| NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "2")) | |
| DATASET_SIZE = int(os.getenv("DATASET_SIZE", "100")) | |
| # --------------------------------------------------------------------------- | |
| # Rewards | |
| # --------------------------------------------------------------------------- | |
| def reward_one(completions, **kwargs): | |
| return [1.0] * len(completions) | |
| # --------------------------------------------------------------------------- | |
| # Main entrypoint | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| train_dataset = Dataset.from_dict({"prompt": ["BrowserGym agent"] * DATASET_SIZE}) | |
| grpo_config = GRPOConfig( | |
| vllm_mode="colocate", | |
| use_vllm=True, | |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, | |
| per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, | |
| num_generations=NUM_GENERATIONS, | |
| logging_steps=1, | |
| save_strategy="steps", | |
| save_total_limit=None, | |
| report_to="trackio", | |
| ) | |
| trainer = GRPOTrainer( | |
| model=MODEL_ID, | |
| reward_funcs=[reward_one], | |
| train_dataset=train_dataset, | |
| args=grpo_config, | |
| ) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment