File size: 2,445 Bytes
421a883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eacebe3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
---

language:
- zho
- eng
- fra
- spa
- por
- deu
- ita
- rus
- jpn
- kor
- vie
- tha
- ara
base_model:
- Qwen/Qwen2.5-0.5B-Instruct
pipeline_tag: text-generation
license: apache-2.0
datasets:
- BAAI/IndustryCorpus2
- BAAI/Infinity-Instruct
- BAAI/Infinity-Preference
---


# mini_qwen



## Introduction

mini_qwen是一个从头开始训练的1B参数的大型语言模型(LLM)项目,包括预训练(PT)、微调(SFT)和直接偏好优化(DPO)3个部分。其中预训练和微调仅需要12G显存即可训练,直接偏好优化仅需要14G显存即可训练,这意味着使用T4显卡就可以开始你的训练之旅。  

mini_qwen是以Qwen2.5-0.5B-Instruct模型为基础,通过扩充模型隐藏状态层数、隐藏状态维度和注意力头数,增加参数量到1B,并进行参数随机初始化。训练数据使用北京智源人工智能研究院的预训练(16B token)、微调(9M 条)和偏好数据(60K 条),使用flash_attention_2进行加速,使用deepspeed在6张H800上训练25h(pt 1epoch)、43h(sft 3epoch)、1h(dpo 3epoch)。  



这是一次非常有趣且有价值的尝试,在整个过程中,本项目探究了尺度定律(scaling law)、复读机现象与微调阶段的知识注入,也解决了很多bug。本项目将尽可能详细地介绍整个训练过程,也欢迎交流讨论。



更多内容详见:https://github.com/qiufengqijun/mini_qwen

## Quickstart
使用方法如下:
```

from transformers import AutoModelForCausalLM, AutoTokenizer

import logging



logging.getLogger("transformers").setLevel(logging.ERROR) # 忽略警告



# 加载分词器与模型 

model_path = "/path/to/your/model"

model = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)





while True:

    prompt = input("用户:")

    

    text = prompt  # 预训练模型

    text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"  # 微调和直接偏好优化模型

    

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(**model_inputs, max_new_tokens=512)

    generated_ids = [

        output_ids[len(input_ids) :]

        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)

    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]



    print("助手:", response)

```