Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# /// script | |
# dependencies = [ | |
# "trl @ git+https://github.com/huggingface/trl.git", | |
# "peft", | |
# "math-verify", | |
# "latex2sympy2_extended", | |
# ] | |
# /// | |
""" | |
pip install math_verify | |
# For Qwen/Qwen2.5-VL-3B-Instruct | |
accelerate launch \ | |
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \ | |
examples/scripts/gspo_vlm.py \ | |
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ | |
--output_dir gspo-Qwen2.5-VL-3B-Instruct \ | |
--learning_rate 1e-5 \ | |
--torch_dtype bfloat16 \ | |
--max_prompt_length 2048 \ | |
--max_completion_length 1024 \ | |
--use_peft \ | |
--lora_target_modules "q_proj", "v_proj" \ | |
--log_completions \ | |
--per_device_train_batch_size 8 \ | |
--num_generations 8 \ | |
--bf16 True \ | |
--importance_sampling_level sequence \ | |
--epsilon 3e-4 \ | |
--epsilon_high 4e-4 \ | |
--beta 0.0 \ | |
--loss_type grpo \ | |
--gradient_accumulation_steps 2 \ | |
--steps_per_generation 8 | |
""" | |
import torch | |
from datasets import load_dataset | |
from latex2sympy2_extended import NormalizationConfig | |
from math_verify import LatexExtractionConfig, parse, verify | |
from trl import ( | |
GRPOConfig, | |
GRPOTrainer, | |
ModelConfig, | |
ScriptArguments, | |
TrlParser, | |
get_kbit_device_map, | |
get_peft_config, | |
get_quantization_config, | |
) | |
from trl.rewards import think_format_reward | |
if __name__ == "__main__": | |
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig)) | |
script_args, training_args, model_args = parser.parse_args_and_config() | |
################ | |
# Model & Processor | |
################ | |
torch_dtype = ( | |
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) | |
) | |
quantization_config = get_quantization_config(model_args) | |
training_args.model_init_kwargs = dict( | |
revision=model_args.model_revision, | |
attn_implementation=model_args.attn_implementation, | |
torch_dtype=torch_dtype, | |
device_map=get_kbit_device_map() if quantization_config is not None else None, | |
quantization_config=quantization_config, | |
) | |
################ | |
# Dataset | |
################ | |
dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train") | |
dataset = dataset.train_test_split(test_size=100, seed=42) | |
SYSTEM_PROMPT = ( | |
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The " | |
"assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " | |
"The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my " | |
"reasoning.\n</think>\nThis is my answer." | |
) | |
def make_conversation(example): | |
prompt = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": example["problem"]}, | |
] | |
return {"prompt": prompt} | |
dataset = dataset.map(make_conversation) | |
# Filter have big images | |
def filter_big_images(example): | |
image = example["image"] | |
return image.size[0] < 512 and image.size[1] < 512 | |
dataset = dataset.filter(filter_big_images) | |
def convert_to_rgb(example): | |
image = example["image"] | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
example["image"] = image | |
return example | |
dataset = dataset.map(convert_to_rgb) | |
train_dataset = dataset["train"] | |
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None | |
################ | |
# Reward Function for Training | |
################ | |
def accuracy_reward(completions, solution: list[str], **kwargs): | |
"""Reward function that checks if the completion matches the ground truth. | |
- If both gold and prediction are parseable β use math verification. | |
- If not parseable β compare as normalized text. | |
""" | |
rewards = [] | |
contents = [completion[0]["content"] for completion in completions] | |
for content, sol in zip(contents, solution): | |
try: | |
gold_parsed = parse(sol, extraction_mode="first_match") | |
except Exception: | |
gold_parsed = [] | |
if len(gold_parsed) != 0: | |
# Try parsing predicted answer too | |
try: | |
answer_parsed = parse( | |
content, | |
extraction_config=[ | |
LatexExtractionConfig( | |
normalization_config=NormalizationConfig( | |
nits=False, | |
malformed_operators=False, | |
basic_latex=True, | |
boxed="all", | |
units=True, | |
), | |
boxed_match_priority=0, | |
try_extract_without_anchor=False, | |
) | |
], | |
extraction_mode="first_match", | |
) | |
reward = float(verify(gold_parsed, answer_parsed)) | |
except Exception as e: | |
print(f"verify failed: {e}, answer: {content}, gold: {sol}") | |
reward = None | |
else: | |
# fallback to text match | |
reward = float(content.strip().lower() == sol.strip().lower()) | |
rewards.append(reward) | |
return rewards | |
################ | |
# Training | |
################ | |
trainer = GRPOTrainer( | |
model=model_args.model_name_or_path, | |
args=training_args, | |
reward_funcs=[think_format_reward, accuracy_reward], | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
peft_config=get_peft_config(model_args), | |
) | |
trainer.train() | |
# Save and push to hub | |
trainer.save_model(training_args.output_dir) | |
if training_args.push_to_hub: | |
trainer.push_to_hub(dataset_name=script_args.dataset_name) | |