Qwen2.5-0.5B-Instruct-GSM8K-PPO-Merged
(perhaps) state-of-the-art math models on GSM8k with less than 0.5 billion parameters
π Model Overview
This model is a merged version of multiple high-performing checkpoints derived from fine-tuning Qwen2.5-0.5B-Instruct using PPO (Proximal Policy Optimization) on the GSM8K mathematical reasoning dataset.
π― Key Features
- Base Model: Qwen/Qwen2.5-0.5B-Instruct (494M parameters)
- Training Algorithm: PPO via VERL framework
- Specialization: Mathematical reasoning and problem-solving
- Model Merging: Averaged from 3 best-performing checkpoints using mergekit
π Performance
Dataset | Score | Improvement |
---|---|---|
GSM8K | 58.91% | +9.31% over qwen2.5-0.5b-instruct model |
This represents a significant improvement in mathematical reasoning capabilities for a 0.5B parameter model.
π§ Usage
Quick Start
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
model_name = "alphadl/ppo-gsm8k-0.5b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Example: Mathematical reasoning
prompt = """Solve this step by step:
Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Let's think step by step and output the final answer after "####"."""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Expected Output Format
The model is trained to provide step-by-step mathematical reasoning followed by the final answer in the format:
#### [numerical_answer]
π οΈ Training Details
Training Framework
- Framework: VERL (Volcano Engine Reinforcement Learning)
- Algorithm: PPO (Proximal Policy Optimization)
- Data Source: GSM8K mathematical reasoning dataset
Model Merging Strategy
This model was created by merging 6 high-performing checkpoints using linear interpolation:
Checkpoint | GSM8K Score | Weight |
---|---|---|
global_step_5000 | 58.3% | 33% |
global_step_6000 | 58.7% | 33% |
global_step_7500 | 58.9% | 34% |
Result: The merged model achieved 58.91%, surpassing individual checkpoints!
Training Configuration
- Base Model: Qwen/Qwen2.5-0.5B-Instruct
- Training Steps: 7,500+ steps
- Validation Frequency: Every 1,000 steps
- Optimization: AdamW with learning rate scheduling
π― Use Cases
This model excels at:
- Mathematical Problem Solving: Arithmetic, algebra, basic geometry
- Step-by-Step Reasoning: Breaking down complex problems
- Educational Applications: Math tutoring and explanation
- Computational Tasks: Basic calculations with reasoning
β οΈ Limitations
- Model Size: As a 0.5B parameter model, it may struggle with very complex mathematical concepts
- Domain Specificity: Optimized for GSM8K-style problems; may not perform as well on other domains
- Context Length: Limited by the base model's context window (32K tokens)
π License
This model inherits the license from the base Qwen2.5-0.5B-Instruct model. Please refer to the original model card for licensing details.
π Acknowledgments
- Base Model: Qwen Team for Qwen2.5-0.5B-Instruct
- Training Framework: VERL Team for the PPO implementation
- Model Merging: mergekit for the averaging capabilities
- Dataset: GSM8K for mathematical reasoning data
π¬ Contact
For questions or issues, please open an issue in the repository or contact the model author.
This model was trained using VERL framework and merged using mergekit for optimal mathematical reasoning performance.
- Downloads last month
- 9