Skip to content

Instantly share code, notes, and snippets.

@burtenshaw
Created November 10, 2025 19:46
Show Gist options
  • Select an option

  • Save burtenshaw/88ae119d0090d35d267ddd5000bcf2e1 to your computer and use it in GitHub Desktop.

Select an option

Save burtenshaw/88ae119d0090d35d267ddd5000bcf2e1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
#
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "git+https://github.com/huggingface/trl.git#egg=trl[vllm]",
# "trackio==0.8.1",
# ///
"""
CUDA_VISIBLE_DEVICES=1 python trackio_late_start.py
"""
import os
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B")
GRADIENT_ACCUMULATION_STEPS = int(os.getenv("GRAD_ACCUM_STEPS", "2"))
PER_DEVICE_BATCH_SIZE = int(os.getenv("PER_DEVICE_BATCH_SIZE", "1"))
NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "2"))
DATASET_SIZE = int(os.getenv("DATASET_SIZE", "100"))
# ---------------------------------------------------------------------------
# Rewards
# ---------------------------------------------------------------------------
def reward_one(completions, **kwargs):
return [1.0] * len(completions)
# ---------------------------------------------------------------------------
# Main entrypoint
# ---------------------------------------------------------------------------
def main() -> None:
train_dataset = Dataset.from_dict({"prompt": ["BrowserGym agent"] * DATASET_SIZE})
grpo_config = GRPOConfig(
vllm_mode="colocate",
use_vllm=True,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
num_generations=NUM_GENERATIONS,
logging_steps=1,
save_strategy="steps",
save_total_limit=None,
report_to="trackio",
)
trainer = GRPOTrainer(
model=MODEL_ID,
reward_funcs=[reward_one],
train_dataset=train_dataset,
args=grpo_config,
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment