Safetensors
omni
custom_code
AlvinSunYooo commited on
Commit
d6c94dc
·
verified ·
1 Parent(s): c52b24e

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ <div align="center">
5
+
6
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/logo.png" width="300em" ></img>
7
+
8
+ <!-- <img src="https://raw.githubusercontent.com/baichuan-inc/Baichuan-Omni-1.5/refs/heads/main/assets/logo.png" width="300em" ></img>
9
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/train-pipeline.png" width="300em" ></img> -->
10
+ <!-- <img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpm-o-26-framework-v2.png" width="300em" ></img> -->
11
+ **Open-source Omni-modal Foundation Model Supporting Text, Image, Video, and Audio Inputs as Well as Text and Audio Outputs**
12
+
13
+
14
+
15
+ <p align="center">
16
+ Baichuan-Omni-1.5 <a href="https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5">🤗</a> | Baichuan-Omni-1.5-Base <a href="https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5-Base">🤗</a> |Github <a href="https://github.com/baichuan-inc/Baichuan-Omni-1.5/">📖 </a> | Report <a href="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/baichuan_omni_1_5.pdf">📖</a>
17
+ </p>
18
+ </p>
19
+ <p align="center">
20
+ OpenMM-Medical <a href="https://huggingface.co/datasets/baichuan-inc/OpenMM-Medical">🤗</a> | OpenAudioBench <a href="https://huggingface.co/datasets/baichuan-inc/OpenAudioBench">🤗</a>
21
+ </p>
22
+ </div>
23
+
24
+
25
+ <!-- ## 介绍
26
+ **Baichuan-Omni-1.5** 是从 Baichuan-omni 升级的最新的、端到端训练的、支持全模态输入/双模态输出的多模态大模型。该模型使用Qwen2.5-7B昨晚大语言模型基座,可以以端到端方式,接受图像、视频、文本、音频作为输入,并且以可控的方式生成高质量文本和语音。
27
+
28
+ - **Baichuan-Omni-1.5-Base**: 为促进全模态大模型发展,我们开源了使用高质量海量数据训练的全模态基座模型。该模型未经SFT指令微调,可塑性强,是**业内首个**开源的**全模态基座模型**。
29
+
30
+ - **Baichuan-Omni-1.5**: 基于性能强悍的Baichuan-Omni-1.5-base,使用高质量的全模态对齐数据,进行端到端的多模态指令数据训练。Baichuan-Omni-1.5的纯文本、图像、视频、音频理解能力达到了 GPT-4o-mini 级别。可控音频生成的能力十分强大,在xxx和xxx评测集上取得最高表现。 -->
31
+
32
+
33
+ ## Baichuan-Omni-1.5
34
+
35
+ The Baichuan-Omni-1.5 is the latest, top-performing model in the Baichuan-omni series. This model is trained and inferred in an end-to-end manner. Compared with Baichuan-omni, this model has significant improvements in text/image/audio/video understanding and text/audio generation, and supports new features such as controllable real-time voice conversations and multi-modal real-time interactions. The main features of Baichuan-Omni-1.5 include:
36
+
37
+ - 🔥 **Possess Multimodal Understanding and Interaction Capabilities.**
38
+ Baichuan-Omni-1.5 not only supports images, videos, text, and audio as input, and generates high-quality text and voice output, but also **supports continuous video and audio streaming, and real-time voice interaction with users**. In OminiBench, a comprehensive evaluation benchmark for omnimodal understanding, Baichuan-Omni-1.5 has achieved the first-class level of the open source community and surpassed GPT-4o-mini.
39
+
40
+ - 💪 **Strong Visual Capability.**
41
+ Baichuan-Omni-1.5 has an average score of 73.3 on the OpenCompass list (comprehensive 10 mainstream multimodal evaluation benchmarks). **With the size of 7B, it surpasses mainstream commercial closed-source multimodal large models such as GPT-4o-mini, Gemini 1.5 Pro and Claude 3.5 Sonnet in single-image understanding**. In addition, its video understanding performance is also better than GPT-4V and Claude 3.5 Sonnet and open source omnimodal models.
42
+
43
+ - 🚀 **Leading Medical Image Understanding Capabilities.**
44
+ Baichuan-Omni-1.5 achieved the best performance on GMAI-MMBench and Openmm-Medical. Using only 7B LLM, the average score exceeded Qwen2-VL-72b by 3%, i.e. 80.7% v.s 83.8%.
45
+
46
+ - 🎙 **Excellent Voice Capabilities.**
47
+ Baichuan-Omni-1.5 **supports high-quality, controllable voice bilingual real-time conversations in Chinese and English**. It **outperforms GPT-4o-realtime** in speech understanding tasks (such as ASR and STT, etc.), and demonstrates **the highest speech generation performance among open source models** in semantic and acoustic evaluation of voice conversations.
48
+
49
+ - 🎬 **Powerful Real-world Understanding and Other Features.**
50
+ Baichuan-Omni-1.5 further optimizes the many visual understanding capabilities of Baichuan-omni. It can process images of any aspect ratio and up to 1.8 million pixels (such as 1344x1344). It scored 68.8 points on RealWorldQA, **surpassing commercial closed-source models such as GPT-4o-mini** and recently open-sourced omnimodal models. It scored 85.6/83.6 on the English/Chinese evaluation subsets of MMBench, respectively, which is also in the first echelon of models with the same size.
51
+
52
+ - 💫 **Provides [🤗 Base Model](https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5-Base) and [🤗 Instruct Model](https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5).**
53
+ Baichuan-Omni-1.5-Base is a high-performance foundational omni-modal model in the industry. Based on the powerful base, Baichuan-Omni-1.5 employs high-quality omnimodal alignment data to perform end-to-end multimodal instruction data training.
54
+
55
+ **Model Architecture**
56
+ <div align="center">
57
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/train-pipeline.png", width=80%></img>
58
+
59
+ </div>
60
+
61
+ <br>
62
+
63
+ - **End-to-end Omni-modal Architecture.** We carefully design **multi-stage and end-to-end** progressive training of different modal encoding/decoding modules to make full use of the rich knowledge in different modalities, we expect different modal knowledge to complement each other.
64
+ Notably, the model is fully trained end-to-end using NTP loss in the whole pre-training stage.
65
+ - **High-quality Controllable Audio Solution.** Multimodal system prompts have been redesigned to include traditional text system prompts and **speech system prompts** for specifying model sounds. It provides the flexibility to control voice style through text or speech samples at inference time, and supports advanced capabilities such as end-to-end voice cloning and timbre creation.
66
+
67
+
68
+ ### Open-source Evaluation Datasets
69
+
70
+ **OpenMM-Medical**
71
+
72
+ To comprehensively evaluate the model's multi-modal medical capabilities, we have constructed OpenMM-Medical, which includes data from 42 publicly available medical image datasets such as ACRIMA (retinal images), BioMediTech (microscope images), and CoronaHack (X-rays), totaling 88,996 images.
73
+
74
+ **OpenAudioBench**
75
+
76
+ To efficiently assess the model's "IQ" issues, we developed OpenAudioBench, comprising five end-to-end audio understanding sub-datasets: four public benchmarks (Llama Question, WEB QA, TriviaQA, AlpacaEval), and an internally created speech logical reasoning dataset by the Baichuan team, totaling 2,701 entries. This suite reflects the model's comprehensive "IQ" level.
77
+
78
+ <!-- **High-quality Medical Image Evaluation Dataset--Openmm-Medical**
79
+
80
+ - We have built a more diverse medical evaluation dataset named **Openmm-Medical** to evaluate large models in medical scenarios.
81
+ - The images in Openmm-Medical come from **42 public medical image datasets**, such as ACRIMA (fundus images), BioMediTech (microscope images), and CoronaHack (X-rays).
82
+ - **Openmm-Medical contains a total of 88,996 images**, and each image is designed as a **multiple-choice question to facilitate the evaluation of different large models.**
83
+ - To promote the development of omnimodal large models in the medical field, we will soon **open** this evaluation dataset.
84
+ -->
85
+
86
+ ### Evaluation
87
+
88
+ We sugguest readers to refer to our [**Github**](https://github.com/baichuan-inc/Baichuan-Omni-1.5/) for more details.
89
+
90
+ <div align="center">
91
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/performance.png" , width=80%>
92
+ </div>
93
+
94
+ <br>
95
+
96
+ <details>
97
+
98
+ <summary>click to view</summary>
99
+
100
+ #### Pure Text Understanding
101
+ <div align="center">
102
+ <table style="margin: 0 auto; text-align: center;">
103
+ <thead>
104
+ <tr>
105
+ <th class="tg-c3ow" colspan="7">Comprehensive Tasks</th>
106
+ </tr>
107
+ </thead>
108
+ <tbody>
109
+ <tr>
110
+ <td>Model</td>
111
+ <td>Size</td>
112
+ <td>MMLU (Acc.)</td>
113
+ <td>CMMLU (Acc.)</td>
114
+ <td>AGIEval (Acc.)</td>
115
+ <td>C-Eval (Acc.)</td>
116
+ <td>GAOKAO (Acc.)</td>
117
+ </tr>
118
+ <tr>
119
+ <td colspan="7">Proprietary Models</td>
120
+ </tr>
121
+ <tr>
122
+ <td>GPT 4o</td>
123
+ <td>-</td>
124
+ <td><b>88.0♢<br></td>
125
+ <td><b>78.3♢<br></td>
126
+ <td><b>62.3♢<br></td>
127
+ <td><b>86.0♢<br></td>
128
+ <td>-</td>
129
+ </tr>
130
+ <tr>
131
+ <td>GPT 4o mini</td>
132
+ <td>-</td>
133
+ <td>82.0</td>
134
+ <td>67.6</td>
135
+ <td>52.2</td>
136
+ <td>63.6</td>
137
+ <td>70.8</td>
138
+ </tr>
139
+ <tr>
140
+ <td colspan="7">Open-source Models (Pure text)</td>
141
+ </tr>
142
+ <tr>
143
+ <td>MAP-Neo</td>
144
+ <td>7B</td>
145
+ <td>58.2</td>
146
+ <td>55.1</td>
147
+ <td>33.9</td>
148
+ <td>57.5</td>
149
+ <td>-</td>
150
+ </tr>
151
+ <tr>
152
+ <td>Qwen1.5-Chat</td>
153
+ <td>7B</td>
154
+ <td>61.5</td>
155
+ <td>68.0</td>
156
+ <td>39.3</td>
157
+ <td>68.8</td>
158
+ <td>-</td>
159
+ </tr>
160
+ <tr>
161
+ <td>Llama3-Instruct</td>
162
+ <td>8B</td>
163
+ <td>67.1</td>
164
+ <td>51.7</td>
165
+ <td>38.4</td>
166
+ <td>50.7</td>
167
+ <td>-</td>
168
+ </tr>
169
+ <tr>
170
+ <td>OLMo</td>
171
+ <td>7B</td>
172
+ <td>28.4</td>
173
+ <td>25.6</td>
174
+ <td>19.9</td>
175
+ <td>27.3</td>
176
+ <td>-</td>
177
+ </tr>
178
+ <tr>
179
+ <td colspan="7">Open-source Models (Omni-modal)</td>
180
+ </tr>
181
+ <tr>
182
+ <td>VITA</td>
183
+ <td>8x7B</td>
184
+ <td>71.0*</td>
185
+ <td>46.6</td>
186
+ <td>46.2*</td>
187
+ <td>56.7*</td>
188
+ <td>-</td>
189
+ </tr>
190
+ <tr>
191
+ <td>VITA-1.5</td>
192
+ <td>7B</td>
193
+ <td>71.0</td>
194
+ <td>75.1</td>
195
+ <td>47.9</td>
196
+ <td>65.6</td>
197
+ <td>57.4</td>
198
+ </tr>
199
+ <tr>
200
+ <td>Baichuan-Omni</td>
201
+ <td>7B</td>
202
+ <td>65.3</td>
203
+ <td>72.2</td>
204
+ <td>47.7</td>
205
+ <td>68.9</td>
206
+ <td>-</td>
207
+ </tr>
208
+ <tr>
209
+ <td>MiniCPM-o 2.6</td>
210
+ <td>7B</td>
211
+ <td>65.3</td>
212
+ <td>63.3</td>
213
+ <td>50.9</td>
214
+ <td>61.5</td>
215
+ <td>56.3</td>
216
+ </tr>
217
+ <tr>
218
+ <td><b>Baichuan-Omni-1.5<br></td>
219
+ <td>7B</td>
220
+ <td>72.2</td>
221
+ <td>75.5</td>
222
+ <td>54.4</td>
223
+ <td>73.1</td>
224
+ <td><b>73.5<br></td>
225
+ </tr>
226
+ </tbody>
227
+ </table>
228
+ </div>
229
+
230
+ </details>
231
+
232
+
233
+ <details>
234
+
235
+ <summary>click to view</summary>
236
+
237
+ #### Image Understanding
238
+
239
+ <div align="center">
240
+ <table style="margin: 0 auto; text-align: center;">
241
+ <thead>
242
+ <tr>
243
+ <th class="tg-c3ow" colspan="9">Multi-choice &amp; Yes-or-No Question</th>
244
+ </tr>
245
+ </thead>
246
+ <tbody>
247
+ <tr>
248
+ <td>Model</td>
249
+ <td>Size</td>
250
+ <td>MMBench-EN (Acc.)</td>
251
+ <td>MMbench-CN (Acc.)</td>
252
+ <td>SEED-IMG (Acc.)</td>
253
+ <td>MMMU-val (Acc.)</td>
254
+ <td>HallusionBench (Acc.)</td>
255
+ </tr>
256
+ <tr>
257
+ <td colspan="9">Proprietary Models</td>
258
+ </tr>
259
+ <tr>
260
+ <td>GPT-4o</td>
261
+ <td>-</td>
262
+ <td>83.4♢</td>
263
+ <td>82.1♢</td>
264
+ <td>-</td>
265
+ <td><b>69.1♢<br></td>
266
+ <td><b>55.0♢<br></td>
267
+ </tr>
268
+ <tr>
269
+ <td>GPT-4o-mini</td>
270
+ <td>-</td>
271
+ <td>77.7</td>
272
+ <td>76.9</td>
273
+ <td>72.3</td>
274
+ <td>60.0♢</td>
275
+ <td>46.1♢</td>
276
+ </tr>
277
+ <tr>
278
+ <td colspan="9">Open Source Models (Vision-Language)</td>
279
+ </tr>
280
+ <tr>
281
+ <td>Qwen2-VL-7B</td>
282
+ <td>7B</td>
283
+ <td><b>86.4<br></td>
284
+ <td>81.9</td>
285
+ <td><b>76.5<br></td>
286
+ <td>52.7</td>
287
+ <td>50.6∗</td>
288
+ </tr>
289
+ <tr>
290
+ <td>MiniCPM-Llama3-V 2.5</td>
291
+ <td>8B</td>
292
+ <td>76.7</td>
293
+ <td>73.3</td>
294
+ <td>72.4</td>
295
+ <td>45.8∗</td>
296
+ <td>42.5</td>
297
+ </tr>
298
+ <tr>
299
+ <td colspan="9">Open Source Models (Omni-modal)</td>
300
+ </tr>
301
+ <tr>
302
+ <td>VITA</td>
303
+ <td>8x7B</td>
304
+ <td>74.7</td>
305
+ <td>71.4</td>
306
+ <td>72.6</td>
307
+ <td>45.3</td>
308
+ <td>39.7∗</td>
309
+ </tr>
310
+ <tr>
311
+ <td>VITA-1.5</td>
312
+ <td>7B</td>
313
+ <td>80.8</td>
314
+ <td>80.2</td>
315
+ <td>74.2</td>
316
+ <td>53.1</td>
317
+ <td>44.1</td>
318
+ </tr>
319
+ <tr>
320
+ <td>Baichuan-Omni</td>
321
+ <td>7B</td>
322
+ <td>76.2</td>
323
+ <td>74.9</td>
324
+ <td>74.1</td>
325
+ <td>47.3</td>
326
+ <td>47.8</td>
327
+ </tr>
328
+ <tr>
329
+ <td>MiniCPM-o 2.6</td>
330
+ <td>7B</td>
331
+ <td>83.6</td>
332
+ <td>81.8</td>
333
+ <td>75.4</td>
334
+ <td>51.1</td>
335
+ <td>50.1</td>
336
+ </tr>
337
+ <tr>
338
+ <td><b>Baichuan-Omni-1.5<br></td>
339
+ <td>7B</td>
340
+ <td>85.6</td>
341
+ <td><b>83.6<br></td>
342
+ <td>75.7</td>
343
+ <td>53.9</td>
344
+ <td>49.7</td>
345
+ </tr>
346
+ </tbody>
347
+ </table>
348
+ </div>
349
+
350
+
351
+ <br>
352
+
353
+ <div align="center">
354
+ <table style="margin: 0 auto; text-align: center;">
355
+ <thead>
356
+ <tr>
357
+ <th class="tg-c3ow" colspan="9">Visual Question Answering</th>
358
+ </tr>
359
+ </thead>
360
+ <tbody>
361
+ <tr>
362
+ <td>Model</td>
363
+ <td>Size</td>
364
+ <td>RealWorldQA (Acc.)</td>
365
+ <td>MathVista-mini (Acc.)</td>
366
+ <td>TextVQA-val (Acc.)</td>
367
+ <td>ChartQA (Acc.)</td>
368
+ <td>OCRBench (Acc.)</td>
369
+ </tr>
370
+ <tr>
371
+ <td colspan="8">Proprietary Models</td>
372
+ </tr>
373
+ <tr>
374
+ <td>GPT-4o</td>
375
+ <td>-</td>
376
+ <td><b>75.4♢<br></td>
377
+ <td>63.8♢</td>
378
+ <td>-</td>
379
+ <td>85.7♢</td>
380
+ <td>73.6♢</td>
381
+ </tr>
382
+ <tr>
383
+ <td>GPT-4o-mini</td>
384
+ <td>-</td>
385
+ <td>66.3</td>
386
+ <td>53.4</td>
387
+ <td>66.8</td>
388
+ <td>-</td>
389
+ <td>77.4</td>
390
+ </tr>
391
+ <tr>
392
+ <td colspan="8">Open Source Models (Vision-Language)</td>
393
+ </tr>
394
+ <tr>
395
+ <td>Qwen2-VL-7B</td>
396
+ <td>7B</td>
397
+ <td>69.7</td>
398
+ <td>58.2∗</td>
399
+ <td><b>84.3∗<br></td>
400
+ <td>83.0∗</td>
401
+ <td>84.5∗</td>
402
+ </tr>
403
+ <tr>
404
+ <td>MiniCPM-Llama3-V 2.5</td>
405
+ <td>8B</td>
406
+ <td>63.5</td>
407
+ <td>54.3∗</td>
408
+ <td>76.6</td>
409
+ <td>72.0</td>
410
+ <td>72.5</td>
411
+ </tr>
412
+ <tr>
413
+ <td colspan="8">Open Source Models (Omni-modal)</td>
414
+ </tr>
415
+ <tr>
416
+ <td>VITA</td>
417
+ <td>8x7B</td>
418
+ <td>59.0</td>
419
+ <td>44.9∗</td>
420
+ <td>71.8</td>
421
+ <td>76.6</td>
422
+ <td>68.5∗</td>
423
+ </tr>
424
+ <tr>
425
+ <td>VITA-1.5</td>
426
+ <td>7B</td>
427
+ <td>66.8</td>
428
+ <td><b>66.5<br></td>
429
+ <td>74.9</td>
430
+ <td>79.6</td>
431
+ <td>73.3</td>
432
+ </tr>
433
+ <tr>
434
+ <td>Baichuan-Omni</td>
435
+ <td>7B</td>
436
+ <td>62.6</td>
437
+ <td>51.9</td>
438
+ <td>74.3</td>
439
+ <td>79.6</td>
440
+ <td>70.0</td>
441
+ </tr>
442
+ <tr>
443
+ <td>MiniCPM-o 2.6</td>
444
+ <td>7B</td>
445
+ <td>67.7</td>
446
+ <td>64.6</td>
447
+ <td>80.1</td>
448
+ <td><b>87.6<br></td>
449
+ <td><b>89.7∗<br></td>
450
+ </tr>
451
+ <tr>
452
+ <td>Baichuan-Omni-1.5 </td>
453
+ <td>7B</td>
454
+ <td>68.8</td>
455
+ <td>63.6</td>
456
+ <td>83.2</td>
457
+ <td>84.9</td>
458
+ <td>84.0</td>
459
+ </tr>
460
+ </tbody>
461
+ </table>
462
+ </div>
463
+
464
+
465
+ </details>
466
+
467
+ <details>
468
+
469
+ <summary>click to view</summary>
470
+
471
+ #### Video Understanding
472
+ <div align="center">
473
+ <table style="margin: 0 auto; text-align: center;">
474
+ <thead>
475
+ <tr>
476
+ <th colspan="7">General VQA&nbsp;&nbsp;&nbsp;</th>
477
+ </tr>
478
+ </thead>
479
+ <tbody>
480
+ <tr>
481
+ <td>Model</td>
482
+ <td>Size</td>
483
+ <td># Frames</td>
484
+ <td>MVBench (Acc.)</td>
485
+ <td>Egoschema (Acc.)</td>
486
+ <td>VideoMME (Acc.)</td>
487
+ <td>Perception-Test (Acc.)</td>
488
+ </tr>
489
+ <tr>
490
+ <td colspan="7">Proprietary Models</td>
491
+ </tr>
492
+ <tr>
493
+ <td>Gemini 1.5 Pro</td>
494
+ <td>-</td>
495
+ <td>-</td>
496
+ <td><b>81.3♢<br></td>
497
+ <td>63.2*</td>
498
+ <td><b>75.0♢<br></td>
499
+ <td>-</td>
500
+ </tr>
501
+ <tr>
502
+ <td>GPT 4o mini</td>
503
+ <td>-</td>
504
+ <td>-</td>
505
+ <td>55.2</td>
506
+ <td>58.5</td>
507
+ <td>63.6</td>
508
+ <td>48.2</td>
509
+ </tr>
510
+ <tr>
511
+ <td>GPT 4o</td>
512
+ <td>-</td>
513
+ <td>-</td>
514
+ <td>-</td>
515
+ <td><b>77.2*<br></td>
516
+ <td>71.9♢</td>
517
+ <td>-</td>
518
+ </tr>
519
+ <tr>
520
+ <td>GPT 4V</td>
521
+ <td>-</td>
522
+ <td>-</td>
523
+ <td>43.7♢</td>
524
+ <td>55.6*</td>
525
+ <td>59.9♢</td>
526
+ <td>-</td>
527
+ </tr>
528
+ <tr>
529
+ <td colspan="7">Open-source Models (Vision-language)</td>
530
+ </tr>
531
+ <tr>
532
+ <td>Qwen2-VL-7B</td>
533
+ <td>7B</td>
534
+ <td>2 fps (max 768)</td>
535
+ <td>67.0* | 64.4</td>
536
+ <td>66.7* | 66.6</td>
537
+ <td>63.3* | 59.0</td>
538
+ <td>62.3* | 60.3</td>
539
+ </tr>
540
+ <tr>
541
+ <td>AnyGPT</td>
542
+ <td>8B</td>
543
+ <td>48</td>
544
+ <td>33.2</td>
545
+ <td>32.1</td>
546
+ <td>29.8</td>
547
+ <td>29.1</td>
548
+ </tr>
549
+ <tr>
550
+ <td>VideoLLaMA 2</td>
551
+ <td>7B</td>
552
+ <td>16</td>
553
+ <td>54.6*</td>
554
+ <td>51.7*</td>
555
+ <td>46.6*</td>
556
+ <td>51.4*</td>
557
+ </tr>
558
+ <tr>
559
+ <td>VideoChat2</td>
560
+ <td>7B</td>
561
+ <td>16</td>
562
+ <td>51.1*</td>
563
+ <td>42.1♢</td>
564
+ <td>33.7♢</td>
565
+ <td>47.3♢</td>
566
+ </tr>
567
+ <tr>
568
+ <td>LLaVA-NeXT-Video</td>
569
+ <td>7B</td>
570
+ <td>32</td>
571
+ <td>46.5♢</td>
572
+ <td>43.9♢</td>
573
+ <td>33.7♢</td>
574
+ <td>48.8♢</td>
575
+ </tr>
576
+ <tr>
577
+ <td>Video-LLaVA</td>
578
+ <td>7B</td>
579
+ <td>8</td>
580
+ <td>41.0♢</td>
581
+ <td>38.4♢</td>
582
+ <td>39.9♢</td>
583
+ <td>44.3♢</td>
584
+ </tr>
585
+ <tr>
586
+ <td colspan="7">Open-source Models (Omni-modal)</td>
587
+ </tr>
588
+ <tr>
589
+ <td>VITA</td>
590
+ <td>8x7B</td>
591
+ <td>1 fps (max 32)</td>
592
+ <td>53.4</td>
593
+ <td>53.9</td>
594
+ <td>56.1</td>
595
+ <td>56.2</td>
596
+ </tr>
597
+ <tr>
598
+ <td>VITA-1.5</td>
599
+ <td>7B</td>
600
+ <td>1 fps (max 32)</td>
601
+ <td>55.5</td>
602
+ <td>54.7</td>
603
+ <td>57.3</td>
604
+ <td>57.6</td>
605
+ </tr>
606
+ <tr>
607
+ <td>Baichuan-Omni</td>
608
+ <td>7B</td>
609
+ <td>1 fps (max 32)</td>
610
+ <td>60.9</td>
611
+ <td>58.8</td>
612
+ <td>58.2</td>
613
+ <td>56.8</td>
614
+ </tr>
615
+ <tr>
616
+ <td>MiniCPM-o 2.6</td>
617
+ <td>7B</td>
618
+ <td>1 fps (max 64)</td>
619
+ <td>58.6</td>
620
+ <td>50.7</td>
621
+ <td>63.4</td>
622
+ <td>66.6</td>
623
+ </tr>
624
+ <tr>
625
+ <td>Baichuan-Omini-1.5</td>
626
+ <td>7B</td>
627
+ <td>1 fps (max 32)</td>
628
+ <td> 63.7 </td>
629
+ <td> 62.4 </td>
630
+ <td> 60.1 </td>
631
+ <td> <b>68.9 <br> </td>
632
+ </tr>
633
+ </tbody>
634
+ </table>
635
+ </div>
636
+
637
+ <br>
638
+
639
+ <div align="center">
640
+ <table style="margin: 0 auto; text-align: center;">
641
+ <thead>
642
+ <tr>
643
+ <th colspan="7">Open-ended VQA</th>
644
+ </tr>
645
+ </thead>
646
+ <tbody>
647
+ <tr>
648
+ <td rowspan="2">Model</td>
649
+ <td rowspan="2">Size</td>
650
+ <td rowspan="2"># Frames</td>
651
+ <td colspan="2">ActivityNet-QA</td>
652
+ <td colspan="2">MSVD-QA</td>
653
+ </tr>
654
+ <tr>
655
+ <td>(Acc.)</td>
656
+ <td>(Score)</td>
657
+ <td>(Acc.)</td>
658
+ <td>(Score)</td>
659
+ </tr>
660
+ <tr>
661
+ <td colspan="7">Proprietary Models</td>
662
+ </tr>
663
+ <tr>
664
+ <td>Gemini 1.5 Pro</td>
665
+ <td>-</td>
666
+ <td>-</td>
667
+ <td>56.7*</td>
668
+ <td>-</td>
669
+ <td>-</td>
670
+ <td>-</td>
671
+ </tr>
672
+ <tr>
673
+ <td>GPT 4o mini</td>
674
+ <td>-</td>
675
+ <td>1 fps (max 32)</td>
676
+ <td>62.1</td>
677
+ <td>3.1</td>
678
+ <td>67.5</td>
679
+ <td>3.3</td>
680
+ </tr>
681
+ <tr>
682
+ <td>GPT 4o</td>
683
+ <td>-</td>
684
+ <td>-</td>
685
+ <td>61.9*</td>
686
+ <td>-</td>
687
+ <td>-</td>
688
+ <td>-</td>
689
+ </tr>
690
+ <tr>
691
+ <td>GPT 4V</td>
692
+ <td>-</td>
693
+ <td>-</td>
694
+ <td>59.5*</td>
695
+ <td>-</td>
696
+ <td>-</td>
697
+ <td>-</td>
698
+ </tr>
699
+ <tr>
700
+ <td colspan="7">Open-source Models (Vision-language)</td>
701
+ </tr>
702
+ <tr>
703
+ <td>Qwen2 VL</td>
704
+ <td>7B</td>
705
+ <td>2 fps (max 768)</td>
706
+ <td>17.4</td>
707
+ <td>1.9</td>
708
+ <td>61.1</td>
709
+ <td>3.5</td>
710
+ </tr>
711
+ <tr>
712
+ <td>VideoLLaMA 2</td>
713
+ <td>7B</td>
714
+ <td>16</td>
715
+ <td>50.2*</td>
716
+ <td>3.3*</td>
717
+ <td>70.9*</td>
718
+ <td>3.8*</td>
719
+ </tr>
720
+ <tr>
721
+ <td>VideoChat2</td>
722
+ <td>7B</td>
723
+ <td>16</td>
724
+ <td>49.1*</td>
725
+ <td>3.3*</td>
726
+ <td>70.0*</td>
727
+ <td>3.9*</td>
728
+ </tr>
729
+ <tr>
730
+ <td>LLaVA-NeXT-Video</td>
731
+ <td>7B</td>
732
+ <td>32</td>
733
+ <td>53.5*</td>
734
+ <td>3.2*</td>
735
+ <td>67.4</td>
736
+ <td>3.4</td>
737
+ </tr>
738
+ <tr>
739
+ <td>Video-LLaVA</td>
740
+ <td>7B</td>
741
+ <td>8</td>
742
+ <td>45.3*</td>
743
+ <td>3.3*</td>
744
+ <td>70.7*</td>
745
+ <td>3.9*</td>
746
+ </tr>
747
+ <tr>
748
+ <td colspan="7">Open-source Models (Omni-modal)</td>
749
+ </tr>
750
+ <tr>
751
+ <td>VITA</td>
752
+ <td>8x7B</td>
753
+ <td>1 fps (max 32)</td>
754
+ <td>55.0</td>
755
+ <td>3.5</td>
756
+ <td>63.9</td>
757
+ <td>3.7</td>
758
+ </tr>
759
+ <tr>
760
+ <td>VITA-1.5</td>
761
+ <td>7B</td>
762
+ <td>1 fps (max 32)</td>
763
+ <td>59.6</td>
764
+ <td>3.0</td>
765
+ <td>67.6</td>
766
+ <td>3.3</td>
767
+ </tr>
768
+ <tr>
769
+ <td>Baichuan-Omni</td>
770
+ <td>7B</td>
771
+ <td>1 fps (max 48)</td>
772
+ <td>58.6</td>
773
+ <td><b>3.7<br></td>
774
+ <td>72.2</td>
775
+ <td> <b>4.0<br> </td>
776
+ </tr>
777
+ <tr>
778
+ <td>MiniCPM-o 2.6</td>
779
+ <td>7B</td>
780
+ <td>1 fps (max 64)</td>
781
+ <td><b>63.0<br></td>
782
+ <td>3.1</td>
783
+ <td>73.7</td>
784
+ <td>3.6</td>
785
+ </tr>
786
+ <tr>
787
+ <td>Baichuan-Omni-1.5</td>
788
+ <td>7B</td>
789
+ <td>1 fps (max 48)</td>
790
+ <td> 62.0</td>
791
+ <td> 3.1</td>
792
+ <td> <b> 74.2 <br></td>
793
+ <td> 3.6</td>
794
+ </tr>
795
+ </tbody>
796
+ </table>
797
+ </div>
798
+
799
+ </details>
800
+
801
+
802
+ <details>
803
+
804
+ <summary>click to view</summary>
805
+
806
+ #### Audio Comprehensive and Speech Generation
807
+ <div align="center">
808
+ <table style="margin: 0 auto; text-align: center;">
809
+ <thead>
810
+ <tr>
811
+ <th colspan="12">Audio Comprehensive Capacity</th>
812
+ </tr></thead>
813
+ <tbody>
814
+ <tr>
815
+ <td rowspan="2">Model</td>
816
+ <td rowspan="2">Size</td>
817
+ <td colspan="2">Reasoning QA</td>
818
+ <td colspan="2">Llama Questions</td>
819
+ <td colspan="2">Web Questions</td>
820
+ <td colspan="2">TriviaQA</td>
821
+ <td colspan="2">AlpacaEval</td>
822
+ </tr>
823
+ <tr>
824
+ <td>s→t</td>
825
+ <td>s→s</td>
826
+ <td>s→t</td>
827
+ <td>s→s</td>
828
+ <td>s→t</td>
829
+ <td>s→s</td>
830
+ <td>s→t</td>
831
+ <td>s→s</td>
832
+ <td>s→t</td>
833
+ <td>s→s</td>
834
+ </tr>
835
+ <tr>
836
+ <td colspan="12">Proprietary Models</td>
837
+ </tr>
838
+ <tr>
839
+ <td>GPT-4o-Audio</td>
840
+ <td>-</td>
841
+ <td><b>55.6</td>
842
+ <td>-</td>
843
+ <td><b>88.4</td>
844
+ <td>-</td>
845
+ <td><b>8.10</td>
846
+ <td>-</td>
847
+ <td><b>9.06</td>
848
+ <td>-</td>
849
+ <td><b>8.01</td>
850
+ <td>-</td>
851
+ </tr>
852
+ <tr>
853
+ <td colspan="12">Open-source Models (Pure Audio)</td>
854
+ </tr>
855
+ <tr>
856
+ <td>GLM-4-Voice</td>
857
+ <td>9B</td>
858
+ <td>-</td>
859
+ <td>26.5</td>
860
+ <td>-</td>
861
+ <td>71.0</td>
862
+ <td>-</td>
863
+ <td>5.15</td>
864
+ <td>-</td>
865
+ <td>4.66</td>
866
+ <td>-</td>
867
+ <td>4.89</td>
868
+ </tr>
869
+ <tr>
870
+ <td colspan="12">Open-source Models (Omni-modal)</td>
871
+ </tr>
872
+ <tr>
873
+ <td>VITA-1.5</td>
874
+ <td>7B</td>
875
+ <td>41.0</td>
876
+ <td>-</td>
877
+ <td>74.2</td>
878
+ <td>-</td>
879
+ <td>5.73</td>
880
+ <td>-</td>
881
+ <td>4.68</td>
882
+ <td>-</td>
883
+ <td>6.82</td>
884
+ <td>-</td>
885
+ </tr>
886
+ <tr>
887
+ <td>MiniCPM-o 2.6</td>
888
+ <td>7B</td>
889
+ <td>38.6</td>
890
+ <td>-</td>
891
+ <td>77.8</td>
892
+ <td>-</td>
893
+ <td>6.86</td>
894
+ <td>-</td>
895
+ <td>6.19</td>
896
+ <td>-</td>
897
+ <td>5.18</td>
898
+ <td>-</td>
899
+ </tr>
900
+ <tr>
901
+ <td><b>Baichuan-Omni-1.5</td>
902
+ <td>7B</td>
903
+ <td>50.0</td>
904
+ <td><b>40.9</td>
905
+ <td>78.5</td>
906
+ <td><b>75.3</td>
907
+ <td>5.91</td>
908
+ <td><b>5.52</td>
909
+ <td>5.72</td>
910
+ <td>5.31</td>
911
+ <td>7.79</td>
912
+ <td><b>6.94</td>
913
+ </tr>
914
+ </tbody>
915
+ </table>
916
+ </div>
917
+
918
+
919
+ </details>
920
+
921
+
922
+
923
+ <details>
924
+
925
+ <summary>click to view</summary>
926
+
927
+ #### Omni-modal Understanding
928
+
929
+ <div align="center">
930
+ <table style="margin: 0 auto; text-align: center;">
931
+ <thead>
932
+ <tr>
933
+ <th colspan="7">Omni-Undesratnding </th>
934
+ </tr>
935
+ <thead>
936
+ <tbody>
937
+ <tr>
938
+ <td>Model</td>
939
+ <td>Size</td>
940
+ <td>Image & Audio</td>
941
+ <td>Image Caption & Audio</td>
942
+ <td>Image & Audio Transcript</td>
943
+ <td>Image Caption & Audio Transcript</td>
944
+ </tr>
945
+ </thead>
946
+ <tr>
947
+ <td colspan="6">Proprietary Models</td>
948
+ </tr>
949
+ <tr>
950
+ <td>GPT4o-mini</td>
951
+ <td>-</td>
952
+ <td>-</td>
953
+ <td>-</td>
954
+ <td>37.0</td>
955
+ <td>37.7</td>
956
+ </tr>
957
+ <tr>
958
+ <td colspan="6">Open-source Models (Omni-modal)</td>
959
+ </tr>
960
+ <tr>
961
+ <td>VITA</td>
962
+ <td>8x7B</td>
963
+ <td>33.1</td>
964
+ <td>31.8</td>
965
+ <td>42.0</td>
966
+ <td>44.2</td>
967
+ </tr>
968
+ <tr>
969
+ <td>VITA-1.5</td>
970
+ <td>7B</td>
971
+ <td>33.4</td>
972
+ <td>29.6</td>
973
+ <td>48.5</td>
974
+ <td><b>47.2<br></td>
975
+ </tr>
976
+ <tr>
977
+ <td>Baichuan-Omni</td>
978
+ <td>7B</td>
979
+ <td>32.2</td>
980
+ <td>26.5</td>
981
+ <td>42.6</td>
982
+ <td>44.2</td>
983
+ </tr>
984
+ <tr>
985
+ <td>MiniCPM-o 2.6</td>
986
+ <td>7B</td>
987
+ <td>40.5</td>
988
+ <td>30.8</td>
989
+ <td><b>53.2<br></td>
990
+ <td>46.3</td>
991
+ </tr>
992
+ <tr>
993
+ <td><b>Baichuan-Omni-1.5<br></td>
994
+ <td>7B</td>
995
+ <td><b>42.9<br></td>
996
+ <td><b>37.7<br></td>
997
+ <td>47.9</td>
998
+ <td>46.9</td>
999
+ </tr>
1000
+ </tbody>
1001
+ </table>
1002
+ </div>
1003
+
1004
+ </details>
1005
+
1006
+ <details>
1007
+
1008
+ <summary>click to view</summary>
1009
+
1010
+ #### Medical Image Understanding Capabilities
1011
+
1012
+ <div align="center">
1013
+ <table style="margin: 0 auto; text-align: center;">
1014
+ <thead>
1015
+ <tr>
1016
+ <th colspan="7">Medical Understanding&nbsp;&nbsp;&nbsp;</th>
1017
+ </tr>
1018
+ </thead>
1019
+ <tbody>
1020
+ <tr>
1021
+ <td>Model</td>
1022
+ <td>Size</td>
1023
+ <td>GMAI-MMB-VAL (Acc.)</td>
1024
+ <td>OpenMM-Medical (Acc.)</td>
1025
+ </tr>
1026
+ </thead>
1027
+ <tr>
1028
+ <td colspan="4">Proprietary Models</td>
1029
+ </tr>
1030
+ <tr>
1031
+ <td>GPT4o-mini</td>
1032
+ <td>-</td>
1033
+ <td>46.4</td>
1034
+ <td>74.3</td>
1035
+ </tr>
1036
+ <tr>
1037
+ <td colspan="4">Open-source Models (Vision-Language)</td>
1038
+ </tr>
1039
+ <tr>
1040
+ <td>Qwen2 VL</td>
1041
+ <td>7B</td>
1042
+ <td>46.3</td>
1043
+ <td>76.9</td>
1044
+ </tr>
1045
+ <tr>
1046
+ <td>Qwen2 VL</td>
1047
+ <td>72B</td>
1048
+ <td><b>50.7<br></td>
1049
+ <td>80.7</td>
1050
+ </tr>
1051
+ <tr>
1052
+ <td colspan="4">Open-source Models (Omni-modal)</td>
1053
+ </tr>
1054
+ <tr>
1055
+ <td>VITA-1.5</td>
1056
+ <td>7B</td>
1057
+ <td>36.7</td>
1058
+ <td>67.1</td>
1059
+ </tr>
1060
+ <tr>
1061
+ <td>MiniCPM-o 2.6</td>
1062
+ <td>7B</td>
1063
+ <td>41.5</td>
1064
+ <td>73.6</td>
1065
+ </tr>
1066
+ <tr>
1067
+ <td><b>Baichuan-Omni-1.5<br></td>
1068
+ <td>7B</td>
1069
+ <td>49.9</td>
1070
+ <td><b>83.8<br></td>
1071
+ </tr>
1072
+ </tbody>
1073
+ </table>
1074
+ </div>
1075
+
1076
+ </details>
1077
+
1078
+ ## Examples
1079
+ <br>
1080
+
1081
+ <div style="display: flex; flex-direction: column; align-items: center;">
1082
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/pipeline.png" alt="pipeline" style="margin-bottom: 5px;">
1083
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/math.png" alt="math" style="margin-bottom: 5px;">
1084
+ <img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/fly_bill.png" alt="fly_bill" style="margin-bottom: 5px;">
1085
+ </div>
1086
+
1087
+
1088
+ ## 🚀 Quick Start
1089
+ We recommend interested scholars to visit our github repo for more details. [**Github**](https://github.com/baichuan-inc/Baichuan-Omni-1.5/)
1090
+
1091
+
1092
+ ### Statement
1093
+ - We hereby declare that our team has not developed any applications based on Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models, not on iOS, Android, the web, or any other platform. We strongly call on all users not to use Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models for any activities that harm national / social security or violate the law. Also, we ask users not to use Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models for Internet services that have not undergone appropriate security reviews and filings. We hope that all users can abide by this principle and ensure that the development of technology proceeds in a regulated and legal environment.
1094
+
1095
+ - We have done our best to ensure the compliance of the data used in the model training process. However, despite our considerable efforts, there may still be some unforeseeable issues due to the complexity of the model and data. Therefore, if any problems arise due to the use of Baichuan-Omni-1.5/Baichuan-Omni-1.5-base open-source models, including but not limited to data security issues, public opinion risks, or any risks and problems brought about by the model being misled, abused, spread or improperly exploited, we will not assume any responsibility.
1096
+
1097
+
1098
+
1099
+ ### License
1100
+ The community usage of Baichuan-Omni-1.5/Baichuan-Omni-1.5-base requires adherence to [Apache 2.0](https://github.com/baichuan-inc/Baichuan-Omni-1.5/blob/main/LICENSE) and [Community License for Baichuan-Omni-1.5 Models](https://github.com/baichuan-inc/Baichuan-Omni-1.5/blob/main/LICENSE). The Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models supports commercial use. If you plan to use the Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models or its derivatives for commercial purposes, please ensure that your entity meets the following conditions:
1101
+
1102
+ 1. The Daily Active Users (DAU) of your or your affiliate's service or product is less than 1 million.
1103
+ 2. Neither you nor your affiliates are software service providers or cloud service providers.
1104
+ 3. There is no possibility for you or your affiliates to grant the commercial license given to you, to reauthorize it to other third parties without Baichuan's permission.
1105
+
1106
+ Upon meeting the above conditions, you need to submit the application materials required by the Baichuan-Omni-1.5 Model Community License Agreement via the following contact email: [email protected]. Once approved, Baichuan will hereby grant you a non-exclusive, global, non-transferable, non-sublicensable, revocable commercial copyright license.
1107
+
1108
+ <!-- ### Citation
1109
+
1110
+ If you find our work helpful, please consider citing our papers 📝 and liking this project ❤️!
1111
+ ```bib
1112
+ @article{
1113
+ } -->
1114
+ ```
added_tokens.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<B_APE>": 151652,
3
+ "<B_CODE>": 151651,
4
+ "<B_FUNC>": 151650,
5
+ "<B_SYS>": 151646,
6
+ "<B_USYS>": 151647,
7
+ "<C_A>": 151649,
8
+ "<C_Q>": 151648,
9
+ "<audio_delim_baichuan>": 151674,
10
+ "<audio_end_baichuan>": 151658,
11
+ "<audio_pad_baichuan>": 151659,
12
+ "<audio_start_baichuan>": 151657,
13
+ "<audiogen_end_baichuan>": 151679,
14
+ "<audiogen_start_baichuan>": 151678,
15
+ "<audiotext_end_baichuan>": 151676,
16
+ "<audiotext_pad_baichuan>": 151677,
17
+ "<audiotext_start_baichuan>": 151675,
18
+ "<baichuan_pad_token>": 151672,
19
+ "<box_delim_baichuan>": 151666,
20
+ "<box_end_baichuan>": 151665,
21
+ "<box_start_baichuan>": 151664,
22
+ "<calc_end>": 151655,
23
+ "<calc_start>": 151654,
24
+ "<function_calling>": 151653,
25
+ "<img_delim_baichuan>": 151669,
26
+ "<img_end_baichuan>": 151661,
27
+ "<img_newline_baichuan>": 151663,
28
+ "<img_pad_baichuan>": 151662,
29
+ "<img_start_baichuan>": 151660,
30
+ "<inner_think>": 151656,
31
+ "<polygon_end_baichuan>": 151671,
32
+ "<polygon_start_baichuan>": 151670,
33
+ "<ref_end_baichuan>": 151668,
34
+ "<ref_start_baichuan>": 151667,
35
+ "<reserved_113>": 151673,
36
+ "<|endoftext|>": 151643,
37
+ "<|im_end|>": 151645,
38
+ "<|im_start|>": 151644
39
+ }
audio_modeling_omni.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, fire
2
+ from typing import Optional
3
+ import torch.distributed
4
+ from torch.nn import functional as F
5
+ from flash_attn import flash_attn_varlen_func
6
+ from torch import nn
7
+ import numpy as np
8
+ import deepspeed
9
+ from transformers.activations import ACT2FN
10
+ from dataclasses import dataclass
11
+ from transformers.modeling_outputs import ModelOutput
12
+ try:
13
+ from .vector_quantize import VectorQuantize
14
+ except:
15
+ from vector_quantize import VectorQuantize
16
+
17
+ from .flow_matching import (
18
+ ConditionalDecoder,
19
+ ConditionalCFM,
20
+ )
21
+
22
+ import math
23
+ import copy
24
+
25
+ def sinusoids(length, channels, max_timescale=10000):
26
+ """Returns sinusoids for positional embedding"""
27
+ assert channels % 2 == 0
28
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
29
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
30
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
31
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
32
+
33
+ def get_sequence_mask(inputs, inputs_length):
34
+ if inputs.dim() == 3:
35
+ bsz, tgt_len, _ = inputs.size()
36
+ else:
37
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
38
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
39
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
40
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
41
+ return sequence_mask, unpacking_index
42
+
43
+ def unpack_hidden_states(hidden_states, lengths):
44
+ bsz = lengths.shape[0]
45
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
46
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
47
+ bsz, torch.max(lengths), hidden_states.shape[-1]
48
+ )
49
+ hidden_states = torch.where(
50
+ sequence_mask, hidden_states, 0
51
+ ) # 3d (bsz, max_input_len, d)
52
+ return hidden_states
53
+
54
+
55
+ class RMSNorm(nn.Module):
56
+ def __init__(self, hidden_size, eps=1e-6):
57
+ """
58
+ RMSNorm is equivalent to T5LayerNorm
59
+ """
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
66
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
67
+
68
+ # convert into half-precision if necessary
69
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
70
+ hidden_states = hidden_states.to(self.weight.dtype)
71
+
72
+ return self.weight * hidden_states
73
+
74
+
75
+ class OmniWhisperAttention(nn.Module):
76
+ def __init__(self, embed_dim, num_heads, causal=False):
77
+ super().__init__()
78
+ self.embed_dim = embed_dim
79
+ self.num_heads = num_heads
80
+ self.head_dim = embed_dim // num_heads
81
+
82
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
83
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
84
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
85
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
86
+
87
+ self.causal = causal
88
+
89
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
90
+ bsz, _ = hidden_states.size()
91
+
92
+ query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
93
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
94
+ value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
95
+
96
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
97
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
98
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
99
+ max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
100
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
101
+ attn_output = self.out_proj(attn_output)
102
+ return attn_output
103
+
104
+
105
+ class OmniWhisperTransformerLayer(nn.Module):
106
+ def __init__(
107
+ self,
108
+ act,
109
+ d_model,
110
+ encoder_attention_heads,
111
+ encoder_ffn_dim,
112
+ causal,
113
+ ln_type="LayerNorm",
114
+ ):
115
+ super().__init__()
116
+ self.embed_dim = d_model
117
+ self.self_attn = OmniWhisperAttention(
118
+ self.embed_dim, encoder_attention_heads, causal
119
+ )
120
+
121
+ if ln_type == "LayerNorm":
122
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
123
+ elif ln_type == "RMSNorm":
124
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
125
+ else:
126
+ raise ValueError(f"Unknown ln_type: {ln_type}")
127
+
128
+ self.activation_fn = act
129
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
130
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
131
+
132
+ if ln_type == "LayerNorm":
133
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
134
+ elif ln_type == "RMSNorm":
135
+ self.final_layer_norm = RMSNorm(self.embed_dim)
136
+ else:
137
+ raise ValueError(f"Unknown ln_type: {ln_type}")
138
+
139
+ def forward(
140
+ self, hidden_states: torch.Tensor, seq_len: torch.Tensor
141
+ ) -> torch.Tensor:
142
+ residual = hidden_states
143
+ hidden_states = self.self_attn_layer_norm(hidden_states)
144
+ hidden_states = self.self_attn(hidden_states, seq_len)
145
+ hidden_states = residual + hidden_states
146
+ residual = hidden_states
147
+ hidden_states = self.final_layer_norm(hidden_states)
148
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
149
+ hidden_states = self.fc2(hidden_states)
150
+ hidden_states = residual + hidden_states
151
+
152
+ if (
153
+ hidden_states.dtype == torch.float16
154
+ or hidden_states.dtype == torch.bfloat16
155
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
156
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
157
+ hidden_states = torch.clamp(
158
+ hidden_states, min=-clamp_value, max=clamp_value
159
+ )
160
+ return hidden_states
161
+
162
+
163
+ class OmniAudioEncoder(nn.Module):
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ config._attn_implementation = 'flash_attention_2' #
167
+ self.config = config
168
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
169
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
170
+
171
+ self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
172
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
173
+ stride=config.stride_size, padding=1)
174
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
175
+
176
+ self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
177
+ ACT2FN[config.activation_function],
178
+ config.d_model,
179
+ config.encoder_attention_heads,
180
+ config.encoder_ffn_dim,
181
+ False) for _ in range(config.encoder_layers)])
182
+ self.layer_norm = nn.LayerNorm(config.d_model)
183
+
184
+ @torch.no_grad()
185
+ def fake_input(self, device):
186
+ input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
187
+ encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
188
+ bridge_length = torch.ones([2], dtype=torch.int32, device=device)
189
+ return input_features, encoder_length, bridge_length
190
+
191
+ def forward(
192
+ self,
193
+ input_features,
194
+ output_length,
195
+ ):
196
+ input_features = input_features.to(self.conv1.weight.dtype)
197
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
198
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
199
+ inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
200
+ bsz, tgt_len, _ = inputs_embeds.size()
201
+ if tgt_len < self.positional_embedding.shape[0]:
202
+ current_positional_embedding = self.positional_embedding[:tgt_len]
203
+ else:
204
+ current_positional_embedding = self.positional_embedding
205
+ hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
206
+
207
+ # packing hidden states
208
+ attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
209
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
210
+ self.config.d_model)
211
+
212
+ for idx, encoder_layer in enumerate(self.layers):
213
+ hidden_states = encoder_layer(hidden_states, output_length)
214
+ hidden_states = self.layer_norm(hidden_states)
215
+ # unpacking
216
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
217
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
218
+ return hidden_states
219
+
220
+
221
+ class CasualConvTranspose1d(nn.Module): # 反卷积
222
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
223
+ super().__init__()
224
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
225
+ self.norm = nn.GroupNorm(1, out_channels)
226
+ self.in_channels = in_channels
227
+ self.out_channels = out_channels
228
+
229
+ def forward(self, hidden_states, input_length, output_dim=None):
230
+ kernel_size = self.conv.kernel_size[0]
231
+ stride = self.conv.stride[0]
232
+ bsz = input_length.shape[0]
233
+
234
+ if output_dim is None:
235
+ output_dim = hidden_states.dim()
236
+ if hidden_states.dim() <= 2: # unpack sequence to 3d
237
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
238
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
239
+ self.in_channels)
240
+ hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
241
+
242
+ hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
243
+ hidden_states = self.conv(hidden_states)
244
+ hidden_states = self.norm(hidden_states)
245
+ hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
246
+
247
+ casual_padding_right = max(0, kernel_size - stride)
248
+ hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
249
+ :]
250
+ output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
251
+ sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
252
+ if output_dim <= 2:
253
+ hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
254
+ else:
255
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
256
+ hidden_states = hidden_states[:, :torch.max(output_length), :] # 截断到最大有效长度
257
+ return hidden_states, output_length
258
+
259
+
260
+ class MelSpecRefineNet(nn.Module):
261
+ """
262
+ # post net, coarse to refined mel-spectrogram frames
263
+ # ref1: Autoregressive Speech Synthesis without Vector Quantization
264
+ # ref2: CosyVoice length_regulator.py
265
+ # ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
266
+ """
267
+
268
+ def __init__(self, encoder_config, vocoder_config):
269
+ super().__init__()
270
+ self.encoder_config = encoder_config
271
+ self.vocoder_config = vocoder_config
272
+
273
+ layers = nn.ModuleList([])
274
+ in_channels = self.vocoder_config.num_mel_bins
275
+ for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
276
+ module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
277
+ in_channels = out_channels
278
+ norm = nn.GroupNorm(1, out_channels)
279
+ act = nn.Mish()
280
+ layers.extend([module, norm, act])
281
+ layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
282
+ self.layers = nn.Sequential(*layers)
283
+
284
+ def compute_output_length(self, input_length):
285
+ output_length = input_length.to(
286
+ torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
287
+ output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
288
+ return output_length.to(torch.int64)
289
+
290
+ def forward(self, coarse_mel, input_length, output_length=None):
291
+ bsz, _, d = coarse_mel.shape
292
+ assert (d == self.vocoder_config.num_mel_bins)
293
+ if output_length is None or not self.training:
294
+ output_length = self.compute_output_length(input_length)
295
+ coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
296
+ coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
297
+ mode='nearest').to(default_dtype)
298
+ refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
299
+ coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
300
+ refined_mel += coarse_mel # residual conntection
301
+ sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
302
+ coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
303
+ refined_mel = torch.where(sequence_mask, refined_mel, 0)
304
+ return refined_mel, coarse_mel, output_length
305
+
306
+
307
+ @dataclass
308
+ class OmniAudioDecoderOutput(ModelOutput):
309
+ refined_mel: Optional[torch.FloatTensor] = None
310
+ coarse_mel: Optional[torch.FloatTensor] = None
311
+ mel_length: Optional[torch.Tensor] = None
312
+ hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
313
+ output_length_before_dconv2: Optional[torch.Tensor] = None
314
+
315
+
316
+ class OmniAudioDecoder(nn.Module):
317
+ def __init__(self, config):
318
+ super().__init__()
319
+ self.config = config.audio_config
320
+ self.vocoder_config = config.vocoder_config
321
+ self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
322
+
323
+ self.dconv1 = CasualConvTranspose1d(
324
+ self.config.d_model,
325
+ self.config.d_model,
326
+ self.config.decoder_kernel_size,
327
+ self.config.avg_pooler,
328
+ )
329
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
330
+ # causal transformer layers
331
+ self.layers = nn.ModuleList(
332
+ [OmniWhisperTransformerLayer(
333
+ ACT2FN[self.config.activation_function],
334
+ self.config.d_model,
335
+ self.config.decoder_attention_heads,
336
+ self.config.decoder_ffn_dim,
337
+ True # causal
338
+ ) for _ in range(self.config.decoder_layers)
339
+ ])
340
+ self.layer_norm = nn.LayerNorm(self.config.d_model)
341
+ self.dconv2 = CasualConvTranspose1d(
342
+ self.config.d_model,
343
+ self.vocoder_config.num_mel_bins,
344
+ self.config.decoder_kernel_size,
345
+ self.config.decoder_stride_size
346
+ )
347
+ self.post_net = MelSpecRefineNet(config.audio_config, config.vocoder_config)
348
+ self.gradient_checkpointing = True
349
+
350
+ @torch.no_grad()
351
+ def fake_input(self, device):
352
+ audio_embed = torch.rand([1, 10, self.config.d_model], dtype=torch.float32, device=device)
353
+ input_length = torch.ones([1], dtype=torch.int32, device=device) * 10
354
+ mel_labels_length = self.post_net.compute_output_length(input_length)
355
+ return audio_embed, input_length, None, mel_labels_length
356
+
357
+ def forward(self,
358
+ audio_embed,
359
+ input_length,
360
+ mel_labels=None,
361
+ mel_labels_length=None,
362
+ fake_input=False,
363
+ ):
364
+ if fake_input:
365
+ audio_embed, input_length, mel_labels, mel_labels_length = self.fake_input(self.layer_norm.weight.device)
366
+
367
+ assert (audio_embed.shape[-1] == self.config.d_model)
368
+ audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
369
+ audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
370
+ _, tgt_len, _ = audio_embed.size()
371
+ if tgt_len < self.positional_embedding.shape[0]:
372
+ current_positional_embedding = self.positional_embedding[:tgt_len]
373
+ else:
374
+ current_positional_embedding = self.positional_embedding
375
+ hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
376
+
377
+ # packing hidden states
378
+ attention_mask, _ = get_sequence_mask(hidden_states, output_length)
379
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
380
+
381
+ for idx, encoder_layer in enumerate(self.layers):
382
+ hidden_states = encoder_layer(hidden_states, output_length)
383
+
384
+ hidden_states = self.layer_norm(hidden_states)
385
+ hidden_states_before_dconv2 = hidden_states
386
+ output_length_before_dconv2 = output_length
387
+
388
+ coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
389
+ refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
390
+
391
+ return OmniAudioDecoderOutput(
392
+ refined_mel=refined_mel,
393
+ coarse_mel=coarse_mel,
394
+ mel_length=mel_labels_length,
395
+ hidden_states_before_dconv2=hidden_states_before_dconv2,
396
+ output_length_before_dconv2=output_length_before_dconv2,
397
+ )
398
+
399
+
400
+ class OmniAudioVQBridgeTokenizer(nn.Module):
401
+ def __init__(self, config):
402
+ super().__init__()
403
+ self.config = config.audio_config
404
+ self.gradient_checkpointing = False
405
+ self.intermediate_dim = self.config.d_model * self.config.avg_pooler
406
+ self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
407
+ self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
408
+
409
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
410
+ self.act_fn = ACT2FN['silu']
411
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
412
+ self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
413
+
414
+ self.vq_list = nn.ModuleList([])
415
+ for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
416
+ vq_config = copy.deepcopy(self.config.vq_config)
417
+ vq_config.dim = self.intermediate_dim
418
+ vq_config.codebook_size = codebook_size
419
+ self.vq_list.append(VectorQuantize(vq_config))
420
+ for vq_layer in self.vq_list:
421
+ deepspeed.zero.register_external_parameter(self, vq_layer.codebook.embed)
422
+
423
+ def rvq_op(self, inputs, output_length):
424
+ def rvq_layer_op(vq_layer, residual_encoding, output_length):
425
+ q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
426
+ residual_encoding = residual_encoding.float() - q_v_i.float()
427
+ residual_encoding = residual_encoding.to(inputs.dtype)
428
+ return residual_encoding, code_ids_i
429
+
430
+ cmt_loss, residual_encoding = 0, inputs
431
+ code_ids_list = []
432
+ for i, vq_layer in enumerate(self.vq_list):
433
+ residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
434
+ code_ids_list.append(code_ids_i)
435
+ return torch.stack(code_ids_list, -1)
436
+
437
+ def forward(self, x, output_length):
438
+ batch_size, _, _ = x.shape
439
+ output_length = output_length.to(x.device)
440
+
441
+ if x.shape[1] % self.config.avg_pooler != 0:
442
+ x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
443
+ xt = x.permute(0, 2, 1)
444
+ g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
445
+ u = self.up_proj(xt).permute(0, 2, 1)
446
+ x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
447
+
448
+ c = self.down_proj(self.act_fn(g) * u)
449
+ res = self.layer_norm(c + x)
450
+ valid_mask, _ = get_sequence_mask(res, output_length)
451
+ code_ids = self.rvq_op(res, output_length)
452
+ code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
453
+ return code_ids
454
+
455
+ @torch.no_grad()
456
+ def decode(self, code_ids):
457
+ vq_num = code_ids.shape[-1]
458
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
459
+ decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
460
+ return decoder_emb
461
+
462
+ @torch.no_grad()
463
+ def recover(self, code_ids):
464
+ vq_num = code_ids.shape[-1]
465
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
466
+ return res
467
+
468
+
469
+ class FlowmatchingPrenet(nn.Module):
470
+ def __init__(
471
+ self,
472
+ input_feat_dim,
473
+ out_feat_dim,
474
+ d_model,
475
+ attention_heads,
476
+ ffn_dim,
477
+ nlayers,
478
+ activation_function,
479
+ max_source_positions,
480
+ target_mel_length_scale_ratio,
481
+ ):
482
+ super().__init__()
483
+
484
+ self.d_model = d_model
485
+ self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
486
+ self.gradient_checkpointing = False
487
+
488
+ self.register_buffer(
489
+ "positional_embedding", sinusoids(max_source_positions, d_model)
490
+ )
491
+
492
+ self.in_mlp = nn.Sequential(
493
+ nn.Linear(input_feat_dim, d_model * 4),
494
+ nn.SiLU(),
495
+ nn.Linear(d_model * 4, d_model),
496
+ )
497
+
498
+ self.transformer_layers = nn.ModuleList(
499
+ [
500
+ OmniWhisperTransformerLayer(
501
+ act=ACT2FN[activation_function],
502
+ d_model=d_model,
503
+ encoder_attention_heads=attention_heads,
504
+ encoder_ffn_dim=ffn_dim,
505
+ causal=True, # causal
506
+ ln_type="RMSNorm",
507
+ )
508
+ for _ in range(nlayers)
509
+ ]
510
+ )
511
+
512
+ self.final_norm = RMSNorm(self.d_model)
513
+ self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
514
+
515
+ def compute_output_length(self, input_length):
516
+ output_length = input_length.float() * self.target_mel_length_scale_ratio
517
+ return output_length.to(torch.int64)
518
+
519
+ def forward(self, input_feat, input_length, output_length=None):
520
+ """
521
+ Args:
522
+ input_feat: [B, T, input_feat_dim]
523
+ input_length: [B]
524
+ output_length: [B]
525
+
526
+ """
527
+ if output_length is None or not self.training:
528
+ output_length = self.compute_output_length(input_length)
529
+
530
+ input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
531
+ orig_dtype = input_feat.dtype
532
+
533
+ input_feat = F.interpolate(
534
+ input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
535
+ size=output_length.max(),
536
+ mode="nearest",
537
+ ).to(orig_dtype)
538
+ input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
539
+ hidden_states = self.in_mlp(input_feat)
540
+
541
+ # packing hidden states
542
+ bsz, tgt_len, d_model = hidden_states.shape
543
+ attention_mask, unpacking_index = get_sequence_mask(
544
+ hidden_states, output_length
545
+ )
546
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
547
+ torch.sum(output_length), self.d_model
548
+ )
549
+
550
+ for idx, encoder_layer in enumerate(self.transformer_layers):
551
+ hidden_states = encoder_layer(hidden_states, output_length)
552
+
553
+ # unpacking
554
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
555
+ bsz, tgt_len, d_model
556
+ )
557
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
558
+
559
+ hidden_states = self.final_norm(hidden_states)
560
+ output = self.out_proj(hidden_states)
561
+ return output, output_length
562
+
563
+
564
+ @dataclass
565
+ class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
566
+ flow_matching_mel: Optional[torch.FloatTensor] = None
567
+ flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
568
+
569
+
570
+ class OmniAudioFlowMatchingDecoder(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.config = config.flow_matching_config
574
+ self.in_channels = self.config.in_channels
575
+ self.spk_emb_dim = self.config.spk_emb_dim
576
+ self.diffusion_steps = self.config.diffusion_steps
577
+ self.cal_mel_mae = self.config.cal_mel_mae
578
+ self.forward_step = -1
579
+
580
+ self.prenet = FlowmatchingPrenet(
581
+ input_feat_dim=self.config.prenet_in_dim,
582
+ out_feat_dim=self.config.prenet_out_dim,
583
+ d_model=self.config.prenet_d_model,
584
+ attention_heads=self.config.prenet_attention_heads,
585
+ ffn_dim=self.config.prenet_ffn_dim,
586
+ nlayers=self.config.prenet_nlayers,
587
+ activation_function=self.config.prenet_activation_function,
588
+ max_source_positions=self.config.prenet_max_source_positions,
589
+ target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
590
+ )
591
+
592
+ self.conditional_decoder = ConditionalDecoder(
593
+ in_channels=self.in_channels * 2 + self.spk_emb_dim,
594
+ out_channels=self.in_channels,
595
+ causal=True,
596
+ channels=self.config.channels,
597
+ dropout=self.config.dropout,
598
+ attention_head_dim=self.config.attention_head_dim,
599
+ n_blocks=self.config.n_blocks,
600
+ num_mid_blocks=self.config.num_mid_blocks,
601
+ num_heads=self.config.num_heads,
602
+ act_fn=self.config.act_fn,
603
+ )
604
+
605
+ self.cfm = ConditionalCFM(
606
+ in_channels=self.in_channels,
607
+ cfm_params=self.config.cfm_params,
608
+ n_spks=0,
609
+ spk_emb_dim=self.spk_emb_dim,
610
+ )
611
+
612
+
613
+ def unpack_hidden_states(self, hidden_states, output_length):
614
+ unpacked = unpack_hidden_states(hidden_states, output_length)
615
+ return unpacked, output_length
616
+
617
+ def forward(
618
+ self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
619
+ ):
620
+ """
621
+ :param refined_mel: [bs, max_input_len, mel_bin]
622
+ :param input_length: [batch_size]
623
+ :param refined_mel: [bs, mel_bin, max_input_len]
624
+ :return:
625
+ """
626
+ self.forward_step += 1
627
+
628
+ orig_dtype = refined_mel.dtype
629
+ prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
630
+ prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
631
+
632
+ if self.prenet is not None:
633
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
634
+ if mel_labels_length is None:
635
+ mel_labels_length = self.prenet.compute_output_length(input_length)
636
+ refined_mel, input_length = self.prenet(
637
+ refined_mel, input_length, mel_labels_length
638
+ )
639
+
640
+ float_dtype = refined_mel.dtype
641
+ refined_mel = refined_mel.float()
642
+ input_length = input_length.long()
643
+
644
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
645
+ sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
646
+ refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
647
+ sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
648
+
649
+ fm_mel = self.cfm.forward(
650
+ estimator=self.conditional_decoder,
651
+ mu=refined_mel.to(float_dtype),
652
+ mask=sequence_mask.float(),
653
+ n_timesteps=self.diffusion_steps,
654
+ )
655
+ return OmniAudioFlowMatchingDecoderOutput(
656
+ flow_matching_mel=fm_mel.transpose(1, 2),
657
+ flow_matching_mel_lengths=mel_labels_length,
658
+ )
config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "_",
3
+ "architectures": [
4
+ "OmniForCausalLM"
5
+ ],
6
+ "attention_qkv_bias": true,
7
+ "attention_qkv_pack": true,
8
+ "audio_config": {
9
+ "audio_head_transformer_layers": 3,
10
+ "audio_delim_token_id": 151674,
11
+ "audio_end_token_id": 151658,
12
+ "audio_pad_token_id": 151659,
13
+ "audio_start_token_id": 151657,
14
+ "audiogen_end_token_id": 151679,
15
+ "audiogen_start_token_id": 151678,
16
+ "audiotext_end_token_id": 151676,
17
+ "audiotext_pad_token_id": 151677,
18
+ "audiotext_start_token_id": 151675,
19
+ "avg_pooler": 4,
20
+ "d_model": 1280,
21
+ "decoder_attention_heads": 20,
22
+ "decoder_ffn_dim": 5120,
23
+ "decoder_kernel_size": 3,
24
+ "decoder_layers": 8,
25
+ "decoder_stride_size": 2,
26
+ "enable": true,
27
+ "encoder_attention_heads": 20,
28
+ "encoder_ffn_dim": 5120,
29
+ "encoder_layers": 32,
30
+ "hop_length": 160,
31
+ "kernel_size": 3,
32
+ "max_audio_seconds": 30,
33
+ "n_fft": 400,
34
+ "num_mel_bins": 128,
35
+ "sampling_rate": 16000,
36
+ "stride_size": 2,
37
+ "split_overlap": 0.0,
38
+ "vq_config":{
39
+ "enable": true,
40
+ "codebook_sizes": [8192, 4096, 2048, 1024, 1024, 1024, 1024, 1024]
41
+ }
42
+ },
43
+ "auto_map": {
44
+ "AutoConfig": "configuration_omni.OmniConfig",
45
+ "AutoModelForCausalLM": "modeling_omni.OmniForCausalLM"
46
+ },
47
+ "omni_tokenizer_type": "auto",
48
+ "bos_token_id": 1,
49
+ "eos_token_id": 2,
50
+ "flow_matching_config": {
51
+ "enable": true,
52
+ "use_hires_mel": true,
53
+ "sampling_rate": 24000,
54
+ "hop_length": 480,
55
+ "max_audio_seconds": 30,
56
+ "split_overlap": 0.1,
57
+ "use_hidden_states_before_dconv2": true,
58
+ "prenet_in_dim": 1280,
59
+ "prenet_out_dim": 80,
60
+ "prenet_d_model": 512,
61
+ "prenet_attention_heads": 8,
62
+ "prenet_ffn_dim": 2048,
63
+ "prenet_nlayers": 12,
64
+ "prenet_activation_function": "gelu",
65
+ "prenet_max_source_positions": 5000,
66
+ "prenet_target_mel_length_scale_ratio": 1.0,
67
+ "prenet_loss_weight": 1.0,
68
+ "unet_use_omni_attn": false,
69
+ "loss_weight": 1.0,
70
+ "in_channels": 80,
71
+ "spk_emb_dim": 0,
72
+ "diffusion_steps": 10,
73
+ "channels": [256],
74
+ "dropout": 0.0,
75
+ "attention_head_dim": 64,
76
+ "n_blocks": 4,
77
+ "num_mid_blocks": 12,
78
+ "num_heads": 8,
79
+ "act_fn": "gelu",
80
+ "cal_mel_mae": true,
81
+ "cfm_params": {
82
+ "sigma_min": 1e-6,
83
+ "solver": "euler",
84
+ "t_scheduler": "cosine",
85
+ "training_cfg_rate": 0.2,
86
+ "inference_cfg_rate": 0.7,
87
+ "reg_loss_type": "l1"
88
+ }
89
+ },
90
+ "head_dim": 128,
91
+ "hidden_act": "silu",
92
+ "hidden_size": 3584,
93
+ "initializer_range": 0.02,
94
+ "intermediate_size": 18944,
95
+ "max_position_embeddings": 65536,
96
+ "max_window_layers": 28,
97
+ "model_type": "omni",
98
+ "multimodal": [
99
+ "audio",
100
+ "audiogen"
101
+ ],
102
+ "multimodal_special_token_list": [
103
+ 151657,
104
+ 151658,
105
+ 151659,
106
+ 151674,
107
+ 151675,
108
+ 151676,
109
+ 151677,
110
+ 151678,
111
+ 151679
112
+ ],
113
+ "num_attention_heads": 28,
114
+ "num_hidden_layers": 28,
115
+ "num_key_value_heads": 4,
116
+ "pad_token_id": 0,
117
+ "position_embedding_type": "rope",
118
+ "rms_norm_eps": 1e-06,
119
+ "rope_theta": 1000000.0,
120
+ "sliding_window": 131072,
121
+ "sparse_attention_heads": null,
122
+ "sparse_attention_layers": [],
123
+ "tie_word_embeddings": false,
124
+ "torch_dtype": "bfloat16",
125
+ "train_multimodal_special_tokens_only": false,
126
+ "transformers_version": "4.45.0.dev0",
127
+ "use_cache": false,
128
+ "use_norm_head": false,
129
+ "use_sliding_window": false,
130
+ "video_config": {
131
+ "_name_or_path": "",
132
+ "_attn_implementation": "flash_attention_2",
133
+ "decode_way": "1fps",
134
+ "depth": 32,
135
+ "embed_dim": 1280,
136
+ "enable": false,
137
+ "hidden_act": "quick_gelu",
138
+ "hidden_size": 3584,
139
+ "image_delimiter_token_id": 151688,
140
+ "image_end_token_id": 151680,
141
+ "image_line_token_id": 151682,
142
+ "image_mean": [
143
+ 0.48145466,
144
+ 0.4578275,
145
+ 0.40821073
146
+ ],
147
+ "image_pad_token_id": 151681,
148
+ "image_size": 224,
149
+ "image_start_token_id": 151679,
150
+ "image_std": [
151
+ 0.26862954,
152
+ 0.26130258,
153
+ 0.27577711
154
+ ],
155
+ "in_channels": 3,
156
+ "in_chans": 3,
157
+ "intermediate_size": 3072,
158
+ "layer_norm_eps": 1e-05,
159
+ "max_frame_num": 32,
160
+ "max_length": 20,
161
+ "max_pixels": 602112,
162
+ "merge_size": 2,
163
+ "min_length": 0,
164
+ "min_pixels": 3136,
165
+ "mlp_ratio": 4,
166
+ "model_type": "clip_vision_model",
167
+ "num_attention_heads": 12,
168
+ "num_channels": 3,
169
+ "num_heads": 16,
170
+ "num_hidden_layers": 12,
171
+ "patch_size": 14,
172
+ "spatial_merge_size": 2,
173
+ "spatial_patch_size": 14,
174
+ "temporal_patch_size": 2,
175
+ "video_end_token_id": 151696,
176
+ "video_place_token_id": 151694,
177
+ "video_start_token_id": 151695
178
+ },
179
+ "visual_config": {
180
+ "_name_or_path": "",
181
+ "_attn_implementation": "flash_attention_2",
182
+ "depth": 32,
183
+ "diversity_penalty": 0.0,
184
+ "do_sample": false,
185
+ "early_stopping": false,
186
+ "embed_dim": 1280,
187
+ "enable": false,
188
+ "hidden_act": "quick_gelu",
189
+ "hidden_size": 3584,
190
+ "image_delimiter_token_id": 151688,
191
+ "image_end_token_id": 151680,
192
+ "image_line_token_id": 151682,
193
+ "image_mean": [
194
+ 0.48145466,
195
+ 0.4578275,
196
+ 0.40821073
197
+ ],
198
+ "image_pad_token_id": 151681,
199
+ "image_size": 224,
200
+ "image_start_token_id": 151679,
201
+ "image_std": [
202
+ 0.26862954,
203
+ 0.26130258,
204
+ 0.27577711
205
+ ],
206
+ "in_channels": 3,
207
+ "in_chans": 3,
208
+ "intermediate_size": 3072,
209
+ "layer_norm_eps": 1e-05,
210
+ "length_penalty": 1.0,
211
+ "max_length": 20,
212
+ "max_pixels": 3211264,
213
+ "merge_size": 2,
214
+ "min_length": 0,
215
+ "min_pixels": 3136,
216
+ "mlp_ratio": 4,
217
+ "model_type": "clip_vision_model",
218
+ "num_attention_heads": 12,
219
+ "num_channels": 3,
220
+ "num_heads": 16,
221
+ "num_hidden_layers": 12,
222
+ "patch_size": 14,
223
+ "projection_dim": 512,
224
+ "spatial_merge_size": 2,
225
+ "spatial_patch_size": 14,
226
+ "temporal_patch_size": 2
227
+ },
228
+ "vocab_size": 152064,
229
+ "vocoder_config":{
230
+ "enable": true,
231
+ "enable_multi_scale": true,
232
+ "max_audio_seconds": 30,
233
+ "sampling_rate": 16000,
234
+ "hop_length": 256,
235
+ "split_overlap": 0.0,
236
+ "n_fft": 1024,
237
+ "num_mel_bins": 80,
238
+ "channels": [256, 256, 256, 256, 256]
239
+ }
240
+ }
configuration_omni.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Baichuan Inc. All Rights Reserved.
2
+
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+ from transformers import WhisperConfig
25
+ from transformers import CLIPVisionConfig
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class OmniConfig(PretrainedConfig):
31
+ model_type = "omni"
32
+ keys_to_ignore_at_inference = ["past_key_values"]
33
+
34
+ def __init__(
35
+ self,
36
+ vocab_size=125696,
37
+ hidden_size=4096,
38
+ intermediate_size=11008,
39
+ num_hidden_layers=32,
40
+ num_attention_heads=32,
41
+ num_key_value_heads=None,
42
+ sparse_attention_heads=None,
43
+ sparse_attention_layers=[],
44
+ head_dim=None,
45
+ attention_qkv_pack=True,
46
+ attention_qkv_bias=False,
47
+ use_norm_head=True,
48
+ hidden_act="silu",
49
+ max_position_embeddings=4096,
50
+ position_embedding_type="rope",
51
+ initializer_range=0.02,
52
+ rms_norm_eps=1e-6,
53
+ use_cache=True,
54
+ pad_token_id=0,
55
+ bos_token_id=1,
56
+ eos_token_id=2,
57
+ tie_word_embeddings=False,
58
+ audio_config=None,
59
+ visual_config=None,
60
+ video_config=None,
61
+ vocoder_config=None,
62
+ flow_matching_config=None,
63
+ **kwargs,
64
+ ):
65
+ self.vocab_size = vocab_size
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_size = hidden_size
68
+ self.intermediate_size = intermediate_size
69
+ self.num_hidden_layers = num_hidden_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
72
+ self.sparse_attention_heads = sparse_attention_heads
73
+ self.sparse_attention_layers = sparse_attention_layers
74
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
75
+ self.attention_qkv_pack = attention_qkv_pack
76
+ self.attention_qkv_bias = attention_qkv_bias
77
+ self.use_norm_head = use_norm_head
78
+ self.hidden_act = hidden_act
79
+ self.position_embedding_type = position_embedding_type
80
+ self.initializer_range = initializer_range
81
+ self.rms_norm_eps = rms_norm_eps
82
+ self.use_cache = use_cache
83
+ assert self.position_embedding_type.lower() in ("rope", "alibi")
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
91
+ if audio_config is not None:
92
+ self.audio_config = WhisperConfig(**audio_config)
93
+ if self.audio_config.vq_config is not None:
94
+ self.audio_config.vq_config = PretrainedConfig(**self.audio_config.vq_config)
95
+ if vocoder_config is not None:
96
+ self.vocoder_config = WhisperConfig(**vocoder_config)
97
+ if flow_matching_config is not None:
98
+ self.flow_matching_config = PretrainedConfig(**flow_matching_config)
99
+ self.flow_matching_config.cfm_params = PretrainedConfig(**self.flow_matching_config.cfm_params)
100
+ if visual_config is not None:
101
+ self.visual_config = CLIPVisionConfig(**visual_config)
102
+ if video_config is not None:
103
+ self.video_config = CLIPVisionConfig(**video_config)
104
+
105
+
106
+ def to_diff_dict(self):
107
+ data = super().to_diff_dict()
108
+ data["model_type"] = self.model_type
109
+ return data
110
+
111
+ def get_rotary_base(self):
112
+ if hasattr(self, "rotary_emb_base"):
113
+ return self.rotary_emb_base
114
+ else:
115
+ return self.rope_theta
116
+
117
+ if __name__ == '__main__':
118
+ from transformers import AutoConfig
119
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
120
+ print(config)
flow_matching.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
2
+ """
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ from abc import ABC
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Optional
22
+
23
+ import torch.nn as nn
24
+ from einops import pack, rearrange, repeat
25
+ from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
26
+ from .matcha_transformer import BasicTransformerBlock
27
+ from omegaconf import DictConfig
28
+
29
+
30
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
31
+ assert mask.dtype == torch.bool
32
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
33
+ mask = mask.to(dtype)
34
+ # attention mask bias
35
+ # NOTE(Mddct): torch.finfo jit issues
36
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
37
+ mask = (1.0 - mask) * torch.finfo(dtype).min
38
+ return mask
39
+
40
+
41
+ def subsequent_chunk_mask(
42
+ size: int,
43
+ chunk_size: int,
44
+ num_left_chunks: int = -1,
45
+ device: torch.device = torch.device("cpu"),
46
+ ) -> torch.Tensor:
47
+ """Create mask for subsequent steps (size, size) with chunk size,
48
+ this is for streaming encoder
49
+
50
+ Args:
51
+ size (int): size of mask
52
+ chunk_size (int): size of chunk
53
+ num_left_chunks (int): number of left chunks
54
+ <0: use full chunk
55
+ >=0: use num_left_chunks
56
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
57
+
58
+ Returns:
59
+ torch.Tensor: mask
60
+
61
+ Examples:
62
+ >>> subsequent_chunk_mask(4, 2)
63
+ [[1, 1, 0, 0],
64
+ [1, 1, 0, 0],
65
+ [1, 1, 1, 1],
66
+ [1, 1, 1, 1]]
67
+ """
68
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
69
+ # actually this is not needed after we have inference cache implemented, will remove it later
70
+ pos_idx = torch.arange(size, device=device)
71
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
72
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
73
+ return ret
74
+
75
+ def subsequent_mask(
76
+ size: int,
77
+ device: torch.device = torch.device("cpu"),
78
+ ) -> torch.Tensor:
79
+ """Create mask for subsequent steps (size, size).
80
+
81
+ This mask is used only in decoder which works in an auto-regressive mode.
82
+ This means the current step could only do attention with its left steps.
83
+
84
+ In encoder, fully attention is used when streaming is not necessary and
85
+ the sequence is not long. In this case, no attention mask is needed.
86
+
87
+ When streaming is need, chunk-based attention is used in encoder. See
88
+ subsequent_chunk_mask for the chunk-based attention mask.
89
+
90
+ Args:
91
+ size (int): size of mask
92
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
93
+ dtype (torch.device): result dtype
94
+
95
+ Returns:
96
+ torch.Tensor: mask
97
+
98
+ Examples:
99
+ >>> subsequent_mask(3)
100
+ [[1, 0, 0],
101
+ [1, 1, 0],
102
+ [1, 1, 1]]
103
+ """
104
+ arange = torch.arange(size, device=device)
105
+ mask = arange.expand(size, size)
106
+ arange = arange.unsqueeze(-1)
107
+ mask = mask <= arange
108
+ return mask
109
+
110
+
111
+ def add_optional_chunk_mask(xs: torch.Tensor,
112
+ masks: torch.Tensor,
113
+ use_dynamic_chunk: bool,
114
+ use_dynamic_left_chunk: bool,
115
+ decoding_chunk_size: int,
116
+ static_chunk_size: int,
117
+ num_decoding_left_chunks: int,
118
+ enable_full_context: bool = True):
119
+ """ Apply optional mask for encoder.
120
+
121
+ Args:
122
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
123
+ mask (torch.Tensor): mask for xs, (B, 1, L)
124
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
125
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
126
+ training.
127
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
128
+ 0: default for training, use random dynamic chunk.
129
+ <0: for decoding, use full chunk.
130
+ >0: for decoding, use fixed chunk size as set.
131
+ static_chunk_size (int): chunk size for static chunk training/decoding
132
+ if it's greater than 0, if use_dynamic_chunk is true,
133
+ this parameter will be ignored
134
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
135
+ the chunk size is decoding_chunk_size.
136
+ >=0: use num_decoding_left_chunks
137
+ <0: use all left chunks
138
+ enable_full_context (bool):
139
+ True: chunk size is either [1, 25] or full context(max_len)
140
+ False: chunk size ~ U[1, 25]
141
+
142
+ Returns:
143
+ torch.Tensor: chunk mask of the input xs.
144
+ """
145
+ # Whether to use chunk mask or not
146
+ if use_dynamic_chunk:
147
+ max_len = xs.size(1)
148
+ if decoding_chunk_size < 0:
149
+ chunk_size = max_len
150
+ num_left_chunks = -1
151
+ elif decoding_chunk_size > 0:
152
+ chunk_size = decoding_chunk_size
153
+ num_left_chunks = num_decoding_left_chunks
154
+ else:
155
+ # chunk size is either [1, 25] or full context(max_len).
156
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
157
+ # delay, the maximum frame is 100 / 4 = 25.
158
+ chunk_size = torch.randint(1, max_len, (1, )).item()
159
+ num_left_chunks = -1
160
+ if chunk_size > max_len // 2 and enable_full_context:
161
+ chunk_size = max_len
162
+ else:
163
+ chunk_size = chunk_size % 25 + 1
164
+ if use_dynamic_left_chunk:
165
+ max_left_chunks = (max_len - 1) // chunk_size
166
+ num_left_chunks = torch.randint(0, max_left_chunks,
167
+ (1, )).item()
168
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
169
+ num_left_chunks,
170
+ xs.device) # (L, L)
171
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
172
+ chunk_masks = masks & chunk_masks # (B, L, L)
173
+ elif static_chunk_size > 0:
174
+ num_left_chunks = num_decoding_left_chunks
175
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
176
+ num_left_chunks,
177
+ xs.device) # (L, L)
178
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
179
+ chunk_masks = masks & chunk_masks # (B, L, L)
180
+ else:
181
+ chunk_masks = masks
182
+ return chunk_masks
183
+
184
+
185
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
186
+ """Make mask tensor containing indices of padded part.
187
+
188
+ See description of make_non_pad_mask.
189
+
190
+ Args:
191
+ lengths (torch.Tensor): Batch of lengths (B,).
192
+ Returns:
193
+ torch.Tensor: Mask tensor containing indices of padded part.
194
+
195
+ Examples:
196
+ >>> lengths = [5, 3, 2]
197
+ >>> make_pad_mask(lengths)
198
+ masks = [[0, 0, 0, 0 ,0],
199
+ [0, 0, 0, 1, 1],
200
+ [0, 0, 1, 1, 1]]
201
+ """
202
+ batch_size = lengths.size(0)
203
+ max_len = max_len if max_len > 0 else lengths.max().item()
204
+ seq_range = torch.arange(0,
205
+ max_len,
206
+ dtype=torch.int64,
207
+ device=lengths.device)
208
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
209
+ seq_length_expand = lengths.unsqueeze(-1)
210
+ mask = seq_range_expand >= seq_length_expand
211
+ return mask
212
+
213
+ # Causal
214
+ class Transpose(torch.nn.Module):
215
+ def __init__(self, dim0: int, dim1: int):
216
+ super().__init__()
217
+ self.dim0 = dim0
218
+ self.dim1 = dim1
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = torch.transpose(x, self.dim0, self.dim1)
222
+ return x
223
+
224
+ class CausalBlock1D(Block1D):
225
+ def __init__(self, dim: int, dim_out: int):
226
+ super(CausalBlock1D, self).__init__(dim, dim_out)
227
+ self.block = torch.nn.Sequential(
228
+ CausalConv1d(dim, dim_out, 3),
229
+ Transpose(1, 2),
230
+ nn.LayerNorm(dim_out),
231
+ Transpose(1, 2),
232
+ nn.Mish(),
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
236
+ output = self.block(x * mask)
237
+ return output * mask
238
+
239
+ class CausalResnetBlock1D(ResnetBlock1D):
240
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
241
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
242
+ self.block1 = CausalBlock1D(dim, dim_out)
243
+ self.block2 = CausalBlock1D(dim_out, dim_out)
244
+
245
+ class CausalConv1d(torch.nn.Conv1d):
246
+ def __init__(
247
+ self,
248
+ in_channels: int,
249
+ out_channels: int,
250
+ kernel_size: int,
251
+ stride: int = 1,
252
+ dilation: int = 1,
253
+ groups: int = 1,
254
+ bias: bool = True,
255
+ padding_mode: str = 'zeros',
256
+ device=None,
257
+ dtype=None
258
+ ) -> None:
259
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
260
+ kernel_size, stride,
261
+ padding=0, dilation=dilation,
262
+ groups=groups, bias=bias,
263
+ padding_mode=padding_mode,
264
+ device=device, dtype=dtype)
265
+ assert stride == 1
266
+ self.causal_padding = (kernel_size - 1, 0)
267
+
268
+ def forward(self, x: torch.Tensor):
269
+ x = F.pad(x, self.causal_padding)
270
+ x = super(CausalConv1d, self).forward(x)
271
+ return x
272
+
273
+
274
+ class BASECFM(torch.nn.Module, ABC):
275
+ def __init__(
276
+ self,
277
+ n_feats,
278
+ cfm_params,
279
+ n_spks=1,
280
+ spk_emb_dim=128,
281
+ ):
282
+ super().__init__()
283
+ self.n_feats = n_feats
284
+ self.n_spks = n_spks
285
+ self.spk_emb_dim = spk_emb_dim
286
+ self.solver = cfm_params.solver
287
+ if hasattr(cfm_params, "sigma_min"):
288
+ self.sigma_min = cfm_params.sigma_min
289
+ else:
290
+ self.sigma_min = 1e-4
291
+
292
+ self.estimator = None
293
+
294
+ @torch.inference_mode()
295
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
296
+ """Forward diffusion
297
+
298
+ Args:
299
+ mu (torch.Tensor): output of encoder
300
+ shape: (batch_size, n_feats, mel_timesteps)
301
+ mask (torch.Tensor): output_mask
302
+ shape: (batch_size, 1, mel_timesteps)
303
+ n_timesteps (int): number of diffusion steps
304
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
305
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
306
+ shape: (batch_size, spk_emb_dim)
307
+ cond: Not used but kept for future purposes
308
+
309
+ Returns:
310
+ sample: generated mel-spectrogram
311
+ shape: (batch_size, n_feats, mel_timesteps)
312
+ """
313
+ z = torch.randn_like(mu) * temperature
314
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
315
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
316
+
317
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
318
+ """
319
+ Fixed euler solver for ODEs.
320
+ Args:
321
+ x (torch.Tensor): random noise
322
+ t_span (torch.Tensor): n_timesteps interpolated
323
+ shape: (n_timesteps + 1,)
324
+ mu (torch.Tensor): output of encoder
325
+ shape: (batch_size, n_feats, mel_timesteps)
326
+ mask (torch.Tensor): output_mask
327
+ shape: (batch_size, 1, mel_timesteps)
328
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
329
+ shape: (batch_size, spk_emb_dim)
330
+ cond: Not used but kept for future purposes
331
+ """
332
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
333
+
334
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
335
+ # Or in future might add like a return_all_steps flag
336
+ sol = []
337
+
338
+ for step in range(1, len(t_span)):
339
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
340
+
341
+ x = x + dt * dphi_dt
342
+ t = t + dt
343
+ sol.append(x)
344
+ if step < len(t_span) - 1:
345
+ dt = t_span[step + 1] - t
346
+
347
+ return sol[-1]
348
+
349
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
350
+ """Computes diffusion loss
351
+
352
+ Args:
353
+ x1 (torch.Tensor): Target
354
+ shape: (batch_size, n_feats, mel_timesteps)
355
+ mask (torch.Tensor): target mask
356
+ shape: (batch_size, 1, mel_timesteps)
357
+ mu (torch.Tensor): output of encoder
358
+ shape: (batch_size, n_feats, mel_timesteps)
359
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
360
+ shape: (batch_size, spk_emb_dim)
361
+
362
+ Returns:
363
+ loss: conditional flow matching loss
364
+ y: conditional flow
365
+ shape: (batch_size, n_feats, mel_timesteps)
366
+ """
367
+ b, _, t = mu.shape
368
+
369
+ # random timestep
370
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
371
+ # sample noise p(x_0)
372
+ z = torch.randn_like(x1)
373
+
374
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
375
+ u = x1 - (1 - self.sigma_min) * z
376
+
377
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
378
+ torch.sum(mask) * u.shape[1]
379
+ )
380
+ return loss, y
381
+
382
+
383
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
384
+ """Make mask tensor containing indices of padded part.
385
+
386
+ See description of make_non_pad_mask.
387
+
388
+ Args:
389
+ lengths (torch.Tensor): Batch of lengths (B,).
390
+ Returns:
391
+ torch.Tensor: Mask tensor containing indices of padded part.
392
+
393
+ Examples:
394
+ >>> lengths = [5, 3, 2]
395
+ >>> make_pad_mask(lengths)
396
+ masks = [[0, 0, 0, 0 ,0],
397
+ [0, 0, 0, 1, 1],
398
+ [0, 0, 1, 1, 1]]
399
+ """
400
+ batch_size = lengths.size(0)
401
+ max_len = max_len if max_len > 0 else lengths.max().item()
402
+ seq_range = torch.arange(0,
403
+ max_len,
404
+ dtype=torch.int64,
405
+ device=lengths.device)
406
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
407
+ seq_length_expand = lengths.unsqueeze(-1)
408
+ mask = seq_range_expand >= seq_length_expand
409
+ return mask
410
+
411
+
412
+ class ConditionalDecoder(nn.Module):
413
+ def __init__(
414
+ self,
415
+ in_channels,
416
+ out_channels,
417
+ causal=False,
418
+ channels=(256, 256),
419
+ dropout=0.05,
420
+ attention_head_dim=64,
421
+ n_blocks=1,
422
+ num_mid_blocks=2,
423
+ num_heads=4,
424
+ act_fn="snake",
425
+ gradient_checkpointing=True,
426
+ ):
427
+ """
428
+ This decoder requires an input with the same shape of the target. So, if your text content
429
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
430
+ """
431
+ super().__init__()
432
+ channels = tuple(channels)
433
+ self.in_channels = in_channels
434
+ self.out_channels = out_channels
435
+ self.causal = causal
436
+ self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
437
+ self.gradient_checkpointing = gradient_checkpointing
438
+
439
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
440
+ time_embed_dim = channels[0] * 4
441
+ self.time_mlp = TimestepEmbedding(
442
+ in_channels=in_channels,
443
+ time_embed_dim=time_embed_dim,
444
+ act_fn="silu",
445
+ )
446
+ self.down_blocks = nn.ModuleList([])
447
+ self.mid_blocks = nn.ModuleList([])
448
+ self.up_blocks = nn.ModuleList([])
449
+
450
+ output_channel = in_channels
451
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
452
+ input_channel = output_channel
453
+ output_channel = channels[i]
454
+ is_last = i == len(channels) - 1
455
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
456
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
457
+ transformer_blocks = nn.ModuleList(
458
+ [
459
+ BasicTransformerBlock(
460
+ dim=output_channel,
461
+ num_attention_heads=num_heads,
462
+ attention_head_dim=attention_head_dim,
463
+ dropout=dropout,
464
+ activation_fn=act_fn,
465
+ )
466
+ for _ in range(n_blocks)
467
+ ]
468
+ )
469
+ downsample = (
470
+ Downsample1D(output_channel) if not is_last else
471
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
472
+ )
473
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
474
+
475
+ for _ in range(num_mid_blocks):
476
+ input_channel = channels[-1]
477
+ out_channels = channels[-1]
478
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
479
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
480
+
481
+ transformer_blocks = nn.ModuleList(
482
+ [
483
+ BasicTransformerBlock(
484
+ dim=output_channel,
485
+ num_attention_heads=num_heads,
486
+ attention_head_dim=attention_head_dim,
487
+ dropout=dropout,
488
+ activation_fn=act_fn,
489
+ )
490
+ for _ in range(n_blocks)
491
+ ]
492
+ )
493
+
494
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
495
+
496
+ channels = channels[::-1] + (channels[0],)
497
+ for i in range(len(channels) - 1):
498
+ input_channel = channels[i] * 2
499
+ output_channel = channels[i + 1]
500
+ is_last = i == len(channels) - 2
501
+ resnet = CausalResnetBlock1D(
502
+ dim=input_channel,
503
+ dim_out=output_channel,
504
+ time_emb_dim=time_embed_dim,
505
+ ) if self.causal else ResnetBlock1D(
506
+ dim=input_channel,
507
+ dim_out=output_channel,
508
+ time_emb_dim=time_embed_dim,
509
+ )
510
+ transformer_blocks = nn.ModuleList(
511
+ [
512
+ BasicTransformerBlock(
513
+ dim=output_channel,
514
+ num_attention_heads=num_heads,
515
+ attention_head_dim=attention_head_dim,
516
+ dropout=dropout,
517
+ activation_fn=act_fn,
518
+ )
519
+ for _ in range(n_blocks)
520
+ ]
521
+ )
522
+ upsample = (
523
+ Upsample1D(output_channel, use_conv_transpose=True)
524
+ if not is_last
525
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
526
+ )
527
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
528
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
529
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
530
+ self.initialize_weights()
531
+
532
+ def initialize_weights(self):
533
+ for m in self.modules():
534
+ if isinstance(m, nn.Conv1d):
535
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
536
+ if m.bias is not None:
537
+ nn.init.constant_(m.bias, 0)
538
+ elif isinstance(m, nn.GroupNorm):
539
+ nn.init.constant_(m.weight, 1)
540
+ nn.init.constant_(m.bias, 0)
541
+ elif isinstance(m, nn.Linear):
542
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
543
+ if m.bias is not None:
544
+ nn.init.constant_(m.bias, 0)
545
+
546
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
547
+ """Forward pass of the UNet1DConditional model.
548
+
549
+ Args:
550
+ x (torch.Tensor): shape (batch_size, in_channels, time)
551
+ mask (_type_): shape (batch_size, 1, time)
552
+ t (_type_): shape (batch_size)
553
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
554
+ cond (_type_, optional): placeholder for future use. Defaults to None.
555
+
556
+ Raises:
557
+ ValueError: _description_
558
+ ValueError: _description_
559
+
560
+ Returns:
561
+ _type_: _description_
562
+ """
563
+ t = self.time_embeddings(t)
564
+ t = t.to(x.dtype)
565
+ t = self.time_mlp(t)
566
+ x = pack([x, mu], "b * t")[0]
567
+ mask = mask.to(x.dtype)
568
+ if spks is not None:
569
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
570
+ x = pack([x, spks], "b * t")[0]
571
+ if cond is not None:
572
+ x = pack([x, cond], "b * t")[0]
573
+
574
+ hiddens = []
575
+ masks = [mask]
576
+ for resnet, transformer_blocks, downsample in self.down_blocks:
577
+ mask_down = masks[-1]
578
+ x = resnet(x, mask_down, t)
579
+ x = rearrange(x, "b c t -> b t c").contiguous()
580
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
581
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
582
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
583
+ for transformer_block in transformer_blocks:
584
+ if self.gradient_checkpointing and self.training:
585
+ def create_custom_forward(module):
586
+ def custom_forward(*inputs):
587
+ return module(*inputs)
588
+ return custom_forward
589
+ x = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(transformer_block),
591
+ x,
592
+ attn_mask,
593
+ t,
594
+ )
595
+ else:
596
+ x = transformer_block(
597
+ hidden_states=x,
598
+ attention_mask=attn_mask,
599
+ timestep=t,
600
+ )
601
+ x = rearrange(x, "b t c -> b c t").contiguous()
602
+ hiddens.append(x) # Save hidden states for skip connections
603
+ x = downsample(x * mask_down)
604
+ masks.append(mask_down[:, :, ::2])
605
+ masks = masks[:-1]
606
+ mask_mid = masks[-1]
607
+
608
+ for resnet, transformer_blocks in self.mid_blocks:
609
+ x = resnet(x, mask_mid, t)
610
+ x = rearrange(x, "b c t -> b t c").contiguous()
611
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
612
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
613
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
614
+ for transformer_block in transformer_blocks:
615
+ if self.gradient_checkpointing and self.training:
616
+ def create_custom_forward(module):
617
+ def custom_forward(*inputs):
618
+ return module(*inputs)
619
+ return custom_forward
620
+ x = torch.utils.checkpoint.checkpoint(
621
+ create_custom_forward(transformer_block),
622
+ x,
623
+ attn_mask,
624
+ t,
625
+ )
626
+ else:
627
+ x = transformer_block(
628
+ hidden_states=x,
629
+ attention_mask=attn_mask,
630
+ timestep=t,
631
+ )
632
+ x = rearrange(x, "b t c -> b c t").contiguous()
633
+
634
+ for resnet, transformer_blocks, upsample in self.up_blocks:
635
+ mask_up = masks.pop()
636
+ skip = hiddens.pop()
637
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
638
+ x = resnet(x, mask_up, t)
639
+ x = rearrange(x, "b c t -> b t c").contiguous()
640
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
641
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
642
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
643
+ for transformer_block in transformer_blocks:
644
+ if self.gradient_checkpointing and self.training:
645
+ def create_custom_forward(module):
646
+ def custom_forward(*inputs):
647
+ return module(*inputs)
648
+ return custom_forward
649
+ x = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(transformer_block),
651
+ x,
652
+ attn_mask,
653
+ t,
654
+ )
655
+ else:
656
+ x = transformer_block(
657
+ hidden_states=x,
658
+ attention_mask=attn_mask,
659
+ timestep=t,
660
+ )
661
+ x = rearrange(x, "b t c -> b c t").contiguous()
662
+ x = upsample(x * mask_up)
663
+ x = self.final_block(x, mask_up)
664
+ output = self.final_proj(x * mask_up)
665
+ return output * mask
666
+
667
+
668
+ class ConditionalCFM(BASECFM):
669
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
670
+ super().__init__(
671
+ n_feats=in_channels,
672
+ cfm_params=cfm_params,
673
+ n_spks=n_spks,
674
+ spk_emb_dim=spk_emb_dim,
675
+ )
676
+ self.t_scheduler = cfm_params.t_scheduler
677
+ self.training_cfg_rate = cfm_params.training_cfg_rate
678
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
679
+
680
+ @torch.inference_mode()
681
+ def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
682
+ """Forward diffusion
683
+
684
+ Args:
685
+ mu (torch.Tensor): output of encoder
686
+ shape: (batch_size, n_feats, mel_timesteps)
687
+ mask (torch.Tensor): output_mask
688
+ shape: (batch_size, 1, mel_timesteps)
689
+ n_timesteps (int): number of diffusion steps
690
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
691
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
692
+ shape: (batch_size, spk_emb_dim)
693
+ cond: Not used but kept for future purposes
694
+
695
+ Returns:
696
+ sample: generated mel-spectrogram
697
+ shape: (batch_size, n_feats, mel_timesteps)
698
+ """
699
+ z = torch.randn_like(mu) * temperature
700
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
701
+ if self.t_scheduler == 'cosine':
702
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
703
+ return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
704
+
705
+ def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
706
+ """
707
+ Fixed euler solver for ODEs.
708
+ Args:
709
+ x (torch.Tensor): random noise
710
+ t_span (torch.Tensor): n_timesteps interpolated
711
+ shape: (n_timesteps + 1,)
712
+ mu (torch.Tensor): output of encoder
713
+ shape: (batch_size, n_feats, mel_timesteps)
714
+ mask (torch.Tensor): output_mask
715
+ shape: (batch_size, 1, mel_timesteps)
716
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
717
+ shape: (batch_size, spk_emb_dim)
718
+ cond: Not used but kept for future purposes
719
+ """
720
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
721
+
722
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
723
+ # Or in future might add like a return_all_steps flag
724
+ sol = []
725
+
726
+ for step in range(1, len(t_span)):
727
+ dphi_dt = estimator(x, mask, mu, t, spks, cond)
728
+ # Classifier-Free Guidance inference introduced in VoiceBox
729
+ if self.inference_cfg_rate > 0:
730
+ cfg_dphi_dt = estimator(
731
+ x, mask,
732
+ torch.zeros_like(mu), t,
733
+ torch.zeros_like(spks) if spks is not None else None,
734
+ cond=cond
735
+ )
736
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
737
+ self.inference_cfg_rate * cfg_dphi_dt)
738
+ x = x + dt * dphi_dt
739
+ t = t + dt
740
+ sol.append(x)
741
+ if step < len(t_span) - 1:
742
+ dt = t_span[step + 1] - t
743
+
744
+ return sol[-1]
745
+
746
+ def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
747
+ """Computes diffusion loss
748
+
749
+ Args:
750
+ x1 (torch.Tensor): Target
751
+ shape: (batch_size, n_feats, mel_timesteps)
752
+ mask (torch.Tensor): target mask
753
+ shape: (batch_size, 1, mel_timesteps)
754
+ mu (torch.Tensor): output of encoder
755
+ shape: (batch_size, n_feats, mel_timesteps)
756
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
757
+ shape: (batch_size, spk_emb_dim)
758
+
759
+ Returns:
760
+ loss: conditional flow matching loss
761
+ y: conditional flow
762
+ shape: (batch_size, n_feats, mel_timesteps)
763
+ """
764
+ org_dtype = x1.dtype
765
+
766
+ b, _, t = mu.shape
767
+ # random timestep
768
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
769
+ if self.t_scheduler == 'cosine':
770
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
771
+ # sample noise p(x_0)
772
+ z = torch.randn_like(x1)
773
+
774
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
775
+ u = x1 - (1 - self.sigma_min) * z
776
+
777
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
778
+ if self.training_cfg_rate > 0:
779
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
780
+ mu = mu * cfg_mask.view(-1, 1, 1)
781
+ if spks is not None:
782
+ spks = spks * cfg_mask.view(-1, 1)
783
+ if cond is not None:
784
+ cond = cond * cfg_mask.view(-1, 1, 1)
785
+
786
+ pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
787
+ pred = pred.float()
788
+ u = u.float()
789
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
790
+ loss = loss.to(org_dtype)
791
+ return loss, y
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.45.0.dev0"
6
+ }
generation_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+
6
+
7
+ def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
8
+ def _parse_messages(messages, split_role="user"):
9
+ system, rounds = "", []
10
+ round = []
11
+ for i, message in enumerate(messages):
12
+ if message["role"] == "system":
13
+ assert i == 0
14
+ system = message["content"]
15
+ continue
16
+ if message["role"] == split_role and round:
17
+ rounds.append(round)
18
+ round = []
19
+ round.append(message)
20
+ if round:
21
+ rounds.append(round)
22
+ return system, rounds
23
+
24
+ max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
25
+ max_input_tokens = model.config.model_max_length - max_new_tokens
26
+ system, rounds = _parse_messages(messages, split_role="user")
27
+ system_tokens = tokenizer.encode(system)
28
+ max_history_tokens = max_input_tokens - len(system_tokens)
29
+
30
+ history_tokens = []
31
+ for round in rounds[::-1]:
32
+ round_tokens = []
33
+ for message in round:
34
+ if message["role"] == "user":
35
+ round_tokens.append(model.generation_config.user_token_id)
36
+ else:
37
+ round_tokens.append(model.generation_config.assistant_token_id)
38
+ round_tokens.extend(tokenizer.encode(message["content"]))
39
+ if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
40
+ history_tokens = round_tokens + history_tokens # concat left
41
+ if len(history_tokens) < max_history_tokens:
42
+ continue
43
+ break
44
+
45
+ input_tokens = system_tokens + history_tokens
46
+ if messages[-1]["role"] != "assistant":
47
+ input_tokens.append(model.generation_config.assistant_token_id)
48
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
49
+ return torch.LongTensor([input_tokens]).to(model.device)
50
+
51
+
52
+ class TextIterStreamer:
53
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
54
+ self.tokenizer = tokenizer
55
+ self.skip_prompt = skip_prompt
56
+ self.skip_special_tokens = skip_special_tokens
57
+ self.tokens = []
58
+ self.text_queue = Queue()
59
+ self.next_tokens_are_prompt = True
60
+
61
+ def put(self, value):
62
+ if self.skip_prompt and self.next_tokens_are_prompt:
63
+ self.next_tokens_are_prompt = False
64
+ else:
65
+ if len(value.shape) > 1:
66
+ value = value[0]
67
+ self.tokens.extend(value.tolist())
68
+ self.text_queue.put(
69
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
70
+
71
+ def end(self):
72
+ self.text_queue.put(None)
73
+
74
+ def __iter__(self):
75
+ return self
76
+
77
+ def __next__(self):
78
+ value = self.text_queue.get()
79
+ if value is None:
80
+ raise StopIteration()
81
+ else:
82
+ return value
83
+
matcha_components.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import math
27
+ from typing import Optional
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ from diffusers.models.activations import get_activation
34
+
35
+
36
+ class SinusoidalPosEmb(torch.nn.Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.dim = dim
40
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
41
+
42
+ def forward(self, x, scale=1000):
43
+ if x.ndim < 1:
44
+ x = x.unsqueeze(0)
45
+ device = x.device
46
+ half_dim = self.dim // 2
47
+ emb = math.log(10000) / (half_dim - 1)
48
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
49
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
50
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
51
+ return emb
52
+
53
+
54
+ class Block1D(torch.nn.Module):
55
+ def __init__(self, dim, dim_out, groups=8):
56
+ super().__init__()
57
+ self.block = torch.nn.Sequential(
58
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
59
+ torch.nn.GroupNorm(groups, dim_out),
60
+ nn.Mish(),
61
+ )
62
+
63
+ def forward(self, x, mask):
64
+ output = self.block(x * mask)
65
+ return output * mask
66
+
67
+
68
+ class ResnetBlock1D(torch.nn.Module):
69
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
70
+ super().__init__()
71
+ self.mlp = torch.nn.Sequential(
72
+ nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
73
+ )
74
+
75
+ self.block1 = Block1D(dim, dim_out, groups=groups)
76
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
77
+
78
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
79
+
80
+ def forward(self, x, mask, time_emb):
81
+ h = self.block1(x, mask)
82
+ h += self.mlp(time_emb).unsqueeze(-1)
83
+ h = self.block2(h, mask)
84
+ output = h + self.res_conv(x * mask)
85
+ return output
86
+
87
+
88
+ class Downsample1D(nn.Module):
89
+ def __init__(self, dim):
90
+ super().__init__()
91
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
92
+
93
+ def forward(self, x):
94
+ return self.conv(x)
95
+
96
+
97
+ class TimestepEmbedding(nn.Module):
98
+ def __init__(
99
+ self,
100
+ in_channels: int,
101
+ time_embed_dim: int,
102
+ act_fn: str = "silu",
103
+ out_dim: int = None,
104
+ post_act_fn: Optional[str] = None,
105
+ cond_proj_dim=None,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
110
+
111
+ if cond_proj_dim is not None:
112
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
113
+ else:
114
+ self.cond_proj = None
115
+
116
+ self.act = get_activation(act_fn)
117
+
118
+ if out_dim is not None:
119
+ time_embed_dim_out = out_dim
120
+ else:
121
+ time_embed_dim_out = time_embed_dim
122
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
123
+
124
+ if post_act_fn is None:
125
+ self.post_act = None
126
+ else:
127
+ self.post_act = get_activation(post_act_fn)
128
+
129
+ def forward(self, sample, condition=None):
130
+ if condition is not None:
131
+ sample = sample + self.cond_proj(condition)
132
+ sample = self.linear_1(sample)
133
+
134
+ if self.act is not None:
135
+ sample = self.act(sample)
136
+
137
+ sample = self.linear_2(sample)
138
+
139
+ if self.post_act is not None:
140
+ sample = self.post_act(sample)
141
+ return sample
142
+
143
+
144
+ class Upsample1D(nn.Module):
145
+ """A 1D upsampling layer with an optional convolution.
146
+
147
+ Parameters:
148
+ channels (`int`):
149
+ number of channels in the inputs and outputs.
150
+ use_conv (`bool`, default `False`):
151
+ option to use a convolution.
152
+ use_conv_transpose (`bool`, default `False`):
153
+ option to use a convolution transpose.
154
+ out_channels (`int`, optional):
155
+ number of output channels. Defaults to `channels`.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ channels,
161
+ use_conv=False,
162
+ use_conv_transpose=True,
163
+ out_channels=None,
164
+ name="conv",
165
+ ):
166
+ super().__init__()
167
+ self.channels = channels
168
+ self.out_channels = out_channels or channels
169
+ self.use_conv = use_conv
170
+ self.use_conv_transpose = use_conv_transpose
171
+ self.name = name
172
+
173
+ self.conv = None
174
+ if use_conv_transpose:
175
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
176
+ elif use_conv:
177
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
178
+
179
+ def forward(self, inputs):
180
+ assert inputs.shape[1] == self.channels
181
+ if self.use_conv_transpose:
182
+ return self.conv(inputs)
183
+
184
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
185
+
186
+ if self.use_conv:
187
+ outputs = self.conv(outputs)
188
+
189
+ return outputs
matcha_feat.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.utils.data
29
+ from librosa.filters import mel as librosa_mel_fn
30
+ from scipy.io.wavfile import read
31
+
32
+ MAX_WAV_VALUE = 32768.0
33
+
34
+
35
+ def load_wav(full_path):
36
+ sampling_rate, data = read(full_path)
37
+ return data, sampling_rate
38
+
39
+
40
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
41
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
42
+
43
+
44
+ def dynamic_range_decompression(x, C=1):
45
+ return np.exp(x) / C
46
+
47
+
48
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
49
+ return torch.log(torch.clamp(x, min=clip_val) * C)
50
+
51
+
52
+ def dynamic_range_decompression_torch(x, C=1):
53
+ return torch.exp(x) / C
54
+
55
+
56
+ def spectral_normalize_torch(magnitudes):
57
+ output = dynamic_range_compression_torch(magnitudes)
58
+ return output
59
+
60
+
61
+ def spectral_de_normalize_torch(magnitudes):
62
+ output = dynamic_range_decompression_torch(magnitudes)
63
+ return output
64
+
65
+
66
+ mel_basis = {}
67
+ hann_window = {}
68
+
69
+
70
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
71
+ if torch.min(y) < -1.0:
72
+ print("min value is ", torch.min(y))
73
+ if torch.max(y) > 1.0:
74
+ print("max value is ", torch.max(y))
75
+
76
+ global mel_basis, hann_window # pylint: disable=global-statement
77
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
80
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
81
+
82
+ y = torch.nn.functional.pad(
83
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
84
+ )
85
+ y = y.squeeze(1)
86
+
87
+ spec = torch.view_as_real(
88
+ torch.stft(
89
+ y,
90
+ n_fft,
91
+ hop_length=hop_size,
92
+ win_length=win_size,
93
+ window=hann_window[str(y.device)],
94
+ center=center,
95
+ pad_mode="reflect",
96
+ normalized=False,
97
+ onesided=True,
98
+ return_complex=True,
99
+ )
100
+ )
101
+
102
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
103
+
104
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
105
+ spec = spectral_normalize_torch(spec)
106
+
107
+ return spec
matcha_transformer.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ from typing import Any, Dict, Optional
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from diffusers.models.attention import (
31
+ GEGLU,
32
+ GELU,
33
+ AdaLayerNorm,
34
+ AdaLayerNormZero,
35
+ ApproximateGELU,
36
+ )
37
+ from diffusers.models.attention_processor import Attention
38
+ from diffusers.models.lora import LoRACompatibleLinear
39
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
40
+
41
+ import torch.nn.functional as F
42
+ from flash_attn import flash_attn_varlen_func
43
+
44
+
45
+ def get_sequence_mask(inputs, inputs_length):
46
+ if inputs.dim() == 3:
47
+ bsz, tgt_len, _ = inputs.size()
48
+ else:
49
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
50
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
51
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
52
+ bsz, tgt_len, 1
53
+ )
54
+ unpacking_index = (
55
+ torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
56
+ ) # 转成下标
57
+ return sequence_mask, unpacking_index
58
+
59
+
60
+ class OmniWhisperAttention(nn.Module):
61
+ def __init__(self, embed_dim, num_heads, causal=False):
62
+ super().__init__()
63
+ self.embed_dim = embed_dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = embed_dim // num_heads
66
+
67
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
68
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
69
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
70
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
71
+
72
+ self.causal = causal
73
+
74
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
75
+ bsz, _ = hidden_states.size()
76
+
77
+ query_states = self.q_proj(hidden_states).view(
78
+ bsz, self.num_heads, self.head_dim
79
+ )
80
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
81
+ value_states = self.v_proj(hidden_states).view(
82
+ bsz, self.num_heads, self.head_dim
83
+ )
84
+
85
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
86
+ torch.int32
87
+ )
88
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
89
+ attn_output = flash_attn_varlen_func(
90
+ query_states,
91
+ key_states,
92
+ value_states,
93
+ cu_len,
94
+ cu_len,
95
+ max_seqlen,
96
+ max_seqlen,
97
+ causal=self.causal,
98
+ ) # (bsz * qlen, nheads, headdim)
99
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
100
+ attn_output = self.out_proj(attn_output)
101
+ return attn_output
102
+
103
+
104
+ class SnakeBeta(nn.Module):
105
+ """
106
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
107
+ Shape:
108
+ - Input: (B, C, T)
109
+ - Output: (B, C, T), same shape as the input
110
+ Parameters:
111
+ - alpha - trainable parameter that controls frequency
112
+ - beta - trainable parameter that controls magnitude
113
+ References:
114
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
115
+ https://arxiv.org/abs/2006.08195
116
+ Examples:
117
+ >>> a1 = snakebeta(256)
118
+ >>> x = torch.randn(256)
119
+ >>> x = a1(x)
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_features,
125
+ out_features,
126
+ alpha=1.0,
127
+ alpha_trainable=True,
128
+ alpha_logscale=True,
129
+ ):
130
+ """
131
+ Initialization.
132
+ INPUT:
133
+ - in_features: shape of the input
134
+ - alpha - trainable parameter that controls frequency
135
+ - beta - trainable parameter that controls magnitude
136
+ alpha is initialized to 1 by default, higher values = higher-frequency.
137
+ beta is initialized to 1 by default, higher values = higher-magnitude.
138
+ alpha will be trained along with the rest of your model.
139
+ """
140
+ super().__init__()
141
+ self.in_features = (
142
+ out_features if isinstance(out_features, list) else [out_features]
143
+ )
144
+ self.proj = LoRACompatibleLinear(in_features, out_features)
145
+
146
+ # initialize alpha
147
+ self.alpha_logscale = alpha_logscale
148
+ if self.alpha_logscale: # log scale alphas initialized to zeros
149
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
150
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
151
+ else: # linear scale alphas initialized to ones
152
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
153
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
154
+
155
+ self.alpha.requires_grad = alpha_trainable
156
+ self.beta.requires_grad = alpha_trainable
157
+
158
+ self.no_div_by_zero = 0.000000001
159
+
160
+ def forward(self, x):
161
+ """
162
+ Forward pass of the function.
163
+ Applies the function to the input elementwise.
164
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
165
+ """
166
+ x = self.proj(x)
167
+ if self.alpha_logscale:
168
+ alpha = torch.exp(self.alpha)
169
+ beta = torch.exp(self.beta)
170
+ else:
171
+ alpha = self.alpha
172
+ beta = self.beta
173
+
174
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
175
+ torch.sin(x * alpha), 2
176
+ )
177
+
178
+ return x
179
+
180
+
181
+ class FeedForward(nn.Module):
182
+ r"""
183
+ A feed-forward layer.
184
+
185
+ Parameters:
186
+ dim (`int`): The number of channels in the input.
187
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
188
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
189
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
190
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
191
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ dim: int,
197
+ dim_out: Optional[int] = None,
198
+ mult: int = 4,
199
+ dropout: float = 0.0,
200
+ activation_fn: str = "geglu",
201
+ final_dropout: bool = False,
202
+ ):
203
+ super().__init__()
204
+ inner_dim = int(dim * mult)
205
+ dim_out = dim_out if dim_out is not None else dim
206
+
207
+ if activation_fn == "gelu":
208
+ act_fn = GELU(dim, inner_dim)
209
+ if activation_fn == "gelu-approximate":
210
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
211
+ elif activation_fn == "geglu":
212
+ act_fn = GEGLU(dim, inner_dim)
213
+ elif activation_fn == "geglu-approximate":
214
+ act_fn = ApproximateGELU(dim, inner_dim)
215
+ elif activation_fn == "snakebeta":
216
+ act_fn = SnakeBeta(dim, inner_dim)
217
+
218
+ self.net = nn.ModuleList([])
219
+ # project in
220
+ self.net.append(act_fn)
221
+ # project dropout
222
+ self.net.append(nn.Dropout(dropout))
223
+ # project out
224
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
225
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
226
+ if final_dropout:
227
+ self.net.append(nn.Dropout(dropout))
228
+
229
+ def forward(self, hidden_states):
230
+ for module in self.net:
231
+ hidden_states = module(hidden_states)
232
+ return hidden_states
233
+
234
+
235
+ @maybe_allow_in_graph
236
+ class BasicTransformerBlock(nn.Module):
237
+ r"""
238
+ A basic Transformer block.
239
+
240
+ Parameters:
241
+ dim (`int`): The number of channels in the input and output.
242
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
243
+ attention_head_dim (`int`): The number of channels in each head.
244
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
245
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
246
+ only_cross_attention (`bool`, *optional*):
247
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
248
+ double_self_attention (`bool`, *optional*):
249
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
250
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
251
+ num_embeds_ada_norm (:
252
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
253
+ attention_bias (:
254
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ dim: int,
260
+ num_attention_heads: int,
261
+ attention_head_dim: int,
262
+ dropout=0.0,
263
+ cross_attention_dim: Optional[int] = None,
264
+ activation_fn: str = "geglu",
265
+ num_embeds_ada_norm: Optional[int] = None,
266
+ attention_bias: bool = False,
267
+ only_cross_attention: bool = False,
268
+ double_self_attention: bool = False,
269
+ upcast_attention: bool = False,
270
+ norm_elementwise_affine: bool = True,
271
+ norm_type: str = "layer_norm",
272
+ final_dropout: bool = False,
273
+ use_omni_attn: bool = False,
274
+ ):
275
+ super().__init__()
276
+
277
+ self.use_omni_attn = use_omni_attn
278
+ self.dim = dim
279
+
280
+ self.only_cross_attention = only_cross_attention
281
+
282
+ self.use_ada_layer_norm_zero = (
283
+ num_embeds_ada_norm is not None
284
+ ) and norm_type == "ada_norm_zero"
285
+ self.use_ada_layer_norm = (
286
+ num_embeds_ada_norm is not None
287
+ ) and norm_type == "ada_norm"
288
+
289
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
290
+ raise ValueError(
291
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
292
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
293
+ )
294
+
295
+ # Define 3 blocks. Each block has its own normalization layer.
296
+ # 1. Self-Attn
297
+ if self.use_ada_layer_norm:
298
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
299
+ elif self.use_ada_layer_norm_zero:
300
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
301
+ else:
302
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
303
+
304
+ if self.use_omni_attn:
305
+ if only_cross_attention:
306
+ raise NotImplementedError
307
+ print(
308
+ "Use OmniWhisperAttention with flash attention. Dropout is ignored."
309
+ )
310
+ self.attn1 = OmniWhisperAttention(
311
+ embed_dim=dim, num_heads=num_attention_heads, causal=False
312
+ )
313
+ else:
314
+ self.attn1 = Attention(
315
+ query_dim=dim,
316
+ heads=num_attention_heads,
317
+ dim_head=attention_head_dim,
318
+ dropout=dropout,
319
+ bias=attention_bias,
320
+ cross_attention_dim=(
321
+ cross_attention_dim if only_cross_attention else None
322
+ ),
323
+ upcast_attention=upcast_attention,
324
+ )
325
+
326
+ # 2. Cross-Attn
327
+ if cross_attention_dim is not None or double_self_attention:
328
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
329
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
330
+ # the second cross attention block.
331
+ self.norm2 = (
332
+ AdaLayerNorm(dim, num_embeds_ada_norm)
333
+ if self.use_ada_layer_norm
334
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
335
+ )
336
+ self.attn2 = Attention(
337
+ query_dim=dim,
338
+ cross_attention_dim=(
339
+ cross_attention_dim if not double_self_attention else None
340
+ ),
341
+ heads=num_attention_heads,
342
+ dim_head=attention_head_dim,
343
+ dropout=dropout,
344
+ bias=attention_bias,
345
+ upcast_attention=upcast_attention,
346
+ # scale_qk=False, # uncomment this to not to use flash attention
347
+ ) # is self-attn if encoder_hidden_states is none
348
+ else:
349
+ self.norm2 = None
350
+ self.attn2 = None
351
+
352
+ # 3. Feed-forward
353
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
354
+ self.ff = FeedForward(
355
+ dim,
356
+ dropout=dropout,
357
+ activation_fn=activation_fn,
358
+ final_dropout=final_dropout,
359
+ )
360
+
361
+ # let chunk size default to None
362
+ self._chunk_size = None
363
+ self._chunk_dim = 0
364
+
365
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
366
+ # Sets chunk feed-forward
367
+ self._chunk_size = chunk_size
368
+ self._chunk_dim = dim
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.FloatTensor,
373
+ attention_mask: Optional[torch.FloatTensor] = None,
374
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
375
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
376
+ timestep: Optional[torch.LongTensor] = None,
377
+ cross_attention_kwargs: Dict[str, Any] = None,
378
+ class_labels: Optional[torch.LongTensor] = None,
379
+ ):
380
+
381
+ bsz, tgt_len, d_model = hidden_states.shape
382
+
383
+ # Notice that normalization is always applied before the real computation in the following blocks.
384
+ # 1. Self-Attention
385
+ if self.use_ada_layer_norm:
386
+ norm_hidden_states = self.norm1(hidden_states, timestep)
387
+ elif self.use_ada_layer_norm_zero:
388
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
389
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
390
+ )
391
+ else:
392
+ norm_hidden_states = self.norm1(hidden_states)
393
+
394
+ cross_attention_kwargs = (
395
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
396
+ )
397
+
398
+ if self.use_omni_attn:
399
+ seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
400
+ var_len_attention_mask, unpacking_index = get_sequence_mask(
401
+ norm_hidden_states, seq_len
402
+ )
403
+ norm_hidden_states = torch.masked_select(
404
+ norm_hidden_states, var_len_attention_mask
405
+ )
406
+ norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
407
+ attn_output = self.attn1(norm_hidden_states, seq_len)
408
+ # unpacking
409
+ attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
410
+ bsz, tgt_len, d_model
411
+ )
412
+ attn_output = torch.where(var_len_attention_mask, attn_output, 0)
413
+ else:
414
+ attn_output = self.attn1(
415
+ norm_hidden_states,
416
+ encoder_hidden_states=(
417
+ encoder_hidden_states if self.only_cross_attention else None
418
+ ),
419
+ attention_mask=(
420
+ encoder_attention_mask
421
+ if self.only_cross_attention
422
+ else attention_mask
423
+ ),
424
+ **cross_attention_kwargs,
425
+ )
426
+
427
+ if self.use_ada_layer_norm_zero:
428
+ attn_output = gate_msa.unsqueeze(1) * attn_output
429
+ hidden_states = attn_output + hidden_states
430
+
431
+ # 2. Cross-Attention
432
+ if self.attn2 is not None:
433
+ norm_hidden_states = (
434
+ self.norm2(hidden_states, timestep)
435
+ if self.use_ada_layer_norm
436
+ else self.norm2(hidden_states)
437
+ )
438
+
439
+ attn_output = self.attn2(
440
+ norm_hidden_states,
441
+ encoder_hidden_states=encoder_hidden_states,
442
+ attention_mask=encoder_attention_mask,
443
+ **cross_attention_kwargs,
444
+ )
445
+ hidden_states = attn_output + hidden_states
446
+
447
+ # 3. Feed-forward
448
+ norm_hidden_states = self.norm3(hidden_states)
449
+
450
+ if self.use_ada_layer_norm_zero:
451
+ norm_hidden_states = (
452
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
453
+ )
454
+
455
+ if self._chunk_size is not None:
456
+ # "feed_forward_chunk_size" can be used to save memory
457
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
458
+ raise ValueError(
459
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
460
+ )
461
+
462
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
463
+ ff_output = torch.cat(
464
+ [
465
+ self.ff(hid_slice)
466
+ for hid_slice in norm_hidden_states.chunk(
467
+ num_chunks, dim=self._chunk_dim
468
+ )
469
+ ],
470
+ dim=self._chunk_dim,
471
+ )
472
+ else:
473
+ ff_output = self.ff(norm_hidden_states)
474
+
475
+ if self.use_ada_layer_norm_zero:
476
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
477
+
478
+ hidden_states = ff_output + hidden_states
479
+
480
+ return hidden_states
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:887a4aafba70ac6740debcf22c58c4f40555f584c702a85776901991498ce59a
3
+ size 4877656728
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d959e03801a1794b6b6e4382c2ea49a4070789ab60d94734274cb7923604547
3
+ size 4932746496
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3876a179fd44ca371bade65d23887555ba0fa6945bfeb1809ba901cd296ae735
3
+ size 4999921608
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d90afae3003cab9992d3f6ccf75f76d7c99779439049a230b898f7a63eb19f39
3
+ size 4677721496
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d81a0faa09bbed4856c1c86a3800317232888d1326fa0a3854cbc91febc67139
3
+ size 1640609776
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_omni.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Baichuan Inc. All Rights Reserved.
2
+ #
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """ PyTorch omni model."""
22
+ import os
23
+ import time
24
+ import json
25
+ import math
26
+ import numpy as np
27
+ from typing import List, Optional, Tuple, Union, Any
28
+ from threading import Thread
29
+ from easydict import EasyDict
30
+
31
+ import torch
32
+ import torch.distributed
33
+ import torch.utils.checkpoint
34
+ from torch import nn
35
+ from torch.nn import CrossEntropyLoss
36
+ from torch.nn import functional as F
37
+ import torch.distributed as dist
38
+ from transformers import PreTrainedModel
39
+ from transformers.activations import ACT2FN
40
+ from dataclasses import dataclass
41
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
42
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
43
+ from transformers.generation.utils import GenerationConfig
44
+ from transformers.utils import logging
45
+ # import for dynamic import not used in this file
46
+ from .vector_quantize import VectorQuantize, EuclideanCodebook
47
+ from .matcha_components import (
48
+ SinusoidalPosEmb,
49
+ Block1D,
50
+ ResnetBlock1D,
51
+ Downsample1D,
52
+ TimestepEmbedding,
53
+ Upsample1D,
54
+ )
55
+ from .matcha_transformer import BasicTransformerBlock
56
+ from .flow_matching import ConditionalDecoder, ConditionalCFM
57
+
58
+ from .configuration_omni import OmniConfig
59
+ from .audio_modeling_omni import (RMSNorm,
60
+ OmniAudioEncoder,
61
+ OmniAudioDecoder,
62
+ OmniAudioVQBridgeTokenizer,
63
+ OmniAudioFlowMatchingDecoder)
64
+ from .visual_modeling_omni import OmniVisualEncoder, OmniVisualBridge
65
+ from .processor_omni import OmniMMProcessor
66
+
67
+ # support model path contain point(.)
68
+ try:
69
+ # step1: copy relative imports to transformers_modules
70
+ from .generation_utils import build_chat_input, TextIterStreamer
71
+ from .sequence_parallel_utils import (
72
+ create_attention_layer,
73
+ get_sequence_parallel_size,
74
+ get_sequence_parallel_chunk,
75
+ )
76
+ except ModuleNotFoundError:
77
+ # step2: direct import from transformers_modules
78
+ try: # bypass check_imports failure
79
+ import sys
80
+ sys.path.append(os.path.dirname(__file__))
81
+ from generation_utils import build_chat_input, TextIterStreamer
82
+ from sequence_parallel_utils import (
83
+ create_attention_layer,
84
+ get_sequence_parallel_size,
85
+ get_sequence_parallel_chunk,
86
+ )
87
+ except Exception:
88
+ raise
89
+
90
+ logger = logging.get_logger(__name__)
91
+
92
+ def get_slopes(n):
93
+ def get_slopes_power_of_2(n):
94
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
95
+ ratio = start
96
+ return [start * ratio ** i for i in range(n)]
97
+
98
+ if math.log2(n).is_integer():
99
+ return get_slopes_power_of_2(
100
+ n) # In the paper, we only train models that have 2^a heads for some a. This function has
101
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
102
+ closest_power_of_2 = 2 ** math.floor(
103
+ math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
104
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
105
+ :n - closest_power_of_2]
106
+
107
+
108
+ class RotaryEmbedding(torch.nn.Module):
109
+ def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
110
+ super().__init__()
111
+ # 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
112
+ # DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
113
+ try:
114
+ import deepspeed
115
+ self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
116
+ except:
117
+ self.arange = torch.arange
118
+
119
+ self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
120
+ self.max_seq_len_cached = max_position_embeddings
121
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
122
+ freqs = torch.outer(t, self.inv_freq)
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
125
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
126
+
127
+ def forward(self, x, seq_len=None):
128
+ # x: [bs, num_attention_heads, seq_len, head_size]
129
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
130
+ if seq_len > self.max_seq_len_cached:
131
+ self.max_seq_len_cached = seq_len
132
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
133
+ freqs = torch.outer(t, self.inv_freq)
134
+ emb = torch.cat((freqs, freqs), dim=-1)
135
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
136
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
137
+ return (
138
+ self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
139
+ self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
140
+ )
141
+
142
+
143
+ def rotate_half(x):
144
+ """Rotates half the hidden dims of the input."""
145
+ x1 = x[..., : x.shape[-1] // 2]
146
+ x2 = x[..., x.shape[-1] // 2:]
147
+ return torch.cat((-x2, x1), dim=-1)
148
+
149
+
150
+ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
151
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
152
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
153
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
154
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
155
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
156
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
157
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
158
+
159
+
160
+ class MLP(nn.Module):
161
+ def __init__(
162
+ self,
163
+ hidden_size: int,
164
+ intermediate_size: int,
165
+ hidden_act: str,
166
+ ):
167
+ super().__init__()
168
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
169
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
170
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
171
+ self.act_fn = ACT2FN[hidden_act]
172
+
173
+ def forward(self, x):
174
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
175
+
176
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
177
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
178
+ """
179
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
180
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
181
+ """
182
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
183
+ if n_rep == 1:
184
+ return hidden_states
185
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
186
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
187
+
188
+
189
+ class Attention(nn.Module):
190
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
191
+ def __init__(self, config: OmniConfig, is_sparse=False):
192
+ super().__init__()
193
+ self.config = config
194
+ self.position_embedding_type = config.position_embedding_type.lower()
195
+ self.num_kv_heads = config.num_key_value_heads
196
+ self.head_dim = config.head_dim
197
+ self.hidden_size = config.num_attention_heads * self.head_dim
198
+ self.hidden_kv_size = self.num_kv_heads * self.head_dim
199
+
200
+ if is_sparse:
201
+ self.num_heads = config.sparse_attention_heads
202
+ assert self.num_kv_heads == config.num_attention_heads
203
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
204
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
205
+ else:
206
+ self.num_heads = config.num_attention_heads
207
+ if self.config.attention_qkv_pack:
208
+ self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
209
+ else:
210
+ self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
211
+ self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
212
+ self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
213
+
214
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
215
+
216
+ if self.position_embedding_type == 'rope':
217
+ self.rotary_emb = RotaryEmbedding(
218
+ dim=self.head_dim,
219
+ max_position_embeddings=config.max_position_embeddings,
220
+ base=config.get_rotary_base()
221
+ )
222
+ elif self.position_embedding_type == 'alibi':
223
+ self.alibi_slopes = get_slopes(self.num_heads)
224
+ self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
225
+
226
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
227
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
228
+
229
+ def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
230
+ assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
231
+ return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ seqlens: Optional[torch.IntTensor] = None,
239
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
240
+ output_attentions: bool = False,
241
+ use_cache: bool = False,
242
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
243
+ bsz, q_len = hidden_states.shape[:2]
244
+
245
+ if self.config.attention_qkv_pack:
246
+ proj = self.W_pack(hidden_states)
247
+ query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
248
+ else:
249
+ query_states = self.q_proj(hidden_states)
250
+ key_states = self.k_proj(hidden_states)
251
+ value_states = self.v_proj(hidden_states)
252
+
253
+ # (B, S, hidden_size) -> (B, num_heads, S, head_size)
254
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
255
+ # (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
256
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
257
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
258
+
259
+ kv_seq_len = key_states.shape[-2]
260
+ if past_key_value is not None:
261
+ kv_seq_len += past_key_value[0].shape[-2]
262
+ if self.position_embedding_type == 'rope':
263
+ max_position = position_ids.max().item()+1 if position_ids is not None else kv_seq_len * get_sequence_parallel_size()
264
+ cos, sin = self.rotary_emb(value_states, seq_len=max_position)
265
+ query_states, key_states = apply_rotary_pos_emb(
266
+ query_states, key_states, cos, sin,
267
+ get_sequence_parallel_chunk(position_ids)
268
+ )
269
+
270
+ if past_key_value is not None:
271
+ # reuse k, v, self_attention
272
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
273
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
274
+ past_key_value = (key_states, value_states) if use_cache else None
275
+
276
+ # repeat k/v heads if n_kv_heads < n_heads
277
+ key_states = self._repeat_kv(key_states, query_states.size(1))
278
+ value_states = self._repeat_kv(value_states, query_states.size(1))
279
+
280
+ if seqlens is not None:
281
+ seqlens = seqlens.to(dtype=torch.int32)
282
+ max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
283
+ if self.position_embedding_type == 'alibi':
284
+ alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
285
+ else:
286
+ alibi_slopes = None
287
+ attn_output = self.attention(
288
+ query_states, key_states, value_states, seqlens, seqlens,
289
+ max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
290
+ else:
291
+ attn_output = self.attention(
292
+ query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
293
+
294
+ attn_output = attn_output.reshape(bsz, q_len, -1)
295
+ attn_output = self.o_proj(attn_output)
296
+
297
+ return attn_output, None, past_key_value
298
+
299
+
300
+ class DecoderLayer(nn.Module):
301
+ def __init__(self, config: OmniConfig, is_sparse=False):
302
+ super().__init__()
303
+ self.hidden_size = config.hidden_size
304
+ self.self_attn = Attention(config=config, is_sparse=is_sparse)
305
+ self.mlp = MLP(
306
+ hidden_size=self.hidden_size,
307
+ intermediate_size=config.intermediate_size,
308
+ hidden_act=config.hidden_act,
309
+ )
310
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
311
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ position_ids: Optional[torch.LongTensor] = None,
318
+ seqlens: Optional[torch.IntTensor] = None,
319
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
320
+ output_attentions: Optional[bool] = False,
321
+ use_cache: Optional[bool] = False,
322
+ group_index=None,
323
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
324
+
325
+ residual = hidden_states
326
+
327
+ hidden_states = self.input_layernorm(hidden_states)
328
+
329
+ # Self Attention
330
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
331
+ hidden_states=hidden_states,
332
+ attention_mask=attention_mask,
333
+ position_ids=position_ids,
334
+ seqlens=seqlens,
335
+ past_key_value=past_key_value,
336
+ output_attentions=output_attentions,
337
+ use_cache=use_cache,
338
+ )
339
+ hidden_states = residual + hidden_states
340
+
341
+ # Fully Connected
342
+ residual = hidden_states
343
+ hidden_states = self.post_attention_layernorm(hidden_states)
344
+ hidden_states = self.mlp(hidden_states)
345
+ hidden_states = residual + hidden_states
346
+
347
+ outputs = (hidden_states,)
348
+
349
+ if output_attentions:
350
+ outputs += (self_attn_weights,)
351
+
352
+ if use_cache:
353
+ outputs += (present_key_value,)
354
+
355
+ return outputs
356
+
357
+
358
+ class OmniPreTrainedModel(PreTrainedModel):
359
+ config_class = OmniConfig
360
+ base_model_prefix = "model"
361
+ supports_gradient_checkpointing = True
362
+ _no_split_modules = ["DecoderLayer"]
363
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
364
+
365
+ def _init_weights(self, module):
366
+ std = self.config.initializer_range
367
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
368
+ module.weight.data.normal_(mean=0.0, std=std)
369
+ if module.bias is not None:
370
+ module.bias.data.zero_()
371
+ elif isinstance(module, nn.Embedding):
372
+ module.weight.data.normal_(mean=0.0, std=std)
373
+ if module.padding_idx is not None:
374
+ module.weight.data[module.padding_idx].zero_()
375
+ elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.GroupNorm):
376
+ module.weight.data.fill_(1.0)
377
+ module.bias.data.zero_()
378
+ elif isinstance(module, RMSNorm):
379
+ module.weight.data.fill_(1.0)
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, OmniModel):
383
+ module.gradient_checkpointing = value
384
+
385
+ @dataclass
386
+ class OmniModelOutputWithPast(BaseModelOutputWithPast):
387
+ audio_encoder_ret: Optional[Any] = None
388
+ audio_decoder_ret: Optional[Any] = None
389
+
390
+ class OmniModel(OmniPreTrainedModel):
391
+ def __init__(self, config: OmniConfig):
392
+ super().__init__(config)
393
+ self.padding_idx = config.pad_token_id
394
+ self.vocab_size = config.vocab_size
395
+
396
+ if config.visual_config.enable:
397
+ self.visual_model = OmniVisualEncoder(config.visual_config)
398
+ self.visual_bridge_model = OmniVisualBridge(config.visual_config)
399
+ if config.video_config.enable and not config.visual_config.enable: # in case 没有visual_config而只有video_config
400
+ self.visual_model = OmniVisualEncoder(config.video_config)
401
+ self.visual_bridge_model = OmniVisualBridge(config.video_config)
402
+
403
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
404
+ self.layers = nn.ModuleList([
405
+ DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
406
+ for layer_idx in range(config.num_hidden_layers)
407
+ ])
408
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
409
+
410
+ self.audio_embed_layers = nn.ModuleList([
411
+ nn.Embedding(codedim + 1, config.hidden_size)
412
+ for i, codedim in enumerate(config.audio_config.vq_config.codebook_sizes)
413
+ ])
414
+
415
+ self.gradient_checkpointing = True
416
+ # Initialize weights and apply final processing
417
+ self.post_init()
418
+
419
+ def get_input_embeddings(self):
420
+ return self.embed_tokens
421
+
422
+ def set_input_embeddings(self, value):
423
+ self.embed_tokens = value
424
+
425
+ @torch.no_grad()
426
+ def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
427
+ '''
428
+ 获取任意模态的特殊mask,包含以下
429
+ 1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
430
+ 2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
431
+ 3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
432
+ '''
433
+ pad_mask = torch.eq(input_ids, pad_token_id)
434
+ sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
435
+ lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
436
+ for sp_id in special_token_list:
437
+ sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
438
+ lm_head_mask[sp_id, 0] = True
439
+ return pad_mask, sp_mask, lm_head_mask
440
+
441
+ def get_multimodal_embed(
442
+ self,
443
+ input_ids,
444
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
445
+ multimodal_embed,
446
+ pad_token_id,
447
+ fake_input,
448
+ group_index=None, # 某种模态的编号
449
+ ):
450
+ pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, pad_token_id, self.config.multimodal_special_token_list)
451
+ if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
452
+ multimodal_embed = multimodal_embed.to(input_ids.device)
453
+ if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
454
+ assert pad_mask.sum() == multimodal_embed.shape[0]
455
+ else:
456
+ assert pad_mask.sum() <= 0
457
+
458
+ # 合并 当前模态embeddings 和text embeddings
459
+ input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
460
+ text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0
461
+ multimodal_embedding = torch.embedding(multimodal_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
462
+ multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
463
+ final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
464
+
465
+ if group_index is None:
466
+ group_index = pad_mask.to(torch.int32)
467
+ else:
468
+ current_index = torch.max(group_index) + 1
469
+ group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
470
+
471
+ return final_embedding, group_index
472
+
473
+ def get_visual_embed(
474
+ self,
475
+ input_ids,
476
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
477
+ images = None,
478
+ patch_nums = None,
479
+ images_grid = None,
480
+ videos = None,
481
+ videos_patch_nums = None,
482
+ videos_grid = None,
483
+ group_index = None, # 某种模态的编号
484
+ ):
485
+ if images is None or len(images) <= 0:
486
+ images, images_grid, patch_nums = self.visual_model.fake_input(input_ids.device)
487
+ image_fake_input = True
488
+ else:
489
+ image_fake_input = False
490
+
491
+ if videos is None or len(videos) <= 0 :
492
+ videos, videos_grid, videos_patch_nums = self.visual_model.fake_input(input_ids.device)
493
+ video_fake_input = True
494
+ else:
495
+ video_fake_input = False
496
+
497
+ visual_input = images + videos
498
+ visual_grid = images_grid + videos_grid
499
+
500
+ visual_input = torch.cat(visual_input, dim=0)
501
+ visual_grid = torch.tensor(np.array(visual_grid))
502
+
503
+ visual_embed = self.visual_model(visual_input, grid_thw=visual_grid)
504
+ visual_embed = self.visual_bridge_model(visual_embed)
505
+
506
+ assert sum(patch_nums) + sum(videos_patch_nums) == visual_embed.shape[0]
507
+ images_embed = visual_embed[:sum(patch_nums)]
508
+ videos_embed = visual_embed[sum(patch_nums):]
509
+
510
+ final_embedding, group_index = self.get_multimodal_embed(input_ids, text_embedding, images_embed, self.config.visual_config.image_pad_token_id, image_fake_input, group_index=group_index)
511
+ final_embedding, group_index = self.get_multimodal_embed(input_ids, final_embedding, videos_embed, self.config.video_config.video_place_token_id, video_fake_input, group_index=group_index)
512
+ return final_embedding, group_index
513
+
514
+
515
+ @torch.no_grad()
516
+ def audio_fake_input(self, device):
517
+ return torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=device)
518
+
519
+ def forward(
520
+ self,
521
+ input_ids: torch.LongTensor = None,
522
+ attention_mask: Optional[torch.Tensor] = None,
523
+ position_ids: Optional[torch.LongTensor] = None,
524
+ seqlens: Optional[torch.IntTensor] = None,
525
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
526
+ inputs_embeds: Optional[torch.FloatTensor] = None,
527
+ audios_tokens: Optional[List|torch.Tensor] = None, # 音频token bs*seqlen*vq_num
528
+ images: Optional[List|torch.Tensor] = None,
529
+ patch_nums: Optional[torch.Tensor] = None,
530
+ images_grid: Optional[List|torch.Tensor] = None,
531
+ videos: Optional[List|torch.Tensor] = None,
532
+ videos_patch_nums: Optional[torch.Tensor] = None,
533
+ videos_grid: Optional[List|torch.Tensor] = None,
534
+ use_cache: Optional[bool] = None,
535
+ output_attentions: Optional[bool] = None,
536
+ output_hidden_states: Optional[bool] = None,
537
+ return_dict: Optional[bool] = None,
538
+ ) -> Union[Tuple, OmniModelOutputWithPast]:
539
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
540
+ output_hidden_states = (
541
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
542
+ )
543
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
544
+ return_dict = True if (return_dict is not None or self.training) else self.config.use_return_dict
545
+
546
+ # retrieve input_ids and inputs_embeds
547
+ if input_ids is not None and inputs_embeds is not None:
548
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
549
+ elif input_ids is not None:
550
+ batch_size, seq_length = input_ids.shape
551
+ elif inputs_embeds is not None:
552
+ batch_size, seq_length, _ = inputs_embeds.shape
553
+ else:
554
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
555
+
556
+ seq_length_with_past = seq_length
557
+ past_key_values_length = 0
558
+
559
+ if past_key_values is not None:
560
+ past_key_values_length = past_key_values[0][0].shape[2]
561
+ seq_length_with_past = seq_length_with_past + past_key_values_length
562
+
563
+ if position_ids is None:
564
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
565
+ position_ids = torch.arange(
566
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
567
+ )
568
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
569
+ else:
570
+ position_ids = position_ids.view(-1, seq_length).long()
571
+
572
+ group_index, audio_decoder_ret = None, None
573
+ if inputs_embeds is None:
574
+ sp_input_ids = get_sequence_parallel_chunk(input_ids)
575
+ inputs_embeds = self.embed_tokens(sp_input_ids)
576
+ if audios_tokens is None or len(audios_tokens) <= 0 :
577
+ audios_tokens = torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=input_ids.device) # a fake input
578
+ fake_input = True
579
+ else:
580
+ fake_input = False
581
+ for i, audio_emb_layer in enumerate(self.audio_embed_layers):
582
+ if i==0:
583
+ audio_embs = audio_emb_layer(audios_tokens[..., i])
584
+ else:
585
+ audio_embs += audio_emb_layer(audios_tokens[..., i])
586
+ inputs_embeds, group_index = self.get_multimodal_embed(sp_input_ids, inputs_embeds, audio_embs, self.config.audio_config.audio_pad_token_id, fake_input, group_index=group_index)
587
+
588
+ if self.config.visual_config.enable or self.config.video_config.enable:
589
+ inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, patch_nums, images_grid, videos, videos_patch_nums, videos_grid, group_index=group_index) # 注意更新group index
590
+
591
+ if seqlens is not None and seqlens.ndim == 2:
592
+ cu_seqlens = []
593
+ offset, seqlen = 0, seqlens.size(1)
594
+ for lens in seqlens:
595
+ cu_seqlens.append(offset)
596
+ cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
597
+ offset += seqlen
598
+ cu_seqlens.append(offset)
599
+ seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
600
+ elif seqlens is None and self.training:
601
+ seqlens = torch.arange(
602
+ end=input_ids.size(0) + 1,
603
+ dtype=torch.int32,
604
+ device=input_ids.device
605
+ ) * input_ids.size(1)
606
+ if seqlens is not None:
607
+ attention_mask = None # unset attention_mask to save memory
608
+
609
+ if seqlens is None and attention_mask is None:
610
+ attention_mask = torch.ones(
611
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
612
+ )
613
+ if attention_mask is not None:
614
+ attention_mask = _prepare_4d_causal_attention_mask(
615
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
616
+ )
617
+
618
+ # embed positions
619
+ hidden_states = inputs_embeds
620
+
621
+ if self.gradient_checkpointing and self.training:
622
+ if use_cache:
623
+ logger.warning_once(
624
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
625
+ )
626
+ use_cache = False
627
+
628
+ # decoder layers
629
+ all_hidden_states = () if output_hidden_states else None
630
+ all_self_attns = () if output_attentions else None
631
+ next_decoder_cache = () if use_cache else None
632
+
633
+ for idx, decoder_layer in enumerate(self.layers):
634
+ if output_hidden_states:
635
+ all_hidden_states += (hidden_states,)
636
+
637
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
638
+
639
+ if self.gradient_checkpointing and self.training:
640
+
641
+ def create_custom_forward(module):
642
+ def custom_forward(*inputs):
643
+ # None for past_key_value
644
+ return module(*inputs, output_attentions, False, group_index)
645
+
646
+ return custom_forward
647
+
648
+ layer_outputs = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(decoder_layer),
650
+ hidden_states,
651
+ attention_mask,
652
+ position_ids,
653
+ seqlens,
654
+ None,
655
+ )
656
+ else:
657
+ layer_outputs = decoder_layer(
658
+ hidden_states,
659
+ attention_mask=attention_mask,
660
+ position_ids=position_ids,
661
+ seqlens=seqlens,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ group_index=group_index,
666
+ )
667
+
668
+ hidden_states = layer_outputs[0]
669
+
670
+ if use_cache:
671
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
672
+
673
+ if output_attentions:
674
+ all_self_attns += (layer_outputs[1],)
675
+
676
+ hidden_states = self.norm(hidden_states)
677
+
678
+ # add hidden states from the last decoder layer
679
+ if output_hidden_states:
680
+ all_hidden_states += (hidden_states,)
681
+
682
+ next_cache = next_decoder_cache if use_cache else None
683
+ if not return_dict:
684
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
685
+ return BaseModelOutputWithPast(
686
+ last_hidden_state=hidden_states,
687
+ past_key_values=next_cache,
688
+ hidden_states=all_hidden_states,
689
+ attentions=all_self_attns,
690
+ )
691
+
692
+
693
+ class NormHead(nn.Module):
694
+ def __init__(self, hidden_size, vocab_size, bias=False):
695
+ super().__init__()
696
+ self.hidden_size = hidden_size
697
+ self.vocab_size = vocab_size
698
+ self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
699
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
700
+
701
+ def forward(self, hidden_states, mask=None):
702
+ norm_weight = nn.functional.normalize(self.weight)
703
+ if mask is not None:
704
+ mask = mask.to(norm_weight)
705
+ norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
706
+ return nn.functional.linear(hidden_states, norm_weight)
707
+
708
+
709
+ def extra_repr(self) -> str:
710
+ return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
711
+
712
+ @dataclass
713
+ class OmniMMCausalLMOutputWithPast(ModelOutput):
714
+ loss: Optional[torch.FloatTensor] = None
715
+ logits: Optional[torch.FloatTensor] = None
716
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
717
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
718
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
719
+ audios_emb_for_infer: Optional[torch.FloatTensor] = None # 用于audio head 推理的 embeddings
720
+
721
+
722
+ class CasualDepthTransformerLayer(nn.Module):
723
+ def __init__(self, config, depth):
724
+ super().__init__()
725
+ self.config = config
726
+ embed_size = config.hidden_size
727
+ assert embed_size % 128 == 0
728
+ num_heads = embed_size // 128
729
+ self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
730
+ self.layernorm1 = RMSNorm(embed_size)
731
+ self.layernorm2 = RMSNorm(embed_size)
732
+ self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size)
733
+ self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
734
+
735
+ def forward(self, x):
736
+ seq_len = x.size(1)
737
+ res = x
738
+ x = self.layernorm1(x)
739
+ src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
740
+ _x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
741
+ res = _x + res # (bs, sl, d)
742
+ res = self.layernorm2(res)
743
+ x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.hidden_size, -1, self.config.hidden_size)))
744
+ x = torch.nn.functional.gelu(x)
745
+ x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.hidden_size, -1, 2 * self.config.hidden_size)))
746
+ return res + x
747
+
748
+ class OmniAudioHead(nn.Module):
749
+ def __init__(self, config):
750
+ super().__init__()
751
+ self.config = config
752
+ hidden_size = config.hidden_size
753
+ self.transformer_layers = nn.ModuleList([
754
+ CasualDepthTransformerLayer(config, len(config.audio_config.vq_config.codebook_sizes))
755
+ for _ in range(config.audio_config.audio_head_transformer_layers)
756
+ ])
757
+ self.headnorm = RMSNorm(hidden_size)
758
+ self.heads = nn.ModuleList([
759
+ nn.Linear(hidden_size, vq_size+1)
760
+ for vq_size in config.audio_config.vq_config.codebook_sizes
761
+ ])
762
+ self.gradient_checkpointing = True
763
+
764
+ def forward(self, x, audios_tokens, audio_emb_layers):
765
+ cumsum_audio_embed = torch.stack([
766
+ audio_emb_layers[i](audios_tokens[..., i])
767
+ for i, vq_size in enumerate(self.config.audio_config.vq_config.codebook_sizes[:-1])
768
+ ], dim=1)
769
+ cumsum_audio_embed = torch.cumsum(cumsum_audio_embed, dim=1) # (bs, depth-1, d)
770
+ hidden_states = torch.concat([x.reshape(-1, 1, self.config.hidden_size), cumsum_audio_embed], dim=1) # (bs, depth, d)
771
+ assert hidden_states.size(1) == len(self.config.audio_config.vq_config.codebook_sizes)
772
+ for i, tlayer in enumerate(self.transformer_layers):
773
+ hidden_states = tlayer(hidden_states,)
774
+ hidden_states = self.headnorm(hidden_states)
775
+ logits = [head(hidden_states[:,i]) for i, head in enumerate(self.heads)]
776
+ return logits
777
+
778
+
779
+ class OmniForCausalLM(OmniPreTrainedModel):
780
+ def __init__(self, config):
781
+ super().__init__(config)
782
+ self.config = config
783
+ self.model = OmniModel(config)
784
+ self.audio_tokenizer = OmniAudioTokenizer(config)
785
+ self.audio_head = OmniAudioHead(config)
786
+ if config.use_norm_head:
787
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
788
+ else:
789
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ @property
794
+ def main_device(self):
795
+ return self.lm_head.weight.device
796
+
797
+ def bind_processor(self, tokenizer, **kwargs):
798
+ self.processor = OmniMMProcessor(
799
+ tokenizer=tokenizer,
800
+ config=self.config,
801
+ **kwargs,
802
+ )
803
+ return self.processor
804
+
805
+ def get_input_embeddings(self):
806
+ return self.model.embed_tokens
807
+
808
+ def set_input_embeddings(self, value):
809
+ self.model.embed_tokens = value
810
+
811
+ def get_output_embeddings(self):
812
+ return self.lm_head
813
+
814
+ def set_output_embeddings(self, new_embeddings):
815
+ self.lm_head = new_embeddings
816
+
817
+ def set_decoder(self, decoder):
818
+ self.model = decoder
819
+
820
+ def get_decoder(self):
821
+ return self.model
822
+
823
+ def forward(
824
+ self,
825
+ input_ids: torch.LongTensor = None,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ seqlens: Optional[torch.IntTensor] = None,
829
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
830
+ inputs_embeds: Optional[torch.FloatTensor] = None,
831
+ labels: Optional[torch.LongTensor] = None,
832
+ audios: Optional[List|torch.Tensor] = None,
833
+ audios_tokens: Optional[List|torch.Tensor] = None,
834
+ encoder_length: Optional[torch.Tensor] = None,
835
+ bridge_length: Optional[torch.Tensor] = None,
836
+ images: Optional[torch.Tensor] = None,
837
+ patch_nums: Optional[torch.Tensor] = None,
838
+ images_grid: Optional[torch.Tensor] = None,
839
+ videos: Optional[torch.Tensor] = None,
840
+ videos_patch_nums: Optional[torch.Tensor] = None,
841
+ videos_grid: Optional[torch.Tensor] = None,
842
+ use_cache: Optional[bool] = None,
843
+ output_attentions: Optional[bool] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ return_dict: Optional[bool] = None,
846
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
847
+
848
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
849
+ output_hidden_states = (
850
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
851
+ )
852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
+
854
+ if audios_tokens is not None:
855
+ assert isinstance(audios_tokens, torch.Tensor)
856
+ else:
857
+ if audios is None or len(audios) == 0:
858
+ audios_tokens = None
859
+ else:
860
+ audios_tokens = self.audio_tokenizer(audios,encoder_length,bridge_length)
861
+
862
+ outputs = self.model(
863
+ input_ids=input_ids,
864
+ attention_mask=attention_mask,
865
+ position_ids=position_ids,
866
+ seqlens=seqlens,
867
+ past_key_values=past_key_values,
868
+ inputs_embeds=inputs_embeds,
869
+ audios_tokens=audios_tokens,
870
+ images=images,
871
+ patch_nums=patch_nums,
872
+ images_grid=images_grid,
873
+ videos=videos,
874
+ videos_patch_nums=videos_patch_nums,
875
+ videos_grid=videos_grid,
876
+ use_cache=use_cache,
877
+ output_attentions=output_attentions,
878
+ output_hidden_states=output_hidden_states,
879
+ return_dict=return_dict,
880
+ )
881
+ hidden_states = outputs.last_hidden_state
882
+ audios_emb_for_infer = hidden_states[:,-1,:]
883
+ logits = self.lm_head(hidden_states)
884
+
885
+ return OmniMMCausalLMOutputWithPast(
886
+ logits=logits,
887
+ past_key_values=outputs.past_key_values,
888
+ hidden_states=outputs.hidden_states,
889
+ attentions=outputs.attentions,
890
+ audios_emb_for_infer=audios_emb_for_infer
891
+ )
892
+
893
+ def prepare_inputs_for_generation(
894
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
895
+ ):
896
+ if past_key_values:
897
+ input_ids = input_ids[:, past_key_values[0][0].shape[-2]:]
898
+
899
+ position_ids = kwargs.get("position_ids", None)
900
+ if attention_mask is not None and position_ids is None:
901
+ # create position_ids on the fly for batch generation
902
+ position_ids = attention_mask.long().cumsum(-1)
903
+ # position_ids.masked_fill_(attention_mask == 0, 1)
904
+ if past_key_values:
905
+ position_ids = position_ids[:, past_key_values[0][0].shape[-2]:]
906
+
907
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
908
+ if inputs_embeds is not None and past_key_values is None:
909
+ model_inputs = {"inputs_embeds": inputs_embeds}
910
+ elif past_key_values is not None:
911
+ model_inputs = {"input_ids": input_ids}
912
+ else:
913
+ model_inputs = {"input_ids": input_ids,
914
+ "audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
915
+ "audios_tokens": kwargs.get("audios_tokens", None),
916
+ "images": kwargs.get("images", None),
917
+ "videos": kwargs.get("videos", None)
918
+ }
919
+
920
+ model_inputs.update(
921
+ {
922
+ "position_ids": position_ids,
923
+ "past_key_values": past_key_values,
924
+ "use_cache": kwargs.get("use_cache"),
925
+ "attention_mask": attention_mask,
926
+ "images_grid": kwargs.get("images_grid"),
927
+ "videos_grid": kwargs.get("videos_grid"),
928
+ "patch_nums": kwargs.get("patch_nums"),
929
+ "videos_patch_nums": kwargs.get("videos_patch_nums"),
930
+ }
931
+ )
932
+ return model_inputs
933
+
934
+ @staticmethod
935
+ def _reorder_cache(past_key_values, beam_idx):
936
+ reordered_past = ()
937
+ for layer_past in past_key_values:
938
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
939
+ return reordered_past
940
+
941
+ def chat(self, tokenizer, messages: List[dict], stream=False,
942
+ generation_config: Optional[GenerationConfig]=None):
943
+ generation_config = generation_config or self.generation_config
944
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
945
+ if stream:
946
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
947
+ Thread(target=self.generate, kwargs=dict(
948
+ inputs=input_ids, streamer=streamer,
949
+ generation_config=generation_config,
950
+ )).start()
951
+ return streamer
952
+ else:
953
+ outputs = self.generate(input_ids, generation_config=generation_config)
954
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
955
+ return response
956
+
957
+
958
+ class OmniAudioTokenizer(OmniPreTrainedModel):
959
+ """
960
+ Construct an audio tokenizer and decoder.
961
+ """
962
+ def __init__(self, config: OmniConfig):
963
+ super().__init__(config)
964
+ self.padding_idx = None
965
+ self.vocab_size = config.vocab_size
966
+ self.training = False
967
+ self.eval()
968
+ self.audio_model = OmniAudioEncoder(config.audio_config)
969
+ self.audio_bridge_model = OmniAudioVQBridgeTokenizer(config)
970
+ if config.vocoder_config.enable:
971
+ self.audio_decoder = OmniAudioDecoder(config)
972
+ if config.flow_matching_config.enable:
973
+ self.audio_flow_matching_decoder = OmniAudioFlowMatchingDecoder(config)
974
+
975
+ def encode(self, x, encoder_length: Optional[torch.Tensor] = None,
976
+ bridge_length: Optional[torch.Tensor] = None):
977
+ audio_emb = self.audio_model(x, encoder_length)
978
+ audios_tokens = self.audio_bridge_model(audio_emb, bridge_length)
979
+ return audios_tokens
980
+
981
+ def decode(self, audio_code_ids, bridge_length: Optional[torch.Tensor] = None):
982
+ assert self.config.vocoder_config.enable, "Vocoder is not enabled in config."
983
+ audio_emb = self.audio_bridge_model.decode(audio_code_ids)
984
+ audio_dec = self.audio_decoder(
985
+ audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
986
+ )
987
+ if self.config.flow_matching_config.enable:
988
+ if self.config.flow_matching_config.use_hidden_states_before_dconv2:
989
+ hidden_states, hidden_states_length = (
990
+ self.audio_flow_matching_decoder.unpack_hidden_states(
991
+ audio_dec.hidden_states_before_dconv2,
992
+ audio_dec.output_length_before_dconv2,
993
+ )
994
+ )
995
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
996
+ hidden_states, hidden_states_length
997
+ )
998
+
999
+ else:
1000
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
1001
+ audio_dec.refined_mel, audio_dec.mel_length
1002
+ )
1003
+ return audio_flow_matching_decoder_ret
1004
+ else:
1005
+ return audio_dec
1006
+
1007
+ @torch.no_grad()
1008
+ def forward(self, audios, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
1009
+ self.eval()
1010
+ audios_tokens = self.encode(audios, encoder_length, bridge_length)
1011
+ return audios_tokens
processor_omni.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re, ujson, os, sys, fire, glob, random, time, json
3
+ import numpy as np
4
+ import io
5
+ import torch
6
+ from torch.utils.data import default_collate
7
+ import torchaudio
8
+ from typing import *
9
+ from dataclasses import dataclass, field
10
+ import transformers
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
13
+ from functools import lru_cache
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import concurrent.futures as cf
17
+ from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
18
+ from transformers.image_utils import PILImageResampling
19
+ from PIL import Image, ImageOps
20
+ from PIL import ImageFile
21
+ torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+ import base64
24
+ from decord import VideoReader, cpu
25
+ import cv2
26
+ import av
27
+ import imagesize
28
+ import tempfile
29
+ import math
30
+ from multiprocessing import Pool
31
+ from cairosvg import svg2png
32
+ import hashlib
33
+
34
+ IMAGE_FACTOR = 28
35
+ MIN_PIXELS = 4 * 28 * 28
36
+ MAX_PIXELS = 16384 * 28 * 28
37
+ MAX_RATIO = 200
38
+
39
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
40
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
41
+ VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
42
+ FRAME_FACTOR = 2
43
+ FPS = 2.0
44
+ FPS_MIN_FRAMES = 4
45
+ FPS_MAX_FRAMES = 768
46
+
47
+ def round_by_factor(number: int, factor: int) -> int:
48
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
49
+ return round(number / factor) * factor
50
+
51
+
52
+ def ceil_by_factor(number: int, factor: int) -> int:
53
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
54
+ return math.ceil(number / factor) * factor
55
+
56
+
57
+ def floor_by_factor(number: int, factor: int) -> int:
58
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
59
+ return math.floor(number / factor) * factor
60
+
61
+
62
+ def smart_resize(
63
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
64
+ ) -> tuple[int, int]:
65
+ """
66
+ Rescales the image so that the following conditions are met:
67
+
68
+ 1. Both dimensions (height and width) are divisible by 'factor'.
69
+
70
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
71
+
72
+ 3. The aspect ratio of the image is maintained as closely as possible.
73
+ """
74
+ if max(height, width) / min(height, width) > MAX_RATIO:
75
+ raise ValueError(
76
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
77
+ )
78
+ h_bar = max(factor, round_by_factor(height, factor))
79
+ w_bar = max(factor, round_by_factor(width, factor))
80
+ if h_bar * w_bar > max_pixels:
81
+ beta = math.sqrt((height * width) / max_pixels)
82
+ h_bar = floor_by_factor(height / beta, factor)
83
+ w_bar = floor_by_factor(width / beta, factor)
84
+ elif h_bar * w_bar < min_pixels:
85
+ beta = math.sqrt(min_pixels / (height * width))
86
+ h_bar = ceil_by_factor(height * beta, factor)
87
+ w_bar = ceil_by_factor(width * beta, factor)
88
+ return h_bar, w_bar
89
+
90
+
91
+ def split_text(text, match_regex):
92
+ matches = list(re.finditer(match_regex, text))
93
+ # 初始化结果列表
94
+ result = []
95
+ match_flag_list = []
96
+ # 上一个匹配的结束位置
97
+ last_end = 0
98
+ # 遍历所有匹配项
99
+ for match in matches:
100
+ # 添加匹配项之前的部分
101
+ if text[last_end:match.start()]:
102
+ result.append(text[last_end:match.start()])
103
+ match_flag_list.append(False)
104
+ # 添加匹配项
105
+ result.append(match.group(0))
106
+ match_flag_list.append(True)
107
+ # 更新上一个匹配的结束位置
108
+ last_end = match.end()
109
+ # 添加最后一个匹配项之后的部分
110
+ if text[last_end:]:
111
+ result.append(text[last_end:])
112
+ match_flag_list.append(False)
113
+ return result, match_flag_list
114
+
115
+
116
+ def read_video(image_path, max_frame_number, decode_way):
117
+ if decode_way=='1fps':
118
+ try:
119
+ # print(image_path)
120
+ vr = VideoReader(image_path, ctx=cpu(0))
121
+ total_frame_num = len(vr)
122
+ fps = round(vr.get_avg_fps())
123
+ frame_idx = [i for i in range(0, len(vr), fps)]
124
+ frames = vr.get_batch(frame_idx).asnumpy()
125
+ cnt = len(frames)
126
+ frame_times = range(cnt)
127
+ except Exception as e:
128
+ print(image_path)
129
+ print('error is', e)
130
+ return None
131
+ elif decode_way=='key':
132
+ try:
133
+ with av.open(image_path) as container:
134
+ stream = container.streams.video[0]
135
+ stream.codec_context.skip_frame = 'NONKEY'
136
+ frames = []
137
+ frame_times = []
138
+ fps = int(stream.average_rate)
139
+ cnt = 0
140
+ for frame in container.decode(stream): # 关键帧存成image patch
141
+ image = np.array(frame.to_image())
142
+ frames.append(image)
143
+ frame_time = int(frame.time)
144
+ frame_times.append(frame_time)
145
+ cnt += 1
146
+ except Exception as e:
147
+ print('error is', e)
148
+ return None
149
+ if frames is None or len(frames)==0:
150
+ return None
151
+ if len(frames)>max_frame_number and max_frame_number>0:
152
+ # 生成14个均匀间隔的索引
153
+ indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
154
+ # 根据索引获取对应元素
155
+ frames = frames[indices]
156
+ frame_times = frame_times[indices]
157
+ return frames, frame_times
158
+
159
+
160
+ class OmniImageProcessor:
161
+ def __init__(self, config, **kwargs):
162
+ self.config = config # visual_config
163
+ self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
164
+ self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
165
+ self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
166
+ self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
167
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
168
+ self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
169
+
170
+ def image_transform(self, strseq, return_mm_data = True):
171
+ image = None
172
+ if isinstance(strseq, str):
173
+ if return_mm_data:
174
+ image = Image.open(strseq).convert("RGB")
175
+ else:
176
+ try:
177
+ image = Image.open(BytesIO(strseq)).convert("RGB")
178
+ except:
179
+ image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
180
+
181
+ image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
182
+ image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
183
+
184
+ # resize, crop, scale, normalize
185
+ # 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
186
+ resized_height, resized_width = smart_resize(
187
+ image_org_size[0], image_org_size[1],
188
+ factor=self.patch_size * self.spatial_merge_size,
189
+ min_pixels=self.min_pixels,
190
+ max_pixels=self.max_pixels,
191
+ )
192
+ output_size = (resized_height, resized_width)
193
+
194
+ # 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
195
+ # image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
196
+ # resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
197
+ image = resize(image, output_size, PILImageResampling.BICUBIC)
198
+ img = image.transpose(2, 0, 1)
199
+ # 对图像进行归一化和标准化处理
200
+ image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
201
+ # 处理成patch
202
+ patches = image[np.newaxis, :]
203
+ if patches.shape[0] == 1:
204
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
205
+ channel = patches.shape[1]
206
+ grid_t = patches.shape[0] // self.temporal_patch_size
207
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
208
+ patches = patches.reshape(
209
+ grid_t,
210
+ self.temporal_patch_size,
211
+ channel,
212
+ grid_h // self.spatial_merge_size,
213
+ self.spatial_merge_size,
214
+ self.patch_size,
215
+ grid_w // self.spatial_merge_size,
216
+ self.spatial_merge_size,
217
+ self.patch_size,
218
+ )
219
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
220
+ flatten_patches = patches.reshape(
221
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
222
+ )
223
+
224
+ return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
225
+
226
+
227
+ class OmniAudioProcessor:
228
+ # 包含基本的音频特征抽取模块 + 输入数据解析模块
229
+ def __init__(
230
+ self,
231
+ config, # audio processor config
232
+ **kwargs
233
+ ):
234
+ # make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
235
+ assert(len(torchaudio.list_audio_backends()) > 0)
236
+ self.config = config
237
+ self.mel_filters = mel_filter_bank(
238
+ num_frequency_bins=1 + self.config.n_fft // 2,
239
+ num_mel_filters=self.config.num_mel_bins,
240
+ min_frequency=0.0,
241
+ max_frequency=self.config.sampling_rate / 2.0,
242
+ sampling_rate=self.config.sampling_rate,
243
+ norm="slaney",
244
+ mel_scale="slaney",
245
+ )
246
+ self.window = torch.hann_window(self.config.n_fft)
247
+
248
+ @staticmethod
249
+ def dynamic_range_compression(x, C=1, clip_val=1e-6):
250
+ return torch.log(torch.clamp(x, min=clip_val) * C)
251
+
252
+ @staticmethod
253
+ def zero_mean_unit_var_norm(x):
254
+ return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
255
+
256
+ def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
257
+ metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
258
+ assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
259
+ waveform_tensor, _ = torchaudio.load(uri, normalize=True)
260
+ if self.config.sampling_rate != metadata.sample_rate:
261
+ waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
262
+
263
+ # downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
264
+ if metadata.num_channels > 1:
265
+ waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
266
+
267
+ # normalized to zero mean
268
+ if do_normalize:
269
+ waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
270
+
271
+ if return_tensors: # (channels, samples)
272
+ return waveform_tensor
273
+ else:
274
+ return waveform_tensor.numpy()
275
+
276
+ def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
277
+ channels, wave_samples = waveform.shape
278
+ max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
279
+ if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
280
+ return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
281
+
282
+ split_waveform, start = [], 0
283
+ while start < wave_samples: # 统一按秒数对齐overlap
284
+ if start > int(self.config.sampling_rate * self.config.split_overlap):
285
+ start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
286
+ end = min(start + max_audio_samples, wave_samples)
287
+ if end - start>= self.config.n_fft: # 保证至少有一帧数据
288
+ split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
289
+ start = end
290
+ return split_waveform
291
+
292
+ @classmethod
293
+ def inference_output_length(cls, config, input_length):
294
+ # for whisper + bridge
295
+ kernel_size = config.kernel_size
296
+ stride_size = config.stride_size
297
+ avg_pooler = config.avg_pooler
298
+ encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
299
+ encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
300
+ if avg_pooler > 1:
301
+ bridge_length = encoder_length // avg_pooler
302
+ return encoder_length, bridge_length
303
+
304
+ def extract_fbank_features(self, waveform):
305
+ # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
306
+ channels, wave_samples = waveform.shape
307
+ assert(wave_samples >= self.config.n_fft)
308
+ valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
309
+ if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
310
+ waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
311
+ else:
312
+ waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
313
+
314
+ # window = torch.hann_window(self.config.n_fft)
315
+ stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
316
+ magnitudes = stft[..., :-1].abs() ** 2
317
+
318
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
319
+ mel_spec = mel_filters.T @ magnitudes
320
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
321
+ if waveform.dim() == 2:
322
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
323
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
324
+ else:
325
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
326
+ log_spec = (log_spec + 4.0) / 4.0
327
+
328
+ log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
329
+ log_spec[:, valid_frame_nums:] = 0.0 # pad0
330
+
331
+ return log_spec, valid_frame_nums
332
+
333
+ def data_augment(self, feature: np.array, input_length, training=True):
334
+ # reference https://arxiv.org/pdf/1904.08779
335
+ def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
336
+ num_masked_span = int(mask_prob * input_length / mask_length + random.random())
337
+ num_masked_span = max(num_masked_span, min_masks)
338
+ start_indices = list(range(input_length - mask_length))
339
+ random.shuffle(start_indices)
340
+ start_indices = start_indices[:num_masked_span]
341
+ return start_indices
342
+
343
+ if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
344
+ return feature
345
+ if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
346
+ return feature
347
+ if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
348
+ return feature
349
+
350
+ if self.config.mask_time_prob > 0:
351
+ start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
352
+ for start_idx in start_indices:
353
+ feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
354
+ if self.config.mask_feature_prob > 0:
355
+ start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
356
+ for start_idx in start_indices:
357
+ feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
358
+
359
+ return feature
360
+
361
+ @dataclass
362
+ class OmniProcessorOutput(ModelOutput):
363
+ input_ids: Optional["List|torch.Tensor"] = None
364
+ labels: Optional["List|torch.Tensor"] = None
365
+ attention_mask: Optional["List|torch.Tensor"] = None
366
+ position_ids: Optional["List|torch.Tensor"] = None
367
+ seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
368
+ # audio fields
369
+ audios: Optional["List|torch.Tensor"] = None
370
+ encoder_length: Optional["List|torch.Tensor"] = None
371
+ bridge_length: Optional["List|torch.Tensor"] = None
372
+ # image fields
373
+ images: Optional["List|torch.Tensor"] = None
374
+ patch_nums: Optional["List|torch.Tensor"] = None
375
+ images_size: Optional["List|torch.Tensor"] = None
376
+ crop_size: Optional["List|torch.Tensor"] = None
377
+ images_grid: Optional["List|torch.Tensor"] = None
378
+ # video fields
379
+ videos: Optional["List|torch.Tensor"] = None
380
+ videos_patch_nums: Optional["List|torch.Tensor"] = None
381
+ videos_size: Optional["List|torch.Tensor"] = None
382
+ videos_crop_size: Optional["List|torch.Tensor"] = None
383
+ videos_grid: Optional["List|torch.Tensor"] = None
384
+ # processor fields
385
+ raw_text: Optional[str] = None
386
+ index: Optional[int] = None
387
+
388
+ def concatenate(self, other): # 仅限list使用
389
+ def concat_one(a, b):
390
+ if a is None and b is None:
391
+ return None
392
+ elif a is None and b is not None:
393
+ return b
394
+ elif a is not None and b is None:
395
+ return a
396
+ else:
397
+ return a + b
398
+ return OmniProcessorOutput(
399
+ input_ids=concat_one(self.input_ids, other.input_ids),
400
+ labels=concat_one(self.labels, other.labels),
401
+ audios=concat_one(self.audios, other.audios),
402
+ encoder_length=concat_one(self.encoder_length, other.encoder_length),
403
+ bridge_length=concat_one(self.bridge_length, other.bridge_length),
404
+ images=concat_one(self.images, other.images),
405
+ images_grid=concat_one(self.images_grid, other.images_grid),
406
+ patch_nums=concat_one(self.patch_nums, other.patch_nums),
407
+
408
+ videos=concat_one(self.videos, other.videos),
409
+ videos_grid=concat_one(self.videos_grid, other.videos_grid),
410
+ videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
411
+
412
+ position_ids=concat_one(self.position_ids, other.position_ids),
413
+ seqlens=concat_one(self.seqlens, other.seqlens),
414
+ images_size=concat_one(self.images_size, other.images_size),
415
+ videos_size=concat_one(self.videos_size, other.videos_size),
416
+ index = self.index # concat保持index不变
417
+ )
418
+
419
+ class OmniMMProcessor(object):
420
+ def __init__(self,
421
+ tokenizer: transformers.PreTrainedTokenizer,
422
+ config,
423
+ training,
424
+ relative_path=None,
425
+ parallel=None,
426
+ **kwargs,
427
+ ):
428
+ self.tokenizer = tokenizer
429
+ self.config = config
430
+ self.audio_processor = OmniAudioProcessor(config.audio_config)
431
+ self.visual_processor = None
432
+ if hasattr(config, "visual_config"):
433
+ self.visual_processor = OmniImageProcessor(config.visual_config)
434
+ self.video_processor = None
435
+ if hasattr(config, "video_config"):
436
+ self.video_processor = OmniImageProcessor(config.video_config)
437
+ self.training = training
438
+ self.relative_path = relative_path
439
+ self.parallel = parallel
440
+ # audio tag
441
+ self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
442
+ self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
443
+ self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
444
+ self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
445
+ self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
446
+ self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
447
+ # image tag
448
+ self.image_start_tag = None
449
+ self.image_end_tag = None
450
+ self.image_pad_tag = None
451
+ self.video_start_tag = None
452
+ self.video_end_tag = None
453
+ # videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
454
+ self.videoframe_start_tag = '<videoframe_start_omni>'
455
+ self.videoframe_end_tag = '<videoframe_end_omni>'
456
+ if hasattr(self.config, "visual_config"):
457
+ # special token for start_tag
458
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
459
+ # special token for end_tag
460
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
461
+ # special token for pad_tag
462
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
463
+ self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
464
+ self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
465
+ if hasattr(self.config, "video_config"):
466
+ self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
467
+ self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
468
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
469
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
470
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
471
+ self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
472
+
473
+ self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
474
+
475
+
476
+ # @lru_cache(maxsize=1024)
477
+ def _get_audio(self, audio_info):
478
+ try:
479
+ audio_info = ujson.loads(audio_info)
480
+ if 'path' in audio_info.keys():
481
+ audio_uri = None
482
+ if os.path.exists(audio_info['path']):
483
+ audio_uri = audio_info['path']
484
+ elif self.relative_path is not None:
485
+ audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
486
+ if not os.path.exists(audio_uri):
487
+ audio_uri = None
488
+ if audio_uri is not None:
489
+ waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
490
+ waveforms = self.audio_processor.split_with_overlap(waveform)
491
+
492
+ ret = OmniProcessorOutput() # 默认初始化 audios字段为None
493
+ for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
494
+ audio, input_length = self.audio_processor.extract_fbank_features(waveform)
495
+ audio = self.audio_processor.data_augment(audio, input_length, self.training)
496
+ encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
497
+ if bridge_length <= 0:
498
+ continue
499
+ current_ret = OmniProcessorOutput(
500
+ audios=[audio[:,:input_length]],
501
+ encoder_length=[encoder_length],
502
+ bridge_length=[bridge_length],
503
+ )
504
+ if ret.audios is None:
505
+ ret = current_ret
506
+ else:
507
+ ret = ret.concatenate(current_ret) # 拼接多个切片
508
+ return ret
509
+ else:
510
+ raise ValueError("can not find path in audio_info")
511
+ except Exception as e:
512
+ print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
513
+ return OmniProcessorOutput()
514
+
515
+ # @lru_cache(maxsize=1024)
516
+ def _get_image(self, image_info):
517
+ try:
518
+ try:
519
+ image_info = ujson.loads(image_info)
520
+ except:
521
+ image_info = re.sub(r"(?<!\\)'", '"', image_info)
522
+ image_info = ujson.loads(image_info)
523
+ if 'base64' in image_info.keys():
524
+ image_data = base64.b64decode(image_info['base64'])
525
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
526
+ elif 'local' in image_info.keys():
527
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
528
+ elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
529
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
530
+ elif 'url' in image_info.keys():
531
+ image_bytes = self._get_vision_obj_byte('url', image_info['url'])
532
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
533
+ else:
534
+ raise ValueError("can not find any path in image_info")
535
+
536
+ merge_length = self.visual_processor.merge_size**2
537
+ patch_nums = np.array(image_list).prod() // merge_length
538
+
539
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
540
+ return OmniProcessorOutput(
541
+ images=[image_feat],
542
+ patch_nums=[patch_nums],
543
+ crop_size=[image_list],
544
+ images_size= [org_size],
545
+ images_grid=[image_list]
546
+ )
547
+ else:
548
+ print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
549
+ return OmniProcessorOutput()
550
+
551
+ except Exception as e:
552
+ print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
553
+ return OmniProcessorOutput()
554
+
555
+ # @lru_cache(maxsize=1024)
556
+ def _get_video_frame(self, video_frame_infos):
557
+ try:
558
+ pattern = r'\{.*?\}'
559
+ matches = re.findall(pattern, video_frame_infos)
560
+ ret = OmniProcessorOutput()
561
+ # 逐个解析
562
+ for match in matches:
563
+ video_frame_info = ujson.loads(match)
564
+ # video_frame_info = ujson.loads(video_frame_info)
565
+ if 'local' in video_frame_info.keys():
566
+ image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
567
+ elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
568
+ image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
569
+ else:
570
+ raise ValueError("can not find any path in video_info")
571
+
572
+ merge_length = self.video_processor.merge_size**2
573
+ patch_nums = np.array(image_list).prod() // merge_length
574
+
575
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
576
+ ret = ret.concatenate(
577
+ OmniProcessorOutput(
578
+ videos=[image_feat],
579
+ videos_patch_nums=[patch_nums],
580
+ videos_crop_size=[image_list],
581
+ videos_size= [org_size],
582
+ videos_grid=[image_list]
583
+ )
584
+ )
585
+ else:
586
+ print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
587
+ return ret
588
+
589
+ except Exception as e:
590
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
591
+ return OmniProcessorOutput()
592
+
593
+ # 读取视频
594
+ def _get_vision_obj_byte(self, source, path):
595
+ vision_obj_byte = None
596
+ if source == "local":
597
+ if os.path.exists(path):
598
+ vision_obj_byte = open(path, "rb").read()
599
+ else:
600
+ vision_obj_byte = None
601
+ if source == "base64":
602
+ vision_obj_byte = base64.b64decode(path)
603
+ if source == "url":
604
+ vision_obj_byte = requests.get(url=path).content
605
+ return vision_obj_byte
606
+
607
+ # 将视频切分为帧,保存至子目录中
608
+ def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
609
+ if decode_way=='1fps':
610
+ frame_suffix = f'_frames'
611
+ elif decode_way=='key':
612
+ frame_suffix = f'_keyframes'
613
+ else:
614
+ raise ValueError('unvalid decode way!!!')
615
+
616
+ server = "local"
617
+ if 'local' in video_info.keys():
618
+ # 本地路径
619
+ video_path = video_info['local']
620
+ # 帧保存本地路径
621
+ frame_path = video_path[:video_path.rfind('.')] + frame_suffix
622
+ mm_obj_byte = self._get_vision_obj_byte('local', video_path)
623
+ elif 'base64' in video_info.keys():
624
+ md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
625
+ if self.relative_path is not None:
626
+ video_path = os.path.join(self.relative_path, md5)
627
+ else:
628
+ video_path = os.path.join(os.getcwd(), md5)
629
+ frame_path = video_path + frame_suffix
630
+ mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
631
+ elif 'url' in video_info.keys():
632
+ md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
633
+ if self.relative_path is not None:
634
+ video_path = os.path.join(self.relative_path, md5)
635
+ else:
636
+ video_path = os.path.join(os.getcwd(), md5)
637
+ frame_path = video_path + frame_suffix
638
+ mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
639
+ else:
640
+ raise ValueError('unvalid video server !!!')
641
+ return ""
642
+
643
+ if mm_obj_byte is None: # 未读取到视频文件
644
+ return ""
645
+ if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
646
+ # 保存帧
647
+ os.makedirs(frame_path, exist_ok=True)
648
+ frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
649
+ for frame_idx, frame in enumerate(frames):
650
+ output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
651
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
652
+ cv2.imwrite(output_filename, frame)
653
+ frame_paths = os.listdir(frame_path)
654
+
655
+ # 选取帧
656
+ frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
657
+ frame_times.sort() #从小到大排序
658
+ frame_number = len(frame_times)
659
+ if frame_number > max_frame_number:
660
+ indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
661
+ else:
662
+ indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
663
+ # 拼接模式
664
+ replace_str = ""
665
+ for frame_idx, idx in enumerate(indices):
666
+ frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
667
+ frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
668
+ frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
669
+ frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
670
+ frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
671
+ frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
672
+ replace_str += frame_str
673
+
674
+ return replace_str
675
+
676
+ def sample_frame(self,frames_str,max_frame = 32):
677
+ def uniform_sample(lst, num_samples):
678
+ if num_samples > len(lst):
679
+ return lst
680
+ interval = len(lst) / num_samples
681
+ samples = [lst[int(i * interval)] for i in range(num_samples)]
682
+ return samples
683
+ p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
684
+ frames_str_split = re.split(p,frames_str)
685
+ frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
686
+ sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
687
+ return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
688
+
689
+ def _get_video_frame_str(self, video_info):
690
+ try:
691
+ if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
692
+ frames_str = video_info
693
+ frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
694
+ return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
695
+ video_info = ujson.loads(video_info)
696
+ # 获取包含多���图像路径的字符串,最大帧数量max_frame_number
697
+ frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
698
+ return frames_str
699
+ except Exception as e:
700
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
701
+ return ""
702
+
703
+ def _replace_image(self, image_text):
704
+ image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
705
+ ret = self._get_image(image_info) # 重复取结果 cached result
706
+ if ret.patch_nums is None:
707
+ return ''
708
+ return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
709
+
710
+ def _replace_video_frame(self, video_frame_text):
711
+ video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
712
+ ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
713
+ if ret.videos_patch_nums is None:
714
+ return ''
715
+ video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
716
+ return ret, ''.join(video_frame_str)
717
+
718
+
719
+ def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
720
+ # 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
721
+ if (self.audio_start_tag != None) and (mtype == 'audio'):
722
+ match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
723
+ drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
724
+ elif (self.image_start_tag != None) and (mtype == 'image'):
725
+ match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
726
+ drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
727
+ elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
728
+ match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
729
+ drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
730
+ elif (self.video_start_tag != None) and (mtype == 'video'):
731
+ match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
732
+ drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
733
+ else:
734
+ raise ValueError("mtype not supportted!")
735
+ new_text_list = []
736
+ new_mm_label_list = []
737
+ new_trainable_flag_list = []
738
+ for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
739
+ for t,m in zip(*split_text(text, match_regex)):
740
+ new_trainable_flag_list.append(trainable)
741
+ if m:
742
+ new_text_list.append(re.sub(drop_regex, '', t))
743
+ new_mm_label_list.append(mtype)
744
+ else:
745
+ new_text_list.append(t)
746
+ new_mm_label_list.append(mm_label)
747
+ return new_text_list, new_mm_label_list, new_trainable_flag_list
748
+
749
+ def process_multimodal_chunk(self, text, mm_label, trainable):
750
+ ret = OmniProcessorOutput()
751
+ if mm_label == 'audio':
752
+ ret = self._get_audio(text)
753
+ if ret.bridge_length is not None:
754
+ ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
755
+ else:
756
+ raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
757
+ elif mm_label == 'audiogen':
758
+ ret = self._get_audio(text)
759
+ if ret.bridge_length is not None:
760
+ ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
761
+ else:
762
+ raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
763
+ elif mm_label == 'image':
764
+ ret, input_str = self._replace_image(text)
765
+ if input_str:
766
+ ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
767
+ else:
768
+ raise ValueError("Get image data Failed at Process image chunk")
769
+ elif mm_label == 'video':
770
+ frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
771
+ ret, input_str = self._replace_video_frame(frame_str)
772
+ if input_str:
773
+ ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
774
+ else:
775
+ raise ValueError("Get video data Failed at Process video chunk")
776
+ elif mm_label == 'text':
777
+ ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
778
+ if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
779
+ raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
780
+ else:
781
+ raise ValueError(f"mm_label not supportted! must in ['audio', 'audiogen', 'image', 'video', 'text'] but get {mm_label}")
782
+ return ret
783
+
784
+ def process_one(self, text, index=0, raw_only=False):
785
+ ret = OmniProcessorOutput(index=index)
786
+ all_text_list = []
787
+ all_mm_label_list = []
788
+ all_trainable_flag_list = []
789
+ text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
790
+ if len(text_list) == 1:
791
+ text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
792
+ all_text_list.append(text)
793
+ all_mm_label_list.append('text')
794
+ all_trainable_flag_list.append(True)
795
+ else:
796
+ for text, match in zip(text_list, match_flag):
797
+ text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
798
+ if text.strip() == '':
799
+ continue # 把多余的空格干掉
800
+ all_text_list.append(text)
801
+ all_mm_label_list.append('text')
802
+ all_trainable_flag_list.append(match)
803
+ # 处理多模态信息
804
+ for mtype in self.config.multimodal: # 循环获取音频 图像结果
805
+ all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
806
+ if len(all_text_list) == 0:
807
+ print(f"Process {text} chunk error: No valid Text data!!!!!")
808
+ return OmniProcessorOutput(index=index)
809
+
810
+ for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
811
+ try:
812
+ mret = self.process_multimodal_chunk(text, mm_label, trainable)
813
+ ret = ret.concatenate(mret)
814
+ except ValueError as e:
815
+ tt = text[:24].replace('\n','<LF>')
816
+ print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
817
+ return OmniProcessorOutput(index=index)
818
+
819
+ if raw_only:
820
+ ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
821
+ return ret
822
+ return ret
823
+
824
+ @torch.no_grad()
825
+ def __call__(self, example, parallel=128):
826
+ if isinstance(example, Dict):
827
+ pass
828
+ elif isinstance(example, str):
829
+ return self.process_one(example)
830
+ elif isinstance(example, List): # batch推理 异步多线程处理
831
+ with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
832
+ future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
833
+ batch_data = [key.result() for key in cf.as_completed(future_list)]
834
+ valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
835
+ assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
836
+ batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
837
+
838
+ ret = OmniProcessorOutput()
839
+ for i in range(len(batch_data)):
840
+ ret = ret.concatenate(batch_data[i])
841
+ self.tokenizer.padding_side = "left"
842
+ max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
843
+ padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
844
+ ret.input_ids = padding_result["input_ids"]
845
+ ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
846
+
847
+ if ret.audios is not None:
848
+ max_audios_len = max([x.shape[-1] for x in ret.audios])
849
+ ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
850
+
851
+ ret.encoder_length = default_collate(ret.encoder_length)
852
+ ret.bridge_length = default_collate(ret.bridge_length)
853
+
854
+ if ret.images is not None:
855
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
856
+ ret.patch_nums = default_collate(ret.patch_nums)
857
+
858
+ if ret.videos is not None:
859
+ ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
860
+ ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
861
+
862
+ return ret
863
+
864
+ else:
865
+ raise ValueError("example format supported yet")
sequence_parallel_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from flash_attn import flash_attn_varlen_func
8
+ try:
9
+ import deepspeed.comm as dist
10
+ except:
11
+ dist = None
12
+
13
+
14
+ try:
15
+ from utils import (
16
+ get_sequence_parallel_group,
17
+ get_sequence_parallel_size,
18
+ get_sequence_parallel_rank
19
+ )
20
+ except (ModuleNotFoundError, ImportError):
21
+ # 从 utils 获取seq parallel设置,import不成功默认为不开启
22
+ get_sequence_parallel_group = lambda : None
23
+ get_sequence_parallel_size = lambda : 1
24
+ get_sequence_parallel_rank = lambda : 0
25
+
26
+
27
+ def single_all_to_all(input, scatter_idx, gather_idx, group):
28
+ seq_world_size = dist.get_world_size(group)
29
+ inp_shape = list(input.shape)
30
+ inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
31
+ if scatter_idx < 2:
32
+ input_t = input.reshape(
33
+ [seq_world_size, inp_shape[scatter_idx]] + \
34
+ inp_shape[scatter_idx + 1:]
35
+ ).contiguous()
36
+ else:
37
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
38
+ input_t = input.reshape(
39
+ [-1, seq_world_size, inp_shape[scatter_idx]] + \
40
+ inp_shape[scatter_idx + 1:]
41
+ ).transpose(0, 1).contiguous()
42
+
43
+ output = torch.empty_like(input_t)
44
+ dist.all_to_all_single(output, input_t, group=group)
45
+
46
+ # if scattering the seq-dim, transpose the heads back to the original dimension
47
+ # [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
48
+ # [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
49
+ if scatter_idx < 2:
50
+ output = output.transpose(0, 1).transpose(1, 2).contiguous()
51
+
52
+ return output.reshape(
53
+ inp_shape[: gather_idx] + \
54
+ [inp_shape[gather_idx] * seq_world_size,] + \
55
+ inp_shape[gather_idx + 1:]).contiguous()
56
+
57
+
58
+ class _SeqAllToAll(torch.autograd.Function):
59
+
60
+ @staticmethod
61
+ def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
62
+ ctx.group = group
63
+ ctx.scatter_idx = scatter_idx
64
+ ctx.gather_idx = gather_idx
65
+
66
+ return single_all_to_all(input, scatter_idx, gather_idx, group)
67
+
68
+ @staticmethod
69
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
70
+ return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
71
+
72
+
73
+ # import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
74
+ # but fix some bugs for 符合训练的维度设置
75
+ class DistributedAttention(nn.Module):
76
+ """Initialization.
77
+
78
+ Arguments:
79
+ local_attention (Module): local attention with q,k,v
80
+ sequence_process_group (ProcessGroup): sequence parallel process group
81
+ scatter_idx (int): scatter_idx for all2all comm
82
+ gather_idx (int): gather_idx for all2all comm
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ local_attention: nn.Module,
88
+ sequence_process_group: 'dist.ProcessGroup',
89
+ scatter_idx: int = 2,
90
+ gather_idx: int = 0,
91
+ ) -> None:
92
+
93
+ super(DistributedAttention, self).__init__()
94
+ self.local_attn = local_attention
95
+ self.spg = sequence_process_group
96
+ self.scatter_idx = scatter_idx
97
+ self.gather_idx = gather_idx
98
+
99
+ def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
100
+ # 将输入的head 维度pad到sp_size的倍数
101
+ sp_size = torch.distributed.get_world_size(self.spg)
102
+ pad_size = (sp_size - query.size(1) % sp_size) % sp_size
103
+ if pad_size > 0:
104
+ # [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
105
+ query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
106
+ key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
107
+ value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
108
+ return query, key, value
109
+
110
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
111
+ """ forward
112
+
113
+ Arguments:
114
+ query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
115
+ key (Tensor): key input to the layer
116
+ value (Tensor): value input to the layer
117
+ args: other args
118
+
119
+ Returns:
120
+ * output (Tensor): context output
121
+ """
122
+ # TODO Merge three alltoall calls into one
123
+ # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
124
+ # [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
125
+ origin_num_head = query.size(1)
126
+ query, key, value = self.pad_attention_head(query,key,value)
127
+
128
+ query = query.transpose(1,2).transpose(0,1)
129
+ key = key.transpose(1,2).transpose(0,1)
130
+ value = value.transpose(1,2).transpose(0,1)
131
+ #in shape : e.g., [s/p,bs,h,head_dim]
132
+ query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
133
+ key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
134
+ value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
135
+
136
+ context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
137
+ context_layer = context_layer.transpose(0,1).contiguous()
138
+ # [seq_len, batch_size, num_head, head_dim]
139
+ output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
140
+ return output.transpose(0,1)[:,:,:origin_num_head,:]
141
+
142
+
143
+ class LocalAttention(nn.Module):
144
+ def __init__(self, hidden_size, num_heads, head_dim):
145
+ super().__init__()
146
+ self.hidden_size = hidden_size
147
+ self.num_heads = num_heads
148
+ self.head_dim = head_dim
149
+
150
+ def forward(self, q, k, v, *args, use_flash=True, **kwargs):
151
+ # input q,k,v [batch_size, num_head, seq_len, head_dim]
152
+ # output [batch_size, seq_len, num_head, head_dim]
153
+ if use_flash:
154
+ q_len, num_heads = q.shape[2], q.shape[1]
155
+ q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
156
+ k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
157
+ v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
158
+ return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
159
+ else:
160
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
161
+ attn_output = F.scaled_dot_product_attention(
162
+ q,k,v, *args, **kwargs)
163
+ attn_output = attn_output.transpose(1, 2)
164
+ return attn_output
165
+
166
+
167
+ def create_attention_layer(hidden_size, num_heads, head_dim):
168
+ if get_sequence_parallel_group() is None:
169
+ return LocalAttention(hidden_size, num_heads, head_dim)
170
+ else:
171
+ return DistributedAttention(
172
+ local_attention=LocalAttention(hidden_size, num_heads, head_dim),
173
+ sequence_process_group=get_sequence_parallel_group()
174
+ )
175
+
176
+
177
+ def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
178
+ assert tensor.size(dim) % get_sequence_parallel_size() == 0
179
+ original_size = tensor.size(dim)
180
+ if shift:
181
+ tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
182
+ if get_sequence_parallel_group() is None:
183
+ return tensor
184
+ else:
185
+ chunk_size = original_size // get_sequence_parallel_size()
186
+ return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]
special_tokens_map.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<B_SYS>",
17
+ "<B_USYS>",
18
+ "<C_Q>",
19
+ "<C_A>",
20
+ "<B_FUNC>",
21
+ "<B_CODE>",
22
+ "<B_APE>",
23
+ "<function_calling>",
24
+ "<calc_start>",
25
+ "<calc_end>",
26
+ "<inner_think>",
27
+ "<audio_start_baichuan>",
28
+ "<audio_end_baichuan>",
29
+ "<audio_pad_baichuan>",
30
+ "<img_start_baichuan>",
31
+ "<img_end_baichuan>",
32
+ "<img_pad_baichuan>",
33
+ "<img_newline_baichuan>",
34
+ "<box_start_baichuan>",
35
+ "<box_end_baichuan>",
36
+ "<box_delim_baichuan>",
37
+ "<ref_start_baichuan>",
38
+ "<ref_end_baichuan>",
39
+ "<img_delim_baichuan>",
40
+ "<polygon_start_baichuan>",
41
+ "<polygon_end_baichuan>",
42
+ "<baichuan_pad_token>",
43
+ "<reserved_113>",
44
+ "<audio_delim_baichuan>",
45
+ "<video_start_baichuan>",
46
+ "<video_end_baichuan>",
47
+ "<video_palce_baichuan>",
48
+ "<audiotext_start_baichuan>",
49
+ "<audiotext_end_baichuan>",
50
+ "<audiotext_pad_baichuan>",
51
+ "<audiogen_start_baichuan>",
52
+ "<audiogen_end_baichuan>"
53
+ ],
54
+ "eos_token": {
55
+ "content": "<|endoftext|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false
60
+ },
61
+ "pad_token": {
62
+ "content": "<|endoftext|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false
67
+ }
68
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "<B_SYS>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "151647": {
37
+ "content": "<B_USYS>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "151648": {
45
+ "content": "<C_Q>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "151649": {
53
+ "content": "<C_A>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "151650": {
61
+ "content": "<B_FUNC>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "151651": {
69
+ "content": "<B_CODE>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "151652": {
77
+ "content": "<B_APE>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": true,
82
+ "special": true
83
+ },
84
+ "151653": {
85
+ "content": "<function_calling>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": true,
90
+ "special": true
91
+ },
92
+ "151654": {
93
+ "content": "<calc_start>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": true,
98
+ "special": true
99
+ },
100
+ "151655": {
101
+ "content": "<calc_end>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": true,
106
+ "special": true
107
+ },
108
+ "151656": {
109
+ "content": "<inner_think>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": true,
114
+ "special": true
115
+ },
116
+ "151657": {
117
+ "content": "<audio_start_baichuan>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "151658": {
125
+ "content": "<audio_end_baichuan>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "151659": {
133
+ "content": "<audio_pad_baichuan>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "151660": {
141
+ "content": "<img_start_baichuan>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "151661": {
149
+ "content": "<img_end_baichuan>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "151662": {
157
+ "content": "<img_pad_baichuan>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "151663": {
165
+ "content": "<img_newline_baichuan>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "151664": {
173
+ "content": "<box_start_baichuan>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "151665": {
181
+ "content": "<box_end_baichuan>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "151666": {
189
+ "content": "<box_delim_baichuan>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "151667": {
197
+ "content": "<ref_start_baichuan>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "151668": {
205
+ "content": "<ref_end_baichuan>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "151669": {
213
+ "content": "<img_delim_baichuan>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "151670": {
221
+ "content": "<polygon_start_baichuan>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "151671": {
229
+ "content": "<polygon_end_baichuan>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "151672": {
237
+ "content": "<baichuan_pad_token>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "151673": {
245
+ "content": "<reserved_113>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "151674": {
253
+ "content": "<audio_delim_baichuan>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "151675": {
261
+ "content": "<audiotext_start_baichuan>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "151676": {
269
+ "content": "<audiotext_end_baichuan>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "151677": {
277
+ "content": "<audiotext_pad_baichuan>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "151678": {
285
+ "content": "<audiogen_start_baichuan>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "151679": {
293
+ "content": "<audiogen_end_baichuan>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ }
300
+ },
301
+ "additional_special_tokens": [
302
+ "<|im_start|>",
303
+ "<|im_end|>",
304
+ "<B_SYS>",
305
+ "<B_USYS>",
306
+ "<C_Q>",
307
+ "<C_A>",
308
+ "<B_FUNC>",
309
+ "<B_CODE>",
310
+ "<B_APE>",
311
+ "<function_calling>",
312
+ "<calc_start>",
313
+ "<calc_end>",
314
+ "<inner_think>",
315
+ "<audio_start_baichuan>",
316
+ "<audio_end_baichuan>",
317
+ "<audio_pad_baichuan>",
318
+ "<img_start_baichuan>",
319
+ "<img_end_baichuan>",
320
+ "<img_pad_baichuan>",
321
+ "<img_newline_baichuan>",
322
+ "<box_start_baichuan>",
323
+ "<box_end_baichuan>",
324
+ "<box_delim_baichuan>",
325
+ "<ref_start_baichuan>",
326
+ "<ref_end_baichuan>",
327
+ "<img_delim_baichuan>",
328
+ "<polygon_start_baichuan>",
329
+ "<polygon_end_baichuan>",
330
+ "<baichuan_pad_token>",
331
+ "<reserved_113>",
332
+ "<audio_delim_baichuan>",
333
+ "<audiotext_start_baichuan>",
334
+ "<audiotext_end_baichuan>",
335
+ "<audiotext_pad_baichuan>",
336
+ "<audiogen_start_baichuan>",
337
+ "<audiogen_end_baichuan>"
338
+ ],
339
+ "bos_token": null,
340
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}{{'<B_SYS>' + message['content']}}{% elif message['role'] == 'user_system' %}{{'<B_USYS>' + message['content']}}{% elif message['role'] == 'user' %}{{'<H_Q>' + message['content']}}{% elif message['role'] == 'assistant' %}{{'<H_A>' + message['content']}}{% elif message['role'] == 'function' %}{{'<B_FUNC>' + message['content']}}{% elif message['role'] == 'code' %}{{'<B_CODE>' + message['content']}}{% else %}{{ raise_exception('Invalid message role: ' + message['role']) }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<H_A>'}}{% endif %}",
341
+ "clean_up_tokenization_spaces": false,
342
+ "eos_token": "<|endoftext|>",
343
+ "errors": "replace",
344
+ "model_max_length": 8192,
345
+ "pad_token": "<|endoftext|>",
346
+ "split_special_tokens": false,
347
+ "tokenizer_class": "Qwen2Tokenizer",
348
+ "unk_token": null
349
+ }
vector_quantize.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch.nn import functional as F
3
+ from torch import nn
4
+ import numpy as np
5
+ from torch.cuda.amp import autocast
6
+
7
+ def uniform_init(*shape):
8
+ t = torch.zeros(shape)
9
+ nn.init.kaiming_uniform_(t)
10
+ return t
11
+
12
+ def cdist(x, y):
13
+ x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
14
+ y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
15
+ xy = torch.einsum('bd,cd->bc', x, y) * -2
16
+ return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
17
+
18
+ def get_sequence_mask(inputs, inputs_length):
19
+ if inputs.dim() == 3:
20
+ bsz, tgt_len, _ = inputs.size()
21
+ else:
22
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
23
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
24
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
25
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
26
+ return sequence_mask, unpacking_index
27
+
28
+
29
+ class EuclideanCodebook(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim,
33
+ codebook_size,
34
+ init_std=0.02,
35
+ ):
36
+ super().__init__()
37
+ self.init_std = init_std
38
+ self.dim = dim
39
+ self.codebook_size = codebook_size
40
+
41
+ embed = uniform_init(codebook_size, dim).to(torch.float32)
42
+ self.cluster_size = nn.Parameter(torch.ones(codebook_size))
43
+ self.embed_avg = nn.Parameter(embed.clone())
44
+ self.embed = nn.Parameter(embed)
45
+ del embed
46
+
47
+ @autocast(enabled=True, dtype=torch.float32)
48
+ @torch.no_grad()
49
+ def forward(self, x):
50
+ assert(len(x.shape) == 2)
51
+ assert(x.dtype == torch.float32)
52
+ embed = self.embed.detach().to(x.device)
53
+ dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
54
+ embed_ind = dist.argmax(dim=-1)
55
+ quantize = embed[embed_ind] # (bs*sl, d)
56
+ return quantize, embed_ind, dist
57
+
58
+ class VectorQuantize(nn.Module):
59
+ def __init__(self, config, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ self.config = config
62
+ self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
63
+
64
+ def forward(self, x, input_length):
65
+ batch_size, seq_len, _ = x.shape
66
+ mask, unpacking_index = get_sequence_mask(x, input_length)
67
+ if x.dtype != torch.float32:
68
+ x = x.to(torch.float32)
69
+ x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
70
+ quantize, embed_ind, _ = self.codebook(x)
71
+ quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
72
+ quantize = torch.where(mask, quantize, 0)
73
+ embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
74
+ embed_ind = torch.where(mask, embed_ind, -1).squeeze()
75
+ return quantize, embed_ind
76
+
77
+ def get_output_from_indices(self, indices):
78
+ return self.codebook.embed[indices]
visual_modeling_omni.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch, math
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ import transformers
7
+ from flash_attn import flash_attn_varlen_func
8
+ from transformers.activations import ACT2FN
9
+ from PIL import Image
10
+ import io, fire
11
+ from torch.nn import functional as F
12
+
13
+ class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.config_attn_implementation = 'flash_attention_2'
17
+ self.gradient_checkpointing = True # 强制开启
18
+ self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
19
+ self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
20
+ del self.merger
21
+
22
+ def forward(
23
+ self,
24
+ pixel_values: torch.Tensor,
25
+ grid_thw: torch.Tensor,
26
+ ):
27
+ hidden_states = pixel_values.to(self.get_dtype())
28
+ grid_thw = grid_thw.to(pixel_values.device)
29
+
30
+ hidden_states = self.patch_embed(hidden_states)
31
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
32
+
33
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
34
+ dim=0, dtype=torch.int32
35
+ )
36
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
37
+
38
+ for blk in self.blocks:
39
+ if self.gradient_checkpointing and self.training:
40
+ hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
41
+ else:
42
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
43
+
44
+ return hidden_states
45
+
46
+ @torch.no_grad()
47
+ def fake_input(self, device):
48
+ merge_size = max(self.merge_size, self.config.spatial_merge_size)
49
+ fake_image = torch.zeros([
50
+ 1,
51
+ self.config.temporal_patch_size,
52
+ 3,
53
+ merge_size // self.config.spatial_merge_size,
54
+ self.config.spatial_merge_size,
55
+ self.config.patch_size,
56
+ merge_size // self.config.spatial_merge_size,
57
+ self.config.spatial_merge_size,
58
+ self.config.patch_size,
59
+ ], dtype=torch.float32, device=device)
60
+ patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
61
+ flatten_patches = patches.reshape(
62
+ merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
63
+ )
64
+ return [flatten_patches], [(1, merge_size, merge_size)], [1]
65
+
66
+
67
+ class OmniVisualBridge(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ self.config = config
71
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
72
+ self.hidden_size = config.embed_dim * (self.merge_size**2)
73
+ self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
74
+ self.mlp = nn.Sequential(
75
+ nn.Linear(self.hidden_size, self.hidden_size),
76
+ nn.GELU(),
77
+ nn.Linear(self.hidden_size, config.hidden_size),
78
+ )
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
82
+ return x
83
+
84
+
85
+ if __name__ == '__main__':
86
+ fire.Fire()
87
+
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
zero_to_fp32.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+
24
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
+ # DeepSpeed data structures it has to be available in the current python environment.
26
+ from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
30
+
31
+
32
+ @dataclass
33
+ class zero_model_state:
34
+ buffers: dict()
35
+ param_shapes: dict()
36
+ shared_params: list
37
+ ds_version: int
38
+ frozen_param_shapes: dict()
39
+ frozen_param_fragments: dict()
40
+
41
+
42
+ debug = 0
43
+
44
+ # load to cpu
45
+ device = torch.device('cpu')
46
+
47
+
48
+ def atoi(text):
49
+ return int(text) if text.isdigit() else text
50
+
51
+
52
+ def natural_keys(text):
53
+ '''
54
+ alist.sort(key=natural_keys) sorts in human order
55
+ http://nedbatchelder.com/blog/200712/human_sorting.html
56
+ (See Toothy's implementation in the comments)
57
+ '''
58
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
59
+
60
+
61
+ def get_model_state_file(checkpoint_dir, zero_stage):
62
+ if not os.path.isdir(checkpoint_dir):
63
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
+
65
+ # there should be only one file
66
+ if zero_stage <= 2:
67
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
+ elif zero_stage == 3:
69
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
70
+
71
+ if not os.path.exists(file):
72
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
73
+
74
+ return file
75
+
76
+
77
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
78
+ # XXX: need to test that this simple glob rule works for multi-node setup too
79
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
80
+
81
+ if len(ckpt_files) == 0:
82
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
83
+
84
+ return ckpt_files
85
+
86
+
87
+ def get_optim_files(checkpoint_dir):
88
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
89
+
90
+
91
+ def get_model_state_files(checkpoint_dir):
92
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
93
+
94
+
95
+ def parse_model_states(files):
96
+ zero_model_states = []
97
+ for file in files:
98
+ state_dict = torch.load(file, map_location=device)
99
+
100
+ if BUFFER_NAMES not in state_dict:
101
+ raise ValueError(f"{file} is not a model state checkpoint")
102
+ buffer_names = state_dict[BUFFER_NAMES]
103
+ if debug:
104
+ print("Found buffers:", buffer_names)
105
+
106
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
107
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
108
+ param_shapes = state_dict[PARAM_SHAPES]
109
+
110
+ # collect parameters that are included in param_shapes
111
+ param_names = []
112
+ for s in param_shapes:
113
+ for name in s.keys():
114
+ param_names.append(name)
115
+
116
+ # update with frozen parameters
117
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
118
+ if frozen_param_shapes is not None:
119
+ if debug:
120
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
121
+ param_names += list(frozen_param_shapes.keys())
122
+
123
+ # handle shared params
124
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
125
+
126
+ ds_version = state_dict.get(DS_VERSION, None)
127
+
128
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
+
130
+ z_model_state = zero_model_state(buffers=buffers,
131
+ param_shapes=param_shapes,
132
+ shared_params=shared_params,
133
+ ds_version=ds_version,
134
+ frozen_param_shapes=frozen_param_shapes,
135
+ frozen_param_fragments=frozen_param_fragments)
136
+ zero_model_states.append(z_model_state)
137
+
138
+ return zero_model_states
139
+
140
+
141
+ def parse_optim_states(files, ds_checkpoint_dir):
142
+
143
+ total_files = len(files)
144
+ state_dicts = []
145
+ for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
+
152
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
154
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
155
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
156
+
157
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
158
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
159
+ # use the max of the partition_count to get the dp world_size.
160
+
161
+ if type(world_size) is list:
162
+ world_size = max(world_size)
163
+
164
+ if world_size != total_files:
165
+ raise ValueError(
166
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
167
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
168
+ )
169
+
170
+ # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
+ elif zero_stage == 3:
174
+ fp32_groups_key = FP32_FLAT_GROUPS
175
+ else:
176
+ raise ValueError(f"unknown zero stage {zero_stage}")
177
+
178
+ if zero_stage <= 2:
179
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
+ elif zero_stage == 3:
181
+ # if there is more than one param group, there will be multiple flattened tensors - one
182
+ # flattened tensor per group - for simplicity merge them into a single tensor
183
+ #
184
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
185
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
186
+
187
+ fp32_flat_groups = [
188
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
189
+ ]
190
+
191
+ return zero_stage, world_size, fp32_flat_groups
192
+
193
+
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
195
+ """
196
+ Returns fp32 state_dict reconstructed from ds checkpoint
197
+
198
+ Args:
199
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
200
+
201
+ """
202
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
203
+
204
+ optim_files = get_optim_files(ds_checkpoint_dir)
205
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
206
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
207
+
208
+ model_files = get_model_state_files(ds_checkpoint_dir)
209
+
210
+ zero_model_states = parse_model_states(model_files)
211
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
+
213
+ if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
215
+ exclude_frozen_parameters)
216
+ elif zero_stage == 3:
217
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
218
+ exclude_frozen_parameters)
219
+
220
+
221
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
222
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
223
+ return
224
+
225
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
226
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
227
+
228
+ if debug:
229
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
230
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
231
+
232
+ wanted_params = len(frozen_param_shapes)
233
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
234
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
235
+ print(f'Frozen params: Have {avail_numel} numels to process.')
236
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
237
+
238
+ total_params = 0
239
+ total_numel = 0
240
+ for name, shape in frozen_param_shapes.items():
241
+ total_params += 1
242
+ unpartitioned_numel = shape.numel()
243
+ total_numel += unpartitioned_numel
244
+
245
+ state_dict[name] = frozen_param_fragments[name]
246
+
247
+ if debug:
248
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
249
+
250
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
251
+
252
+
253
+ def _has_callable(obj, fn):
254
+ attr = getattr(obj, fn, None)
255
+ return callable(attr)
256
+
257
+
258
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
259
+ param_shapes = zero_model_states[0].param_shapes
260
+
261
+ # Reconstruction protocol:
262
+ #
263
+ # XXX: document this
264
+
265
+ if debug:
266
+ for i in range(world_size):
267
+ for j in range(len(fp32_flat_groups[0])):
268
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
269
+
270
+ # XXX: memory usage doubles here (zero2)
271
+ num_param_groups = len(fp32_flat_groups[0])
272
+ merged_single_partition_of_fp32_groups = []
273
+ for i in range(num_param_groups):
274
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
275
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
276
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
277
+ avail_numel = sum(
278
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
279
+
280
+ if debug:
281
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
282
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
283
+ # not asserting if there is a mismatch due to possible padding
284
+ print(f"Have {avail_numel} numels to process.")
285
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
286
+
287
+ # params
288
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
289
+ # out-of-core computing solution
290
+ total_numel = 0
291
+ total_params = 0
292
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
293
+ offset = 0
294
+ avail_numel = full_single_fp32_vector.numel()
295
+ for name, shape in shapes.items():
296
+
297
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
298
+ total_numel += unpartitioned_numel
299
+ total_params += 1
300
+
301
+ if debug:
302
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
303
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
304
+ offset += unpartitioned_numel
305
+
306
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
307
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
308
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
309
+ # live optimizer object, so we are checking that the numbers are within the right range
310
+ align_to = 2 * world_size
311
+
312
+ def zero2_align(x):
313
+ return align_to * math.ceil(x / align_to)
314
+
315
+ if debug:
316
+ print(f"original offset={offset}, avail_numel={avail_numel}")
317
+
318
+ offset = zero2_align(offset)
319
+ avail_numel = zero2_align(avail_numel)
320
+
321
+ if debug:
322
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
323
+
324
+ # Sanity check
325
+ if offset != avail_numel:
326
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
327
+
328
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
329
+
330
+
331
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
332
+ exclude_frozen_parameters):
333
+ state_dict = OrderedDict()
334
+
335
+ # buffers
336
+ buffers = zero_model_states[0].buffers
337
+ state_dict.update(buffers)
338
+ if debug:
339
+ print(f"added {len(buffers)} buffers")
340
+
341
+ if not exclude_frozen_parameters:
342
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
343
+
344
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
345
+
346
+ # recover shared parameters
347
+ for pair in zero_model_states[0].shared_params:
348
+ if pair[1] in state_dict:
349
+ state_dict[pair[0]] = state_dict[pair[1]]
350
+
351
+ return state_dict
352
+
353
+
354
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
355
+ remainder = unpartitioned_numel % world_size
356
+ padding_numel = (world_size - remainder) if remainder else 0
357
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
358
+ return partitioned_numel, padding_numel
359
+
360
+
361
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
362
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
363
+ return
364
+
365
+ if debug:
366
+ for i in range(world_size):
367
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
368
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
369
+
370
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
371
+ wanted_params = len(frozen_param_shapes)
372
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
373
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
374
+ print(f'Frozen params: Have {avail_numel} numels to process.')
375
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
376
+
377
+ total_params = 0
378
+ total_numel = 0
379
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
380
+ total_params += 1
381
+ unpartitioned_numel = shape.numel()
382
+ total_numel += unpartitioned_numel
383
+
384
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
385
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
386
+
387
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
388
+
389
+ if debug:
390
+ print(
391
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
392
+ )
393
+
394
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
395
+
396
+
397
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
398
+ param_shapes = zero_model_states[0].param_shapes
399
+ avail_numel = fp32_flat_groups[0].numel() * world_size
400
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
401
+ # param, re-consolidating each param, while dealing with padding if any
402
+
403
+ # merge list of dicts, preserving order
404
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
405
+
406
+ if debug:
407
+ for i in range(world_size):
408
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
409
+
410
+ wanted_params = len(param_shapes)
411
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
412
+ # not asserting if there is a mismatch due to possible padding
413
+ avail_numel = fp32_flat_groups[0].numel() * world_size
414
+ print(f"Trainable params: Have {avail_numel} numels to process.")
415
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
416
+
417
+ # params
418
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
419
+ # out-of-core computing solution
420
+ offset = 0
421
+ total_numel = 0
422
+ total_params = 0
423
+ for name, shape in param_shapes.items():
424
+
425
+ unpartitioned_numel = shape.numel()
426
+ total_numel += unpartitioned_numel
427
+ total_params += 1
428
+
429
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
430
+
431
+ if debug:
432
+ print(
433
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
434
+ )
435
+
436
+ # XXX: memory usage doubles here
437
+ state_dict[name] = torch.cat(
438
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
439
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
440
+ offset += partitioned_numel
441
+
442
+ offset *= world_size
443
+
444
+ # Sanity check
445
+ if offset != avail_numel:
446
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
447
+
448
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
449
+
450
+
451
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
452
+ exclude_frozen_parameters):
453
+ state_dict = OrderedDict()
454
+
455
+ # buffers
456
+ buffers = zero_model_states[0].buffers
457
+ state_dict.update(buffers)
458
+ if debug:
459
+ print(f"added {len(buffers)} buffers")
460
+
461
+ if not exclude_frozen_parameters:
462
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
463
+
464
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
465
+
466
+ # recover shared parameters
467
+ for pair in zero_model_states[0].shared_params:
468
+ if pair[1] in state_dict:
469
+ state_dict[pair[0]] = state_dict[pair[1]]
470
+
471
+ return state_dict
472
+
473
+
474
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
475
+ """
476
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
477
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
478
+ via a model hub.
479
+
480
+ Args:
481
+ - ``checkpoint_dir``: path to the desired checkpoint folder
482
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
483
+ - ``exclude_frozen_parameters``: exclude frozen parameters
484
+
485
+ Returns:
486
+ - pytorch ``state_dict``
487
+
488
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
489
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
490
+ the checkpoint.
491
+
492
+ A typical usage might be ::
493
+
494
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
495
+ # do the training and checkpoint saving
496
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
497
+ model = model.cpu() # move to cpu
498
+ model.load_state_dict(state_dict)
499
+ # submit to model hub or save the model to share with others
500
+
501
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
502
+ application. i.e. you will need to re-initialize the deepspeed engine, since
503
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
504
+
505
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
506
+
507
+ """
508
+ if tag is None:
509
+ latest_path = os.path.join(checkpoint_dir, 'latest')
510
+ if os.path.isfile(latest_path):
511
+ with open(latest_path, 'r') as fd:
512
+ tag = fd.read().strip()
513
+ else:
514
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
515
+
516
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
517
+
518
+ if not os.path.isdir(ds_checkpoint_dir):
519
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
520
+
521
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
522
+
523
+
524
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
525
+ """
526
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
527
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
528
+
529
+ Args:
530
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
531
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
532
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
533
+ - ``exclude_frozen_parameters``: exclude frozen parameters
534
+ """
535
+
536
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
537
+ print(f"Saving fp32 state dict to {output_file}")
538
+ torch.save(state_dict, output_file)
539
+
540
+
541
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
542
+ """
543
+ 1. Put the provided model to cpu
544
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
545
+ 3. Load it into the provided model
546
+
547
+ Args:
548
+ - ``model``: the model object to update
549
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
550
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
551
+
552
+ Returns:
553
+ - ``model`: modified model
554
+
555
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
556
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
557
+ conveniently placed for you in the checkpoint folder.
558
+
559
+ A typical usage might be ::
560
+
561
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
562
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
563
+ # submit to model hub or save the model to share with others
564
+
565
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
566
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
567
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
568
+
569
+ """
570
+ logger.info(f"Extracting fp32 weights")
571
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
572
+
573
+ logger.info(f"Overwriting model with fp32 weights")
574
+ model = model.cpu()
575
+ model.load_state_dict(state_dict, strict=False)
576
+
577
+ return model
578
+
579
+
580
+ if __name__ == "__main__":
581
+
582
+ parser = argparse.ArgumentParser()
583
+ parser.add_argument("checkpoint_dir",
584
+ type=str,
585
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
586
+ parser.add_argument(
587
+ "output_file",
588
+ type=str,
589
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
590
+ parser.add_argument("-t",
591
+ "--tag",
592
+ type=str,
593
+ default=None,
594
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
595
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
596
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
597
+ args = parser.parse_args()
598
+
599
+ debug = args.debug
600
+
601
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
602
+ args.output_file,
603
+ tag=args.tag,
604
+ exclude_frozen_parameters=args.exclude_frozen_parameters)