SophieA17's picture
Update README.md
f225fc7 verified
metadata
license: apache-2.0
datasets:
  - K-and-K/knights-and-knaves
language:
  - en
  - zh
base_model:
  - SophieA17/Sophie0-Reasoning-SFT

Sophie0-Reasoning-GRPO

Introduction

Sophie0是一个从头实现的单人0.5B大语言模型项目,主要核心在于完整跑通预训练(Pretrain)监督微调(Supervised Fine-tune, SFT)直接偏好优化(Direct Preference Optimization, DPO)、以及基于 组内相对策略优化(Group Relative Policy Optimization, GRPO) 的显示思维链推理等主要流程。其中预训练阶段使用BAAI开源的多领域数据集,总数据量约11B Tokens,消耗52x4 GPU hours;微调阶段使用BAAI以及数学CoT数据在内总计9.74M行对话数据,消耗24x8 GPU hours;DPO阶段使用BAAI的偏好数据以及从LLama 3提取的英语对话数据在内总计159.3k对数据,消耗1x4 GPU hours;GRPO阶段使用Knights & Knaves 3ppl数据集以及从DeepSeek-R1提取的思维链对模型进行SFT和GRPO,SFT阶段总计有1.5k条数据,GRPO阶段总计有500条prompt,前者消耗10min x 1 GPU Times,后者消耗51x2 GPU hours.

此外,本项目进一步探讨了在下游SFT和DPO阶段完全使用变长(varlen)序列训练的可行性以及实现方式,充分利用了flash attention 2自带的varlen attentionvarlen RoPE算子,同时也探讨了批量推理时引入的填充token对输出的影响,以及如何通过设计兼容varlen的KV Cache类直接基于Huggingface GenerationMixin接口无缝切块填充推理和无填充变长序列推理

更多内容详见此处

QuickStart

可通过如下方法直接调用本模型

import os
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained("SophieA17/Sophie0-Reasoning-GRPO", trust_remote_code=True)
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("SophieA17/Sophie0-Reasoning-GRPO", trust_remote_code=True)

model = model.to(device="cuda:0", dtype=torch.bfloat16)
inputs = [
    "<s><user>Could you please introduce youself?</s>\n",
    "<s><user>Where is the best place for traveling in summer?</s>\n"
]

input_ids = tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left", return_token_type_ids=False).to(model.device)

generation_config = GenerationConfig(
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    max_new_tokens=1024,
    do_sample=True,
    top_k=20,
    top_p=0.8,
    temperature=0.8,
    repeat_penalty=1.1,
    use_cache=True
)

outputs = model.generate(**input_ids, use_varlen_inference=True, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False)