--- license: apache-2.0 datasets: - K-and-K/knights-and-knaves language: - en - zh base_model: - SophieA17/Sophie0-Reasoning-SFT --- Sophie0-Reasoning-GRPO ### Introduction Sophie0是一个从头实现的单人0.5B大语言模型项目,主要核心在于完整跑通**预训练(Pretrain)**、**监督微调(Supervised Fine-tune, SFT)**、**直接偏好优化(Direct Preference Optimization, DPO)**、以及基于 **组内相对策略优化(Group Relative Policy Optimization, GRPO)** 的显示**思维链推理**等主要流程。其中预训练阶段使用BAAI开源的多领域数据集,总数据量约11B Tokens,消耗52x4 GPU hours;微调阶段使用BAAI以及数学CoT数据在内总计9.74M行对话数据,消耗24x8 GPU hours;DPO阶段使用BAAI的偏好数据以及从LLama 3提取的英语对话数据在内总计159.3k对数据,消耗1x4 GPU hours;GRPO阶段使用Knights & Knaves 3ppl数据集以及从DeepSeek-R1提取的思维链对模型进行SFT和GRPO,SFT阶段总计有1.5k条数据,GRPO阶段总计有500条prompt,前者消耗10min x 1 GPU Times,后者消耗51x2 GPU hours. 此外,本项目进一步探讨了在下游SFT和DPO阶段完全使用变长(varlen)序列训练的可行性以及实现方式,充分利用了flash attention 2自带的`varlen attention` 和 `varlen RoPE`算子,同时也探讨了批量推理时引入的填充token对输出的影响,以及如何通过设计兼容varlen的KV Cache类直接基于Huggingface GenerationMixin接口无缝切块填充推理和无填充变长序列推理 更多内容详见[此处](https://github.com/Sophie10001b/sophie0) ### QuickStart 可通过如下方法直接调用本模型 ```python import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained("SophieA17/Sophie0-Reasoning-GRPO", trust_remote_code=True) tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("SophieA17/Sophie0-Reasoning-GRPO", trust_remote_code=True) model = model.to(device="cuda:0", dtype=torch.bfloat16) inputs = [ "Could you please introduce youself?\n", "Where is the best place for traveling in summer?\n" ] input_ids = tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left", return_token_type_ids=False).to(model.device) generation_config = GenerationConfig( bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1024, do_sample=True, top_k=20, top_p=0.8, temperature=0.8, repeat_penalty=1.1, use_cache=True ) outputs = model.generate(**input_ids, use_varlen_inference=True, generation_config=generation_config) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False) ```