# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # /// script # dependencies = [ # "trl @ git+https://github.com/huggingface/trl.git", # "peft", # "math-verify", # "latex2sympy2_extended", # ] # /// """ pip install math_verify # For Qwen/Qwen2.5-VL-3B-Instruct accelerate launch \ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/gspo_vlm.py \ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ --output_dir gspo-Qwen2.5-VL-3B-Instruct \ --learning_rate 1e-5 \ --torch_dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ --log_completions \ --per_device_train_batch_size 8 \ --num_generations 8 \ --bf16 True \ --importance_sampling_level sequence \ --epsilon 3e-4 \ --epsilon_high 4e-4 \ --beta 0.0 \ --loss_type grpo \ --gradient_accumulation_steps 2 \ --steps_per_generation 8 """ import torch from datasets import load_dataset from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from trl import ( GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config, ) from trl.rewards import think_format_reward if __name__ == "__main__": parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Processor ################ torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) ################ # Dataset ################ dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train") dataset = dataset.train_test_split(test_size=100, seed=42) SYSTEM_PROMPT = ( "A conversation between user and assistant. The user asks a question, and the assistant solves it. The " "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " "The reasoning process and answer are enclosed within tags, i.e., \nThis is my " "reasoning.\n\nThis is my answer." ) def make_conversation(example): prompt = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": example["problem"]}, ] return {"prompt": prompt} dataset = dataset.map(make_conversation) # Filter have big images def filter_big_images(example): image = example["image"] return image.size[0] < 512 and image.size[1] < 512 dataset = dataset.filter(filter_big_images) def convert_to_rgb(example): image = example["image"] if image.mode != "RGB": image = image.convert("RGB") example["image"] = image return example dataset = dataset.map(convert_to_rgb) train_dataset = dataset["train"] eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None ################ # Reward Function for Training ################ def accuracy_reward(completions, solution: list[str], **kwargs): """Reward function that checks if the completion matches the ground truth. - If both gold and prediction are parseable → use math verification. - If not parseable → compare as normalized text. """ rewards = [] contents = [completion[0]["content"] for completion in completions] for content, sol in zip(contents, solution): try: gold_parsed = parse(sol, extraction_mode="first_match") except Exception: gold_parsed = [] if len(gold_parsed) != 0: # Try parsing predicted answer too try: answer_parsed = parse( content, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, boxed="all", units=True, ), boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) reward = float(verify(gold_parsed, answer_parsed)) except Exception as e: print(f"verify failed: {e}, answer: {content}, gold: {sol}") reward = None else: # fallback to text match reward = float(content.strip().lower() == sol.strip().lower()) rewards.append(reward) return rewards ################ # Training ################ trainer = GRPOTrainer( model=model_args.model_name_or_path, args=training_args, reward_funcs=[think_format_reward, accuracy_reward], train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=get_peft_config(model_args), ) trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name)