Created
September 30, 2025 12:23
-
-
Save burtenshaw/5ad5a89616ab3333635ed26fe5bd0646 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # /// script | |
| # dependencies = [ | |
| # "trl", | |
| # "Pillow", | |
| # "peft", | |
| # "math-verify", | |
| # "latex2sympy2_extended", | |
| # "torchvision", | |
| # "trackio", | |
| # "vllm", | |
| # ] | |
| # /// | |
| """ | |
| LoRA-optimized GRPO VLM script following best practices from "LoRA Without Regret" (Schulman et al. 2025). | |
| Source: https://thinkingmachines.ai/blog/lora/ | |
| Key finding for RL: LoRA performs equivalently to full fine-tuning even with very small ranks. | |
| Policy gradient algorithms learn roughly 1 bit of information per episode, requiring minimal capacity. | |
| Recommended ranks for RL tasks: 8-32 (much lower than SFT which needs 64-256+) | |
| # For Qwen/Qwen2.5-VL-3B-Instruct with optimal LoRA (rank 16, all-linear) | |
| ``` | |
| hf jobs uv run \ | |
| --flavor a100-large \ | |
| --timeout 6h \ | |
| --secrets HF_TOKEN \ | |
| "https://gist.githubusercontent.com/burtenshaw/986f53790607d6378d959f31ddb416d2/raw/grpo_lora.py" \ | |
| --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ | |
| --output_dir grpo-Qwen2.5-VL-3B-Instruct-LoRA \ | |
| --learning_rate 1e-5 \ | |
| --gradient_checkpointing \ | |
| --torch_dtype bfloat16 \ | |
| --max_prompt_length 2048 \ | |
| --max_completion_length 1024 \ | |
| --use_vllm \ | |
| --vllm_mode colocate \ | |
| --use_peft \ | |
| --lora_r 16 \ | |
| --lora_alpha 16 \ | |
| --lora_target_modules all-linear \ | |
| --log_completions \ | |
| --report_to trackio \ | |
| --push_to_hub | |
| ``` | |
| # For HuggingFaceTB/SmolVLM2-2.2B-Instruct with lower rank (rank 8) | |
| ``` | |
| hf jobs uv run \ | |
| --flavor a100-large \ | |
| --timeout 6h \ | |
| --secrets HF_TOKEN \ | |
| "https://gist.githubusercontent.com/burtenshaw/986f53790607d6378d959f31ddb416d2/raw/grpo_lora.py" \ | |
| --model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ | |
| --output_dir grpo-SmolVLM2-2.2B-Instruct-LoRA \ | |
| --learning_rate 1e-5 \ | |
| --torch_dtype bfloat16 \ | |
| --max_prompt_length 2048 \ | |
| --max_completion_length 1024 \ | |
| --use_peft \ | |
| --lora_r 8 \ | |
| --lora_alpha 8 \ | |
| --lora_target_modules all-linear \ | |
| --log_completions \ | |
| --per_device_train_batch_size 1 \ | |
| --gradient_accumulation_steps 2 \ | |
| --num_generations 2 \ | |
| --report_to trackio \ | |
| --push_to_hub | |
| ``` | |
| # Higher rank for complex VLM reasoning (rank 32) | |
| ``` | |
| hf jobs uv run \ | |
| --flavor a100-large \ | |
| --timeout 8h \ | |
| --secrets HF_TOKEN \ | |
| "https://gist.githubusercontent.com/burtenshaw/986f53790607d6378d959f31ddb416d2/raw/grpo_lora.py" \ | |
| --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \ | |
| --output_dir grpo-Qwen2.5-VL-7B-Instruct-LoRA \ | |
| --learning_rate 1e-5 \ | |
| --gradient_checkpointing \ | |
| --torch_dtype bfloat16 \ | |
| --max_prompt_length 2048 \ | |
| --max_completion_length 1024 \ | |
| --use_vllm \ | |
| --vllm_mode colocate \ | |
| --use_peft \ | |
| --lora_r 32 \ | |
| --lora_alpha 16 \ | |
| --lora_target_modules all-linear \ | |
| --log_completions \ | |
| --report_to trackio \ | |
| --push_to_hub | |
| ``` | |
| """ | |
| import os | |
| import torch | |
| from accelerate import logging | |
| from datasets import load_dataset | |
| from latex2sympy2_extended import NormalizationConfig | |
| from math_verify import LatexExtractionConfig, parse, verify | |
| from trl import ( | |
| GRPOConfig, | |
| GRPOTrainer, | |
| ModelConfig, | |
| ScriptArguments, | |
| TrlParser, | |
| get_kbit_device_map, | |
| get_peft_config, | |
| get_quantization_config, | |
| ) | |
| from trl.rewards import think_format_reward | |
| logger = logging.get_logger(__name__) | |
| # Enable logging in a Hugging Face Space | |
| os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") | |
| def validate_lora_config_rl(model_args, training_args): | |
| """ | |
| Validate and provide guidance on LoRA configuration for RL tasks. | |
| Based on "LoRA Without Regret" (Schulman et al. 2025): | |
| https://thinkingmachines.ai/blog/lora/ | |
| Key insight: RL requires very low capacity (~1 bit per episode). | |
| Policy gradient algorithms learn minimal information per episode. | |
| """ | |
| if not model_args.use_peft: | |
| return | |
| # Check if target_modules is set appropriately | |
| if model_args.lora_target_modules is None: | |
| logger.warning( | |
| "⚠️ No lora_target_modules specified. For best performance, set --lora_target_modules all-linear " | |
| "to apply LoRA to ALL weight matrices (not just attention). Research shows that attention-only " | |
| "LoRA underperforms even when using higher rank to match parameter count." | |
| ) | |
| elif isinstance(model_args.lora_target_modules, (list, str)): | |
| target_str = str(model_args.lora_target_modules) | |
| if "all-linear" not in target_str and "q_proj" in target_str: | |
| logger.warning( | |
| "⚠️ Detected attention-only LoRA configuration (q_proj, v_proj). For best performance, use " | |
| "--lora_target_modules all-linear to apply LoRA to ALL weight matrices including MLP layers. " | |
| "Research shows this significantly improves performance compared to attention-only LoRA." | |
| ) | |
| # Check learning rate | |
| if training_args.learning_rate > 5e-4: | |
| logger.warning( | |
| f"⚠️ Learning rate {training_args.learning_rate} seems high. For RL tasks, use learning rates " | |
| f"similar to full fine-tuning (typically 1e-5 to 5e-5). Consider reducing the learning rate." | |
| ) | |
| # Provide rank guidance specific to RL | |
| if model_args.lora_r < 8: | |
| logger.warning( | |
| f"⚠️ Rank {model_args.lora_r} is very low. For RL tasks, rank 8-32 is recommended. " | |
| f"Current rank may be too low even for RL." | |
| ) | |
| elif model_args.lora_r > 64: | |
| logger.info( | |
| f"💡 Rank {model_args.lora_r} is higher than typically needed for RL. Research shows that " | |
| f"policy gradient algorithms learn ~1 bit per episode, requiring minimal capacity (rank 8-32). " | |
| f"Consider using a lower rank to save compute and memory." | |
| ) | |
| elif 8 <= model_args.lora_r <= 32: | |
| logger.info( | |
| f"✅ Rank {model_args.lora_r} is optimal for RL tasks. Research shows LoRA performs equivalently " | |
| f"to full fine-tuning with these small ranks for policy gradient methods." | |
| ) | |
| # Check batch size | |
| total_batch_size = ( | |
| training_args.per_device_train_batch_size | |
| * training_args.gradient_accumulation_steps | |
| * training_args.world_size | |
| ) | |
| if total_batch_size > 256: | |
| logger.warning( | |
| f"⚠️ Large effective batch size detected: {total_batch_size}. Research shows LoRA may be less " | |
| f"tolerant of very large batch sizes compared to full fine-tuning. Consider reducing batch size " | |
| f"if you observe suboptimal performance." | |
| ) | |
| # Log configuration summary | |
| logger.info("=" * 80) | |
| logger.info("LoRA Configuration Summary for RL (based on 'LoRA Without Regret'):") | |
| logger.info(f" Rank (r): {model_args.lora_r} (optimal for RL: 8-32)") | |
| logger.info(f" Alpha: {model_args.lora_alpha}") | |
| logger.info(f" Alpha/r ratio: {model_args.lora_alpha / model_args.lora_r:.2f}") | |
| logger.info(f" Target modules: {model_args.lora_target_modules}") | |
| logger.info(f" Dropout: {model_args.lora_dropout}") | |
| logger.info(f" Learning rate: {training_args.learning_rate}") | |
| logger.info(f" Effective batch size: {total_batch_size}") | |
| logger.info( | |
| f" Quantization: {'4-bit' if model_args.load_in_4bit else '8-bit' if model_args.load_in_8bit else 'None'}" | |
| ) | |
| logger.info( | |
| " RL Insight: Policy gradient learns ~1 bit/episode → very low capacity needed" | |
| ) | |
| logger.info("=" * 80) | |
| if __name__ == "__main__": | |
| parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig)) | |
| script_args, training_args, model_args = parser.parse_args_and_config() | |
| ################ | |
| # Validate LoRA configuration | |
| ################ | |
| validate_lora_config_rl(model_args, training_args) | |
| ################ | |
| # Model & Processor | |
| ################ | |
| torch_dtype = ( | |
| model_args.torch_dtype | |
| if model_args.torch_dtype in ["auto", None] | |
| else getattr(torch, model_args.torch_dtype) | |
| ) | |
| quantization_config = get_quantization_config(model_args) | |
| training_args.model_init_kwargs = dict( | |
| revision=model_args.model_revision, | |
| attn_implementation=model_args.attn_implementation, | |
| torch_dtype=torch_dtype, | |
| device_map=get_kbit_device_map() if quantization_config is not None else None, | |
| quantization_config=quantization_config, | |
| ) | |
| ################ | |
| # Dataset | |
| ################ | |
| dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train") | |
| dataset = dataset.train_test_split(test_size=100, seed=42) | |
| SYSTEM_PROMPT = ( | |
| "A conversation between user and assistant. The user asks a question, and the assistant solves it. The " | |
| "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " | |
| "The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my " | |
| "reasoning.\n</think>\nThis is my answer." | |
| ) | |
| def make_conversation(example): | |
| prompt = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": example["problem"]}, | |
| ] | |
| return {"prompt": prompt} | |
| dataset = dataset.map(make_conversation) | |
| # Filter have big images | |
| def filter_big_images(example): | |
| image = example["image"] | |
| return image.size[0] < 512 and image.size[1] < 512 | |
| dataset = dataset.filter(filter_big_images) | |
| def convert_to_rgb(example): | |
| image = example["image"] | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| example["image"] = image | |
| return example | |
| dataset = dataset.map(convert_to_rgb) | |
| train_dataset = dataset["train"] | |
| eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None | |
| ################ | |
| # Reward Function for Training | |
| ################ | |
| def accuracy_reward(completions, solution: list[str], **kwargs): | |
| """Reward function that checks if the completion matches the ground truth. | |
| - If both gold and prediction are parseable → use math verification. | |
| - If not parseable → compare as normalized text. | |
| """ | |
| rewards = [] | |
| contents = [completion[0]["content"] for completion in completions] | |
| for content, sol in zip(contents, solution): | |
| try: | |
| gold_parsed = parse(sol, extraction_mode="first_match") | |
| except Exception: | |
| gold_parsed = [] | |
| if len(gold_parsed) != 0: | |
| # Try parsing predicted answer too | |
| try: | |
| answer_parsed = parse( | |
| content, | |
| extraction_config=[ | |
| LatexExtractionConfig( | |
| normalization_config=NormalizationConfig( | |
| nits=False, | |
| malformed_operators=False, | |
| basic_latex=True, | |
| boxed="all", | |
| units=True, | |
| ), | |
| boxed_match_priority=0, | |
| try_extract_without_anchor=False, | |
| ) | |
| ], | |
| extraction_mode="first_match", | |
| ) | |
| reward = float(verify(gold_parsed, answer_parsed)) | |
| except Exception as e: | |
| print(f"verify failed: {e}, answer: {content}, gold: {sol}") | |
| reward = None | |
| else: | |
| # fallback to text match | |
| reward = float(content.strip().lower() == sol.strip().lower()) | |
| rewards.append(reward) | |
| return rewards | |
| ################ | |
| # Training | |
| ################ | |
| trainer = GRPOTrainer( | |
| model=model_args.model_name_or_path, | |
| args=training_args, | |
| reward_funcs=[think_format_reward, accuracy_reward], | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| peft_config=get_peft_config(model_args), | |
| ) | |
| trainer.train() | |
| # Save and push to hub | |
| trainer.save_model(training_args.output_dir) | |
| if training_args.push_to_hub: | |
| trainer.push_to_hub(dataset_name=script_args.dataset_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment