R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering
Introduction
R1-AQA is a audio question answering (AQA) model based on Qwen2-Audio-7B-Instruct
, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm.
This implementation has achieved state-of-the-art performance on the MMAU benchmark with only 38k post-training samples.
For more details, please refer to our Github and Technical Report.
Our main findings are as follows:
- The GRPO algorithm can be directly and effectively applied to the audio modality, even to
Qwen2-Audio-7B-Instruct
with only 8.2B parameters. - With only 38k post-training samples, reinforcement learning outperforms supervised fine-tuning, indicating that RL-based approaches can be effective without large datasets.
- The explicit reasoning process has not shown significant benefits for AQA tasks, and how to efficiently leverage deep thinking or step-by-step reasoning remains an open question for further research.
- Large audio language models (LALMs) still lag far behind humans auditory-language reasoning, suggesting that the RL-based approaches warrant further explorations.
Additional Notes:
- The AVQA training set originally consists of approximately 40k samples. However, we use only about 38k samples because some data sources have become invalid. Other datasets using YouTube sources face a similar issue, such as AudioSet. We believe that the missing 2k samples do not have a significant impact on the training results.
- The statement about the 8.2B parameters is based on the Qwen2-Audio Technical Report.
Table: Accuracies (%) on the MMAU benchmark
Model | Method | Test-mini | Test | Test-mini | Test | Test-mini | Test | Test-mini | Test |
---|---|---|---|---|---|---|---|---|---|
- | Human* | 86.31 | - | 78.22 | - | 82.17 | - | 82.23 | - |
Gemini Pro 2.0 Flash | Direct Inference* | 56.46 | 61.73 | 58.68 | 56.53 | 51.65 | 61.53 | 55.60 | 59.93 |
Audio Flamingo 2 | Direct Inference* | 61.56 | 65.10 | 73.95 | 72.90 | 30.93 | 40.26 | 55.48 | 59.42 |
GPT4o + Strong Cap. | Direct Inference* | 57.35 | 55.83 | 49.70 | 51.73 | 64.86 | 68.66 | 57.30 | 58.74 |
Llama-3-8B-Instruct + Strong Cap. | Direct Inference* | 50.75 | 49.10 | 48.93 | 48.93 | 55.25 | 62.70 | 52.10 | 53.57 |
Qwen2-Audio-7B-Instruct | Direct Inference* | 54.95 | 45.90 | 50.98 | 53.26 | 42.04 | 45.90 | 49.20 | 52.50 |
SALAMONN | Direct Inference* | 41.00 | 40.30 | 34.80 | 33.76 | 25.50 | 24.24 | 33.70 | 32.77 |
Qwen2-Audio-7B-Instruct | CoTA [1] | 60.06 | - | 64.30 | - | 60.70 | - | 61.71 | - |
Qwen2-Audio-7B-Instruct | Zero-Shot-CoT [2] | 61.86 | - | 56.29 | - | 55.26 | - | 57.80 | - |
Qwen2-Audio-7B-Instruct | GRPO (Ours) 1️⃣ | 69.37 | - | 66.77 | - | 57.36 | - | 64.50 | - |
Qwen2-Audio-7B-Instruct | GRPO (Ours) 2️⃣ | 68.77 | 69.76 | 64.37 | 61.40 | 63.66 | 62.70 | 65.60 | 64.36 |
Notes
* The data are sourced from the MMAU leaderboard.
[1] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).
[2] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).
1️⃣ It is the original model, identical to the one on Hugging Face and described in our technical report.
2️⃣ It is the model submitted to the MMAU leaderboard, trained multiple times to achieve balanced results.
Inference
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]
# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
Citation
@article{li2025reinforcement,
title={Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering},
author={Li, Gang and Liu, Jizhong and Dinkel, Heinrich and Niu, Yadong and Zhang, Junbo and Luan, Jian},
journal={arXiv preprint arXiv:2503.11197},
year={2025},
url={https://github.com/xiaomi-research/r1-aqa; https://huggingface.co/mispeech/r1-aqa}
}
- Downloads last month
- 547