Upload folder using huggingface_hub
Browse files- README.md +1114 -0
- added_tokens.json +39 -0
- audio_modeling_omni.py +658 -0
- config.json +240 -0
- configuration_omni.py +120 -0
- flow_matching.py +791 -0
- generation_config.json +6 -0
- generation_utils.py +83 -0
- matcha_components.py +189 -0
- matcha_feat.py +107 -0
- matcha_transformer.py +480 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_omni.py +1011 -0
- processor_omni.py +865 -0
- sequence_parallel_utils.py +186 -0
- special_tokens_map.json +68 -0
- tokenizer.json +0 -0
- tokenizer_config.json +349 -0
- vector_quantize.py +78 -0
- visual_modeling_omni.py +87 -0
- vocab.json +0 -0
- zero_to_fp32.py +604 -0
README.md
ADDED
@@ -0,0 +1,1114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
<div align="center">
|
5 |
+
|
6 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/logo.png" width="300em" ></img>
|
7 |
+
|
8 |
+
<!-- <img src="https://raw.githubusercontent.com/baichuan-inc/Baichuan-Omni-1.5/refs/heads/main/assets/logo.png" width="300em" ></img>
|
9 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/train-pipeline.png" width="300em" ></img> -->
|
10 |
+
<!-- <img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpm-o-26-framework-v2.png" width="300em" ></img> -->
|
11 |
+
**Open-source Omni-modal Foundation Model Supporting Text, Image, Video, and Audio Inputs as Well as Text and Audio Outputs**
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
Baichuan-Omni-1.5 <a href="https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5">🤗</a> | Baichuan-Omni-1.5-Base <a href="https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5-Base">🤗</a> |Github <a href="https://github.com/baichuan-inc/Baichuan-Omni-1.5/">📖 </a> | Report <a href="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/baichuan_omni_1_5.pdf">📖</a>
|
17 |
+
</p>
|
18 |
+
</p>
|
19 |
+
<p align="center">
|
20 |
+
OpenMM-Medical <a href="https://huggingface.co/datasets/baichuan-inc/OpenMM-Medical">🤗</a> | OpenAudioBench <a href="https://huggingface.co/datasets/baichuan-inc/OpenAudioBench">🤗</a>
|
21 |
+
</p>
|
22 |
+
</div>
|
23 |
+
|
24 |
+
|
25 |
+
<!-- ## 介绍
|
26 |
+
**Baichuan-Omni-1.5** 是从 Baichuan-omni 升级的最新的、端到端训练的、支持全模态输入/双模态输出的多模态大模型。该模型使用Qwen2.5-7B昨晚大语言模型基座,可以以端到端方式,接受图像、视频、文本、音频作为输入,并且以可控的方式生成高质量文本和语音。
|
27 |
+
|
28 |
+
- **Baichuan-Omni-1.5-Base**: 为促进全模态大模型发展,我们开源了使用高质量海量数据训练的全模态基座模型。该模型未经SFT指令微调,可塑性强,是**业内首个**开源的**全模态基座模型**。
|
29 |
+
|
30 |
+
- **Baichuan-Omni-1.5**: 基于性能强悍的Baichuan-Omni-1.5-base,使用高质量的全模态对齐数据,进行端到端的多模态指令数据训练。Baichuan-Omni-1.5的纯文本、图像、视频、音频理解能力达到了 GPT-4o-mini 级别。可控音频生成的能力十分强大,在xxx和xxx评测集上取得最高表现。 -->
|
31 |
+
|
32 |
+
|
33 |
+
## Baichuan-Omni-1.5
|
34 |
+
|
35 |
+
The Baichuan-Omni-1.5 is the latest, top-performing model in the Baichuan-omni series. This model is trained and inferred in an end-to-end manner. Compared with Baichuan-omni, this model has significant improvements in text/image/audio/video understanding and text/audio generation, and supports new features such as controllable real-time voice conversations and multi-modal real-time interactions. The main features of Baichuan-Omni-1.5 include:
|
36 |
+
|
37 |
+
- 🔥 **Possess Multimodal Understanding and Interaction Capabilities.**
|
38 |
+
Baichuan-Omni-1.5 not only supports images, videos, text, and audio as input, and generates high-quality text and voice output, but also **supports continuous video and audio streaming, and real-time voice interaction with users**. In OminiBench, a comprehensive evaluation benchmark for omnimodal understanding, Baichuan-Omni-1.5 has achieved the first-class level of the open source community and surpassed GPT-4o-mini.
|
39 |
+
|
40 |
+
- 💪 **Strong Visual Capability.**
|
41 |
+
Baichuan-Omni-1.5 has an average score of 73.3 on the OpenCompass list (comprehensive 10 mainstream multimodal evaluation benchmarks). **With the size of 7B, it surpasses mainstream commercial closed-source multimodal large models such as GPT-4o-mini, Gemini 1.5 Pro and Claude 3.5 Sonnet in single-image understanding**. In addition, its video understanding performance is also better than GPT-4V and Claude 3.5 Sonnet and open source omnimodal models.
|
42 |
+
|
43 |
+
- 🚀 **Leading Medical Image Understanding Capabilities.**
|
44 |
+
Baichuan-Omni-1.5 achieved the best performance on GMAI-MMBench and Openmm-Medical. Using only 7B LLM, the average score exceeded Qwen2-VL-72b by 3%, i.e. 80.7% v.s 83.8%.
|
45 |
+
|
46 |
+
- 🎙 **Excellent Voice Capabilities.**
|
47 |
+
Baichuan-Omni-1.5 **supports high-quality, controllable voice bilingual real-time conversations in Chinese and English**. It **outperforms GPT-4o-realtime** in speech understanding tasks (such as ASR and STT, etc.), and demonstrates **the highest speech generation performance among open source models** in semantic and acoustic evaluation of voice conversations.
|
48 |
+
|
49 |
+
- 🎬 **Powerful Real-world Understanding and Other Features.**
|
50 |
+
Baichuan-Omni-1.5 further optimizes the many visual understanding capabilities of Baichuan-omni. It can process images of any aspect ratio and up to 1.8 million pixels (such as 1344x1344). It scored 68.8 points on RealWorldQA, **surpassing commercial closed-source models such as GPT-4o-mini** and recently open-sourced omnimodal models. It scored 85.6/83.6 on the English/Chinese evaluation subsets of MMBench, respectively, which is also in the first echelon of models with the same size.
|
51 |
+
|
52 |
+
- 💫 **Provides [🤗 Base Model](https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5-Base) and [🤗 Instruct Model](https://huggingface.co/baichuan-inc/Baichuan-Omni-1d5).**
|
53 |
+
Baichuan-Omni-1.5-Base is a high-performance foundational omni-modal model in the industry. Based on the powerful base, Baichuan-Omni-1.5 employs high-quality omnimodal alignment data to perform end-to-end multimodal instruction data training.
|
54 |
+
|
55 |
+
**Model Architecture**
|
56 |
+
<div align="center">
|
57 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/train-pipeline.png", width=80%></img>
|
58 |
+
|
59 |
+
</div>
|
60 |
+
|
61 |
+
<br>
|
62 |
+
|
63 |
+
- **End-to-end Omni-modal Architecture.** We carefully design **multi-stage and end-to-end** progressive training of different modal encoding/decoding modules to make full use of the rich knowledge in different modalities, we expect different modal knowledge to complement each other.
|
64 |
+
Notably, the model is fully trained end-to-end using NTP loss in the whole pre-training stage.
|
65 |
+
- **High-quality Controllable Audio Solution.** Multimodal system prompts have been redesigned to include traditional text system prompts and **speech system prompts** for specifying model sounds. It provides the flexibility to control voice style through text or speech samples at inference time, and supports advanced capabilities such as end-to-end voice cloning and timbre creation.
|
66 |
+
|
67 |
+
|
68 |
+
### Open-source Evaluation Datasets
|
69 |
+
|
70 |
+
**OpenMM-Medical**
|
71 |
+
|
72 |
+
To comprehensively evaluate the model's multi-modal medical capabilities, we have constructed OpenMM-Medical, which includes data from 42 publicly available medical image datasets such as ACRIMA (retinal images), BioMediTech (microscope images), and CoronaHack (X-rays), totaling 88,996 images.
|
73 |
+
|
74 |
+
**OpenAudioBench**
|
75 |
+
|
76 |
+
To efficiently assess the model's "IQ" issues, we developed OpenAudioBench, comprising five end-to-end audio understanding sub-datasets: four public benchmarks (Llama Question, WEB QA, TriviaQA, AlpacaEval), and an internally created speech logical reasoning dataset by the Baichuan team, totaling 2,701 entries. This suite reflects the model's comprehensive "IQ" level.
|
77 |
+
|
78 |
+
<!-- **High-quality Medical Image Evaluation Dataset--Openmm-Medical**
|
79 |
+
|
80 |
+
- We have built a more diverse medical evaluation dataset named **Openmm-Medical** to evaluate large models in medical scenarios.
|
81 |
+
- The images in Openmm-Medical come from **42 public medical image datasets**, such as ACRIMA (fundus images), BioMediTech (microscope images), and CoronaHack (X-rays).
|
82 |
+
- **Openmm-Medical contains a total of 88,996 images**, and each image is designed as a **multiple-choice question to facilitate the evaluation of different large models.**
|
83 |
+
- To promote the development of omnimodal large models in the medical field, we will soon **open** this evaluation dataset.
|
84 |
+
-->
|
85 |
+
|
86 |
+
### Evaluation
|
87 |
+
|
88 |
+
We sugguest readers to refer to our [**Github**](https://github.com/baichuan-inc/Baichuan-Omni-1.5/) for more details.
|
89 |
+
|
90 |
+
<div align="center">
|
91 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/performance.png" , width=80%>
|
92 |
+
</div>
|
93 |
+
|
94 |
+
<br>
|
95 |
+
|
96 |
+
<details>
|
97 |
+
|
98 |
+
<summary>click to view</summary>
|
99 |
+
|
100 |
+
#### Pure Text Understanding
|
101 |
+
<div align="center">
|
102 |
+
<table style="margin: 0 auto; text-align: center;">
|
103 |
+
<thead>
|
104 |
+
<tr>
|
105 |
+
<th class="tg-c3ow" colspan="7">Comprehensive Tasks</th>
|
106 |
+
</tr>
|
107 |
+
</thead>
|
108 |
+
<tbody>
|
109 |
+
<tr>
|
110 |
+
<td>Model</td>
|
111 |
+
<td>Size</td>
|
112 |
+
<td>MMLU (Acc.)</td>
|
113 |
+
<td>CMMLU (Acc.)</td>
|
114 |
+
<td>AGIEval (Acc.)</td>
|
115 |
+
<td>C-Eval (Acc.)</td>
|
116 |
+
<td>GAOKAO (Acc.)</td>
|
117 |
+
</tr>
|
118 |
+
<tr>
|
119 |
+
<td colspan="7">Proprietary Models</td>
|
120 |
+
</tr>
|
121 |
+
<tr>
|
122 |
+
<td>GPT 4o</td>
|
123 |
+
<td>-</td>
|
124 |
+
<td><b>88.0♢<br></td>
|
125 |
+
<td><b>78.3♢<br></td>
|
126 |
+
<td><b>62.3♢<br></td>
|
127 |
+
<td><b>86.0♢<br></td>
|
128 |
+
<td>-</td>
|
129 |
+
</tr>
|
130 |
+
<tr>
|
131 |
+
<td>GPT 4o mini</td>
|
132 |
+
<td>-</td>
|
133 |
+
<td>82.0</td>
|
134 |
+
<td>67.6</td>
|
135 |
+
<td>52.2</td>
|
136 |
+
<td>63.6</td>
|
137 |
+
<td>70.8</td>
|
138 |
+
</tr>
|
139 |
+
<tr>
|
140 |
+
<td colspan="7">Open-source Models (Pure text)</td>
|
141 |
+
</tr>
|
142 |
+
<tr>
|
143 |
+
<td>MAP-Neo</td>
|
144 |
+
<td>7B</td>
|
145 |
+
<td>58.2</td>
|
146 |
+
<td>55.1</td>
|
147 |
+
<td>33.9</td>
|
148 |
+
<td>57.5</td>
|
149 |
+
<td>-</td>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td>Qwen1.5-Chat</td>
|
153 |
+
<td>7B</td>
|
154 |
+
<td>61.5</td>
|
155 |
+
<td>68.0</td>
|
156 |
+
<td>39.3</td>
|
157 |
+
<td>68.8</td>
|
158 |
+
<td>-</td>
|
159 |
+
</tr>
|
160 |
+
<tr>
|
161 |
+
<td>Llama3-Instruct</td>
|
162 |
+
<td>8B</td>
|
163 |
+
<td>67.1</td>
|
164 |
+
<td>51.7</td>
|
165 |
+
<td>38.4</td>
|
166 |
+
<td>50.7</td>
|
167 |
+
<td>-</td>
|
168 |
+
</tr>
|
169 |
+
<tr>
|
170 |
+
<td>OLMo</td>
|
171 |
+
<td>7B</td>
|
172 |
+
<td>28.4</td>
|
173 |
+
<td>25.6</td>
|
174 |
+
<td>19.9</td>
|
175 |
+
<td>27.3</td>
|
176 |
+
<td>-</td>
|
177 |
+
</tr>
|
178 |
+
<tr>
|
179 |
+
<td colspan="7">Open-source Models (Omni-modal)</td>
|
180 |
+
</tr>
|
181 |
+
<tr>
|
182 |
+
<td>VITA</td>
|
183 |
+
<td>8x7B</td>
|
184 |
+
<td>71.0*</td>
|
185 |
+
<td>46.6</td>
|
186 |
+
<td>46.2*</td>
|
187 |
+
<td>56.7*</td>
|
188 |
+
<td>-</td>
|
189 |
+
</tr>
|
190 |
+
<tr>
|
191 |
+
<td>VITA-1.5</td>
|
192 |
+
<td>7B</td>
|
193 |
+
<td>71.0</td>
|
194 |
+
<td>75.1</td>
|
195 |
+
<td>47.9</td>
|
196 |
+
<td>65.6</td>
|
197 |
+
<td>57.4</td>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<td>Baichuan-Omni</td>
|
201 |
+
<td>7B</td>
|
202 |
+
<td>65.3</td>
|
203 |
+
<td>72.2</td>
|
204 |
+
<td>47.7</td>
|
205 |
+
<td>68.9</td>
|
206 |
+
<td>-</td>
|
207 |
+
</tr>
|
208 |
+
<tr>
|
209 |
+
<td>MiniCPM-o 2.6</td>
|
210 |
+
<td>7B</td>
|
211 |
+
<td>65.3</td>
|
212 |
+
<td>63.3</td>
|
213 |
+
<td>50.9</td>
|
214 |
+
<td>61.5</td>
|
215 |
+
<td>56.3</td>
|
216 |
+
</tr>
|
217 |
+
<tr>
|
218 |
+
<td><b>Baichuan-Omni-1.5<br></td>
|
219 |
+
<td>7B</td>
|
220 |
+
<td>72.2</td>
|
221 |
+
<td>75.5</td>
|
222 |
+
<td>54.4</td>
|
223 |
+
<td>73.1</td>
|
224 |
+
<td><b>73.5<br></td>
|
225 |
+
</tr>
|
226 |
+
</tbody>
|
227 |
+
</table>
|
228 |
+
</div>
|
229 |
+
|
230 |
+
</details>
|
231 |
+
|
232 |
+
|
233 |
+
<details>
|
234 |
+
|
235 |
+
<summary>click to view</summary>
|
236 |
+
|
237 |
+
#### Image Understanding
|
238 |
+
|
239 |
+
<div align="center">
|
240 |
+
<table style="margin: 0 auto; text-align: center;">
|
241 |
+
<thead>
|
242 |
+
<tr>
|
243 |
+
<th class="tg-c3ow" colspan="9">Multi-choice & Yes-or-No Question</th>
|
244 |
+
</tr>
|
245 |
+
</thead>
|
246 |
+
<tbody>
|
247 |
+
<tr>
|
248 |
+
<td>Model</td>
|
249 |
+
<td>Size</td>
|
250 |
+
<td>MMBench-EN (Acc.)</td>
|
251 |
+
<td>MMbench-CN (Acc.)</td>
|
252 |
+
<td>SEED-IMG (Acc.)</td>
|
253 |
+
<td>MMMU-val (Acc.)</td>
|
254 |
+
<td>HallusionBench (Acc.)</td>
|
255 |
+
</tr>
|
256 |
+
<tr>
|
257 |
+
<td colspan="9">Proprietary Models</td>
|
258 |
+
</tr>
|
259 |
+
<tr>
|
260 |
+
<td>GPT-4o</td>
|
261 |
+
<td>-</td>
|
262 |
+
<td>83.4♢</td>
|
263 |
+
<td>82.1♢</td>
|
264 |
+
<td>-</td>
|
265 |
+
<td><b>69.1♢<br></td>
|
266 |
+
<td><b>55.0♢<br></td>
|
267 |
+
</tr>
|
268 |
+
<tr>
|
269 |
+
<td>GPT-4o-mini</td>
|
270 |
+
<td>-</td>
|
271 |
+
<td>77.7</td>
|
272 |
+
<td>76.9</td>
|
273 |
+
<td>72.3</td>
|
274 |
+
<td>60.0♢</td>
|
275 |
+
<td>46.1♢</td>
|
276 |
+
</tr>
|
277 |
+
<tr>
|
278 |
+
<td colspan="9">Open Source Models (Vision-Language)</td>
|
279 |
+
</tr>
|
280 |
+
<tr>
|
281 |
+
<td>Qwen2-VL-7B</td>
|
282 |
+
<td>7B</td>
|
283 |
+
<td><b>86.4<br></td>
|
284 |
+
<td>81.9</td>
|
285 |
+
<td><b>76.5<br></td>
|
286 |
+
<td>52.7</td>
|
287 |
+
<td>50.6∗</td>
|
288 |
+
</tr>
|
289 |
+
<tr>
|
290 |
+
<td>MiniCPM-Llama3-V 2.5</td>
|
291 |
+
<td>8B</td>
|
292 |
+
<td>76.7</td>
|
293 |
+
<td>73.3</td>
|
294 |
+
<td>72.4</td>
|
295 |
+
<td>45.8∗</td>
|
296 |
+
<td>42.5</td>
|
297 |
+
</tr>
|
298 |
+
<tr>
|
299 |
+
<td colspan="9">Open Source Models (Omni-modal)</td>
|
300 |
+
</tr>
|
301 |
+
<tr>
|
302 |
+
<td>VITA</td>
|
303 |
+
<td>8x7B</td>
|
304 |
+
<td>74.7</td>
|
305 |
+
<td>71.4</td>
|
306 |
+
<td>72.6</td>
|
307 |
+
<td>45.3</td>
|
308 |
+
<td>39.7∗</td>
|
309 |
+
</tr>
|
310 |
+
<tr>
|
311 |
+
<td>VITA-1.5</td>
|
312 |
+
<td>7B</td>
|
313 |
+
<td>80.8</td>
|
314 |
+
<td>80.2</td>
|
315 |
+
<td>74.2</td>
|
316 |
+
<td>53.1</td>
|
317 |
+
<td>44.1</td>
|
318 |
+
</tr>
|
319 |
+
<tr>
|
320 |
+
<td>Baichuan-Omni</td>
|
321 |
+
<td>7B</td>
|
322 |
+
<td>76.2</td>
|
323 |
+
<td>74.9</td>
|
324 |
+
<td>74.1</td>
|
325 |
+
<td>47.3</td>
|
326 |
+
<td>47.8</td>
|
327 |
+
</tr>
|
328 |
+
<tr>
|
329 |
+
<td>MiniCPM-o 2.6</td>
|
330 |
+
<td>7B</td>
|
331 |
+
<td>83.6</td>
|
332 |
+
<td>81.8</td>
|
333 |
+
<td>75.4</td>
|
334 |
+
<td>51.1</td>
|
335 |
+
<td>50.1</td>
|
336 |
+
</tr>
|
337 |
+
<tr>
|
338 |
+
<td><b>Baichuan-Omni-1.5<br></td>
|
339 |
+
<td>7B</td>
|
340 |
+
<td>85.6</td>
|
341 |
+
<td><b>83.6<br></td>
|
342 |
+
<td>75.7</td>
|
343 |
+
<td>53.9</td>
|
344 |
+
<td>49.7</td>
|
345 |
+
</tr>
|
346 |
+
</tbody>
|
347 |
+
</table>
|
348 |
+
</div>
|
349 |
+
|
350 |
+
|
351 |
+
<br>
|
352 |
+
|
353 |
+
<div align="center">
|
354 |
+
<table style="margin: 0 auto; text-align: center;">
|
355 |
+
<thead>
|
356 |
+
<tr>
|
357 |
+
<th class="tg-c3ow" colspan="9">Visual Question Answering</th>
|
358 |
+
</tr>
|
359 |
+
</thead>
|
360 |
+
<tbody>
|
361 |
+
<tr>
|
362 |
+
<td>Model</td>
|
363 |
+
<td>Size</td>
|
364 |
+
<td>RealWorldQA (Acc.)</td>
|
365 |
+
<td>MathVista-mini (Acc.)</td>
|
366 |
+
<td>TextVQA-val (Acc.)</td>
|
367 |
+
<td>ChartQA (Acc.)</td>
|
368 |
+
<td>OCRBench (Acc.)</td>
|
369 |
+
</tr>
|
370 |
+
<tr>
|
371 |
+
<td colspan="8">Proprietary Models</td>
|
372 |
+
</tr>
|
373 |
+
<tr>
|
374 |
+
<td>GPT-4o</td>
|
375 |
+
<td>-</td>
|
376 |
+
<td><b>75.4♢<br></td>
|
377 |
+
<td>63.8♢</td>
|
378 |
+
<td>-</td>
|
379 |
+
<td>85.7♢</td>
|
380 |
+
<td>73.6♢</td>
|
381 |
+
</tr>
|
382 |
+
<tr>
|
383 |
+
<td>GPT-4o-mini</td>
|
384 |
+
<td>-</td>
|
385 |
+
<td>66.3</td>
|
386 |
+
<td>53.4</td>
|
387 |
+
<td>66.8</td>
|
388 |
+
<td>-</td>
|
389 |
+
<td>77.4</td>
|
390 |
+
</tr>
|
391 |
+
<tr>
|
392 |
+
<td colspan="8">Open Source Models (Vision-Language)</td>
|
393 |
+
</tr>
|
394 |
+
<tr>
|
395 |
+
<td>Qwen2-VL-7B</td>
|
396 |
+
<td>7B</td>
|
397 |
+
<td>69.7</td>
|
398 |
+
<td>58.2∗</td>
|
399 |
+
<td><b>84.3∗<br></td>
|
400 |
+
<td>83.0∗</td>
|
401 |
+
<td>84.5∗</td>
|
402 |
+
</tr>
|
403 |
+
<tr>
|
404 |
+
<td>MiniCPM-Llama3-V 2.5</td>
|
405 |
+
<td>8B</td>
|
406 |
+
<td>63.5</td>
|
407 |
+
<td>54.3∗</td>
|
408 |
+
<td>76.6</td>
|
409 |
+
<td>72.0</td>
|
410 |
+
<td>72.5</td>
|
411 |
+
</tr>
|
412 |
+
<tr>
|
413 |
+
<td colspan="8">Open Source Models (Omni-modal)</td>
|
414 |
+
</tr>
|
415 |
+
<tr>
|
416 |
+
<td>VITA</td>
|
417 |
+
<td>8x7B</td>
|
418 |
+
<td>59.0</td>
|
419 |
+
<td>44.9∗</td>
|
420 |
+
<td>71.8</td>
|
421 |
+
<td>76.6</td>
|
422 |
+
<td>68.5∗</td>
|
423 |
+
</tr>
|
424 |
+
<tr>
|
425 |
+
<td>VITA-1.5</td>
|
426 |
+
<td>7B</td>
|
427 |
+
<td>66.8</td>
|
428 |
+
<td><b>66.5<br></td>
|
429 |
+
<td>74.9</td>
|
430 |
+
<td>79.6</td>
|
431 |
+
<td>73.3</td>
|
432 |
+
</tr>
|
433 |
+
<tr>
|
434 |
+
<td>Baichuan-Omni</td>
|
435 |
+
<td>7B</td>
|
436 |
+
<td>62.6</td>
|
437 |
+
<td>51.9</td>
|
438 |
+
<td>74.3</td>
|
439 |
+
<td>79.6</td>
|
440 |
+
<td>70.0</td>
|
441 |
+
</tr>
|
442 |
+
<tr>
|
443 |
+
<td>MiniCPM-o 2.6</td>
|
444 |
+
<td>7B</td>
|
445 |
+
<td>67.7</td>
|
446 |
+
<td>64.6</td>
|
447 |
+
<td>80.1</td>
|
448 |
+
<td><b>87.6<br></td>
|
449 |
+
<td><b>89.7∗<br></td>
|
450 |
+
</tr>
|
451 |
+
<tr>
|
452 |
+
<td>Baichuan-Omni-1.5 </td>
|
453 |
+
<td>7B</td>
|
454 |
+
<td>68.8</td>
|
455 |
+
<td>63.6</td>
|
456 |
+
<td>83.2</td>
|
457 |
+
<td>84.9</td>
|
458 |
+
<td>84.0</td>
|
459 |
+
</tr>
|
460 |
+
</tbody>
|
461 |
+
</table>
|
462 |
+
</div>
|
463 |
+
|
464 |
+
|
465 |
+
</details>
|
466 |
+
|
467 |
+
<details>
|
468 |
+
|
469 |
+
<summary>click to view</summary>
|
470 |
+
|
471 |
+
#### Video Understanding
|
472 |
+
<div align="center">
|
473 |
+
<table style="margin: 0 auto; text-align: center;">
|
474 |
+
<thead>
|
475 |
+
<tr>
|
476 |
+
<th colspan="7">General VQA </th>
|
477 |
+
</tr>
|
478 |
+
</thead>
|
479 |
+
<tbody>
|
480 |
+
<tr>
|
481 |
+
<td>Model</td>
|
482 |
+
<td>Size</td>
|
483 |
+
<td># Frames</td>
|
484 |
+
<td>MVBench (Acc.)</td>
|
485 |
+
<td>Egoschema (Acc.)</td>
|
486 |
+
<td>VideoMME (Acc.)</td>
|
487 |
+
<td>Perception-Test (Acc.)</td>
|
488 |
+
</tr>
|
489 |
+
<tr>
|
490 |
+
<td colspan="7">Proprietary Models</td>
|
491 |
+
</tr>
|
492 |
+
<tr>
|
493 |
+
<td>Gemini 1.5 Pro</td>
|
494 |
+
<td>-</td>
|
495 |
+
<td>-</td>
|
496 |
+
<td><b>81.3♢<br></td>
|
497 |
+
<td>63.2*</td>
|
498 |
+
<td><b>75.0♢<br></td>
|
499 |
+
<td>-</td>
|
500 |
+
</tr>
|
501 |
+
<tr>
|
502 |
+
<td>GPT 4o mini</td>
|
503 |
+
<td>-</td>
|
504 |
+
<td>-</td>
|
505 |
+
<td>55.2</td>
|
506 |
+
<td>58.5</td>
|
507 |
+
<td>63.6</td>
|
508 |
+
<td>48.2</td>
|
509 |
+
</tr>
|
510 |
+
<tr>
|
511 |
+
<td>GPT 4o</td>
|
512 |
+
<td>-</td>
|
513 |
+
<td>-</td>
|
514 |
+
<td>-</td>
|
515 |
+
<td><b>77.2*<br></td>
|
516 |
+
<td>71.9♢</td>
|
517 |
+
<td>-</td>
|
518 |
+
</tr>
|
519 |
+
<tr>
|
520 |
+
<td>GPT 4V</td>
|
521 |
+
<td>-</td>
|
522 |
+
<td>-</td>
|
523 |
+
<td>43.7♢</td>
|
524 |
+
<td>55.6*</td>
|
525 |
+
<td>59.9♢</td>
|
526 |
+
<td>-</td>
|
527 |
+
</tr>
|
528 |
+
<tr>
|
529 |
+
<td colspan="7">Open-source Models (Vision-language)</td>
|
530 |
+
</tr>
|
531 |
+
<tr>
|
532 |
+
<td>Qwen2-VL-7B</td>
|
533 |
+
<td>7B</td>
|
534 |
+
<td>2 fps (max 768)</td>
|
535 |
+
<td>67.0* | 64.4</td>
|
536 |
+
<td>66.7* | 66.6</td>
|
537 |
+
<td>63.3* | 59.0</td>
|
538 |
+
<td>62.3* | 60.3</td>
|
539 |
+
</tr>
|
540 |
+
<tr>
|
541 |
+
<td>AnyGPT</td>
|
542 |
+
<td>8B</td>
|
543 |
+
<td>48</td>
|
544 |
+
<td>33.2</td>
|
545 |
+
<td>32.1</td>
|
546 |
+
<td>29.8</td>
|
547 |
+
<td>29.1</td>
|
548 |
+
</tr>
|
549 |
+
<tr>
|
550 |
+
<td>VideoLLaMA 2</td>
|
551 |
+
<td>7B</td>
|
552 |
+
<td>16</td>
|
553 |
+
<td>54.6*</td>
|
554 |
+
<td>51.7*</td>
|
555 |
+
<td>46.6*</td>
|
556 |
+
<td>51.4*</td>
|
557 |
+
</tr>
|
558 |
+
<tr>
|
559 |
+
<td>VideoChat2</td>
|
560 |
+
<td>7B</td>
|
561 |
+
<td>16</td>
|
562 |
+
<td>51.1*</td>
|
563 |
+
<td>42.1♢</td>
|
564 |
+
<td>33.7♢</td>
|
565 |
+
<td>47.3♢</td>
|
566 |
+
</tr>
|
567 |
+
<tr>
|
568 |
+
<td>LLaVA-NeXT-Video</td>
|
569 |
+
<td>7B</td>
|
570 |
+
<td>32</td>
|
571 |
+
<td>46.5♢</td>
|
572 |
+
<td>43.9♢</td>
|
573 |
+
<td>33.7♢</td>
|
574 |
+
<td>48.8♢</td>
|
575 |
+
</tr>
|
576 |
+
<tr>
|
577 |
+
<td>Video-LLaVA</td>
|
578 |
+
<td>7B</td>
|
579 |
+
<td>8</td>
|
580 |
+
<td>41.0♢</td>
|
581 |
+
<td>38.4♢</td>
|
582 |
+
<td>39.9♢</td>
|
583 |
+
<td>44.3♢</td>
|
584 |
+
</tr>
|
585 |
+
<tr>
|
586 |
+
<td colspan="7">Open-source Models (Omni-modal)</td>
|
587 |
+
</tr>
|
588 |
+
<tr>
|
589 |
+
<td>VITA</td>
|
590 |
+
<td>8x7B</td>
|
591 |
+
<td>1 fps (max 32)</td>
|
592 |
+
<td>53.4</td>
|
593 |
+
<td>53.9</td>
|
594 |
+
<td>56.1</td>
|
595 |
+
<td>56.2</td>
|
596 |
+
</tr>
|
597 |
+
<tr>
|
598 |
+
<td>VITA-1.5</td>
|
599 |
+
<td>7B</td>
|
600 |
+
<td>1 fps (max 32)</td>
|
601 |
+
<td>55.5</td>
|
602 |
+
<td>54.7</td>
|
603 |
+
<td>57.3</td>
|
604 |
+
<td>57.6</td>
|
605 |
+
</tr>
|
606 |
+
<tr>
|
607 |
+
<td>Baichuan-Omni</td>
|
608 |
+
<td>7B</td>
|
609 |
+
<td>1 fps (max 32)</td>
|
610 |
+
<td>60.9</td>
|
611 |
+
<td>58.8</td>
|
612 |
+
<td>58.2</td>
|
613 |
+
<td>56.8</td>
|
614 |
+
</tr>
|
615 |
+
<tr>
|
616 |
+
<td>MiniCPM-o 2.6</td>
|
617 |
+
<td>7B</td>
|
618 |
+
<td>1 fps (max 64)</td>
|
619 |
+
<td>58.6</td>
|
620 |
+
<td>50.7</td>
|
621 |
+
<td>63.4</td>
|
622 |
+
<td>66.6</td>
|
623 |
+
</tr>
|
624 |
+
<tr>
|
625 |
+
<td>Baichuan-Omini-1.5</td>
|
626 |
+
<td>7B</td>
|
627 |
+
<td>1 fps (max 32)</td>
|
628 |
+
<td> 63.7 </td>
|
629 |
+
<td> 62.4 </td>
|
630 |
+
<td> 60.1 </td>
|
631 |
+
<td> <b>68.9 <br> </td>
|
632 |
+
</tr>
|
633 |
+
</tbody>
|
634 |
+
</table>
|
635 |
+
</div>
|
636 |
+
|
637 |
+
<br>
|
638 |
+
|
639 |
+
<div align="center">
|
640 |
+
<table style="margin: 0 auto; text-align: center;">
|
641 |
+
<thead>
|
642 |
+
<tr>
|
643 |
+
<th colspan="7">Open-ended VQA</th>
|
644 |
+
</tr>
|
645 |
+
</thead>
|
646 |
+
<tbody>
|
647 |
+
<tr>
|
648 |
+
<td rowspan="2">Model</td>
|
649 |
+
<td rowspan="2">Size</td>
|
650 |
+
<td rowspan="2"># Frames</td>
|
651 |
+
<td colspan="2">ActivityNet-QA</td>
|
652 |
+
<td colspan="2">MSVD-QA</td>
|
653 |
+
</tr>
|
654 |
+
<tr>
|
655 |
+
<td>(Acc.)</td>
|
656 |
+
<td>(Score)</td>
|
657 |
+
<td>(Acc.)</td>
|
658 |
+
<td>(Score)</td>
|
659 |
+
</tr>
|
660 |
+
<tr>
|
661 |
+
<td colspan="7">Proprietary Models</td>
|
662 |
+
</tr>
|
663 |
+
<tr>
|
664 |
+
<td>Gemini 1.5 Pro</td>
|
665 |
+
<td>-</td>
|
666 |
+
<td>-</td>
|
667 |
+
<td>56.7*</td>
|
668 |
+
<td>-</td>
|
669 |
+
<td>-</td>
|
670 |
+
<td>-</td>
|
671 |
+
</tr>
|
672 |
+
<tr>
|
673 |
+
<td>GPT 4o mini</td>
|
674 |
+
<td>-</td>
|
675 |
+
<td>1 fps (max 32)</td>
|
676 |
+
<td>62.1</td>
|
677 |
+
<td>3.1</td>
|
678 |
+
<td>67.5</td>
|
679 |
+
<td>3.3</td>
|
680 |
+
</tr>
|
681 |
+
<tr>
|
682 |
+
<td>GPT 4o</td>
|
683 |
+
<td>-</td>
|
684 |
+
<td>-</td>
|
685 |
+
<td>61.9*</td>
|
686 |
+
<td>-</td>
|
687 |
+
<td>-</td>
|
688 |
+
<td>-</td>
|
689 |
+
</tr>
|
690 |
+
<tr>
|
691 |
+
<td>GPT 4V</td>
|
692 |
+
<td>-</td>
|
693 |
+
<td>-</td>
|
694 |
+
<td>59.5*</td>
|
695 |
+
<td>-</td>
|
696 |
+
<td>-</td>
|
697 |
+
<td>-</td>
|
698 |
+
</tr>
|
699 |
+
<tr>
|
700 |
+
<td colspan="7">Open-source Models (Vision-language)</td>
|
701 |
+
</tr>
|
702 |
+
<tr>
|
703 |
+
<td>Qwen2 VL</td>
|
704 |
+
<td>7B</td>
|
705 |
+
<td>2 fps (max 768)</td>
|
706 |
+
<td>17.4</td>
|
707 |
+
<td>1.9</td>
|
708 |
+
<td>61.1</td>
|
709 |
+
<td>3.5</td>
|
710 |
+
</tr>
|
711 |
+
<tr>
|
712 |
+
<td>VideoLLaMA 2</td>
|
713 |
+
<td>7B</td>
|
714 |
+
<td>16</td>
|
715 |
+
<td>50.2*</td>
|
716 |
+
<td>3.3*</td>
|
717 |
+
<td>70.9*</td>
|
718 |
+
<td>3.8*</td>
|
719 |
+
</tr>
|
720 |
+
<tr>
|
721 |
+
<td>VideoChat2</td>
|
722 |
+
<td>7B</td>
|
723 |
+
<td>16</td>
|
724 |
+
<td>49.1*</td>
|
725 |
+
<td>3.3*</td>
|
726 |
+
<td>70.0*</td>
|
727 |
+
<td>3.9*</td>
|
728 |
+
</tr>
|
729 |
+
<tr>
|
730 |
+
<td>LLaVA-NeXT-Video</td>
|
731 |
+
<td>7B</td>
|
732 |
+
<td>32</td>
|
733 |
+
<td>53.5*</td>
|
734 |
+
<td>3.2*</td>
|
735 |
+
<td>67.4</td>
|
736 |
+
<td>3.4</td>
|
737 |
+
</tr>
|
738 |
+
<tr>
|
739 |
+
<td>Video-LLaVA</td>
|
740 |
+
<td>7B</td>
|
741 |
+
<td>8</td>
|
742 |
+
<td>45.3*</td>
|
743 |
+
<td>3.3*</td>
|
744 |
+
<td>70.7*</td>
|
745 |
+
<td>3.9*</td>
|
746 |
+
</tr>
|
747 |
+
<tr>
|
748 |
+
<td colspan="7">Open-source Models (Omni-modal)</td>
|
749 |
+
</tr>
|
750 |
+
<tr>
|
751 |
+
<td>VITA</td>
|
752 |
+
<td>8x7B</td>
|
753 |
+
<td>1 fps (max 32)</td>
|
754 |
+
<td>55.0</td>
|
755 |
+
<td>3.5</td>
|
756 |
+
<td>63.9</td>
|
757 |
+
<td>3.7</td>
|
758 |
+
</tr>
|
759 |
+
<tr>
|
760 |
+
<td>VITA-1.5</td>
|
761 |
+
<td>7B</td>
|
762 |
+
<td>1 fps (max 32)</td>
|
763 |
+
<td>59.6</td>
|
764 |
+
<td>3.0</td>
|
765 |
+
<td>67.6</td>
|
766 |
+
<td>3.3</td>
|
767 |
+
</tr>
|
768 |
+
<tr>
|
769 |
+
<td>Baichuan-Omni</td>
|
770 |
+
<td>7B</td>
|
771 |
+
<td>1 fps (max 48)</td>
|
772 |
+
<td>58.6</td>
|
773 |
+
<td><b>3.7<br></td>
|
774 |
+
<td>72.2</td>
|
775 |
+
<td> <b>4.0<br> </td>
|
776 |
+
</tr>
|
777 |
+
<tr>
|
778 |
+
<td>MiniCPM-o 2.6</td>
|
779 |
+
<td>7B</td>
|
780 |
+
<td>1 fps (max 64)</td>
|
781 |
+
<td><b>63.0<br></td>
|
782 |
+
<td>3.1</td>
|
783 |
+
<td>73.7</td>
|
784 |
+
<td>3.6</td>
|
785 |
+
</tr>
|
786 |
+
<tr>
|
787 |
+
<td>Baichuan-Omni-1.5</td>
|
788 |
+
<td>7B</td>
|
789 |
+
<td>1 fps (max 48)</td>
|
790 |
+
<td> 62.0</td>
|
791 |
+
<td> 3.1</td>
|
792 |
+
<td> <b> 74.2 <br></td>
|
793 |
+
<td> 3.6</td>
|
794 |
+
</tr>
|
795 |
+
</tbody>
|
796 |
+
</table>
|
797 |
+
</div>
|
798 |
+
|
799 |
+
</details>
|
800 |
+
|
801 |
+
|
802 |
+
<details>
|
803 |
+
|
804 |
+
<summary>click to view</summary>
|
805 |
+
|
806 |
+
#### Audio Comprehensive and Speech Generation
|
807 |
+
<div align="center">
|
808 |
+
<table style="margin: 0 auto; text-align: center;">
|
809 |
+
<thead>
|
810 |
+
<tr>
|
811 |
+
<th colspan="12">Audio Comprehensive Capacity</th>
|
812 |
+
</tr></thead>
|
813 |
+
<tbody>
|
814 |
+
<tr>
|
815 |
+
<td rowspan="2">Model</td>
|
816 |
+
<td rowspan="2">Size</td>
|
817 |
+
<td colspan="2">Reasoning QA</td>
|
818 |
+
<td colspan="2">Llama Questions</td>
|
819 |
+
<td colspan="2">Web Questions</td>
|
820 |
+
<td colspan="2">TriviaQA</td>
|
821 |
+
<td colspan="2">AlpacaEval</td>
|
822 |
+
</tr>
|
823 |
+
<tr>
|
824 |
+
<td>s→t</td>
|
825 |
+
<td>s→s</td>
|
826 |
+
<td>s→t</td>
|
827 |
+
<td>s→s</td>
|
828 |
+
<td>s→t</td>
|
829 |
+
<td>s→s</td>
|
830 |
+
<td>s→t</td>
|
831 |
+
<td>s→s</td>
|
832 |
+
<td>s→t</td>
|
833 |
+
<td>s→s</td>
|
834 |
+
</tr>
|
835 |
+
<tr>
|
836 |
+
<td colspan="12">Proprietary Models</td>
|
837 |
+
</tr>
|
838 |
+
<tr>
|
839 |
+
<td>GPT-4o-Audio</td>
|
840 |
+
<td>-</td>
|
841 |
+
<td><b>55.6</td>
|
842 |
+
<td>-</td>
|
843 |
+
<td><b>88.4</td>
|
844 |
+
<td>-</td>
|
845 |
+
<td><b>8.10</td>
|
846 |
+
<td>-</td>
|
847 |
+
<td><b>9.06</td>
|
848 |
+
<td>-</td>
|
849 |
+
<td><b>8.01</td>
|
850 |
+
<td>-</td>
|
851 |
+
</tr>
|
852 |
+
<tr>
|
853 |
+
<td colspan="12">Open-source Models (Pure Audio)</td>
|
854 |
+
</tr>
|
855 |
+
<tr>
|
856 |
+
<td>GLM-4-Voice</td>
|
857 |
+
<td>9B</td>
|
858 |
+
<td>-</td>
|
859 |
+
<td>26.5</td>
|
860 |
+
<td>-</td>
|
861 |
+
<td>71.0</td>
|
862 |
+
<td>-</td>
|
863 |
+
<td>5.15</td>
|
864 |
+
<td>-</td>
|
865 |
+
<td>4.66</td>
|
866 |
+
<td>-</td>
|
867 |
+
<td>4.89</td>
|
868 |
+
</tr>
|
869 |
+
<tr>
|
870 |
+
<td colspan="12">Open-source Models (Omni-modal)</td>
|
871 |
+
</tr>
|
872 |
+
<tr>
|
873 |
+
<td>VITA-1.5</td>
|
874 |
+
<td>7B</td>
|
875 |
+
<td>41.0</td>
|
876 |
+
<td>-</td>
|
877 |
+
<td>74.2</td>
|
878 |
+
<td>-</td>
|
879 |
+
<td>5.73</td>
|
880 |
+
<td>-</td>
|
881 |
+
<td>4.68</td>
|
882 |
+
<td>-</td>
|
883 |
+
<td>6.82</td>
|
884 |
+
<td>-</td>
|
885 |
+
</tr>
|
886 |
+
<tr>
|
887 |
+
<td>MiniCPM-o 2.6</td>
|
888 |
+
<td>7B</td>
|
889 |
+
<td>38.6</td>
|
890 |
+
<td>-</td>
|
891 |
+
<td>77.8</td>
|
892 |
+
<td>-</td>
|
893 |
+
<td>6.86</td>
|
894 |
+
<td>-</td>
|
895 |
+
<td>6.19</td>
|
896 |
+
<td>-</td>
|
897 |
+
<td>5.18</td>
|
898 |
+
<td>-</td>
|
899 |
+
</tr>
|
900 |
+
<tr>
|
901 |
+
<td><b>Baichuan-Omni-1.5</td>
|
902 |
+
<td>7B</td>
|
903 |
+
<td>50.0</td>
|
904 |
+
<td><b>40.9</td>
|
905 |
+
<td>78.5</td>
|
906 |
+
<td><b>75.3</td>
|
907 |
+
<td>5.91</td>
|
908 |
+
<td><b>5.52</td>
|
909 |
+
<td>5.72</td>
|
910 |
+
<td>5.31</td>
|
911 |
+
<td>7.79</td>
|
912 |
+
<td><b>6.94</td>
|
913 |
+
</tr>
|
914 |
+
</tbody>
|
915 |
+
</table>
|
916 |
+
</div>
|
917 |
+
|
918 |
+
|
919 |
+
</details>
|
920 |
+
|
921 |
+
|
922 |
+
|
923 |
+
<details>
|
924 |
+
|
925 |
+
<summary>click to view</summary>
|
926 |
+
|
927 |
+
#### Omni-modal Understanding
|
928 |
+
|
929 |
+
<div align="center">
|
930 |
+
<table style="margin: 0 auto; text-align: center;">
|
931 |
+
<thead>
|
932 |
+
<tr>
|
933 |
+
<th colspan="7">Omni-Undesratnding </th>
|
934 |
+
</tr>
|
935 |
+
<thead>
|
936 |
+
<tbody>
|
937 |
+
<tr>
|
938 |
+
<td>Model</td>
|
939 |
+
<td>Size</td>
|
940 |
+
<td>Image & Audio</td>
|
941 |
+
<td>Image Caption & Audio</td>
|
942 |
+
<td>Image & Audio Transcript</td>
|
943 |
+
<td>Image Caption & Audio Transcript</td>
|
944 |
+
</tr>
|
945 |
+
</thead>
|
946 |
+
<tr>
|
947 |
+
<td colspan="6">Proprietary Models</td>
|
948 |
+
</tr>
|
949 |
+
<tr>
|
950 |
+
<td>GPT4o-mini</td>
|
951 |
+
<td>-</td>
|
952 |
+
<td>-</td>
|
953 |
+
<td>-</td>
|
954 |
+
<td>37.0</td>
|
955 |
+
<td>37.7</td>
|
956 |
+
</tr>
|
957 |
+
<tr>
|
958 |
+
<td colspan="6">Open-source Models (Omni-modal)</td>
|
959 |
+
</tr>
|
960 |
+
<tr>
|
961 |
+
<td>VITA</td>
|
962 |
+
<td>8x7B</td>
|
963 |
+
<td>33.1</td>
|
964 |
+
<td>31.8</td>
|
965 |
+
<td>42.0</td>
|
966 |
+
<td>44.2</td>
|
967 |
+
</tr>
|
968 |
+
<tr>
|
969 |
+
<td>VITA-1.5</td>
|
970 |
+
<td>7B</td>
|
971 |
+
<td>33.4</td>
|
972 |
+
<td>29.6</td>
|
973 |
+
<td>48.5</td>
|
974 |
+
<td><b>47.2<br></td>
|
975 |
+
</tr>
|
976 |
+
<tr>
|
977 |
+
<td>Baichuan-Omni</td>
|
978 |
+
<td>7B</td>
|
979 |
+
<td>32.2</td>
|
980 |
+
<td>26.5</td>
|
981 |
+
<td>42.6</td>
|
982 |
+
<td>44.2</td>
|
983 |
+
</tr>
|
984 |
+
<tr>
|
985 |
+
<td>MiniCPM-o 2.6</td>
|
986 |
+
<td>7B</td>
|
987 |
+
<td>40.5</td>
|
988 |
+
<td>30.8</td>
|
989 |
+
<td><b>53.2<br></td>
|
990 |
+
<td>46.3</td>
|
991 |
+
</tr>
|
992 |
+
<tr>
|
993 |
+
<td><b>Baichuan-Omni-1.5<br></td>
|
994 |
+
<td>7B</td>
|
995 |
+
<td><b>42.9<br></td>
|
996 |
+
<td><b>37.7<br></td>
|
997 |
+
<td>47.9</td>
|
998 |
+
<td>46.9</td>
|
999 |
+
</tr>
|
1000 |
+
</tbody>
|
1001 |
+
</table>
|
1002 |
+
</div>
|
1003 |
+
|
1004 |
+
</details>
|
1005 |
+
|
1006 |
+
<details>
|
1007 |
+
|
1008 |
+
<summary>click to view</summary>
|
1009 |
+
|
1010 |
+
#### Medical Image Understanding Capabilities
|
1011 |
+
|
1012 |
+
<div align="center">
|
1013 |
+
<table style="margin: 0 auto; text-align: center;">
|
1014 |
+
<thead>
|
1015 |
+
<tr>
|
1016 |
+
<th colspan="7">Medical Understanding </th>
|
1017 |
+
</tr>
|
1018 |
+
</thead>
|
1019 |
+
<tbody>
|
1020 |
+
<tr>
|
1021 |
+
<td>Model</td>
|
1022 |
+
<td>Size</td>
|
1023 |
+
<td>GMAI-MMB-VAL (Acc.)</td>
|
1024 |
+
<td>OpenMM-Medical (Acc.)</td>
|
1025 |
+
</tr>
|
1026 |
+
</thead>
|
1027 |
+
<tr>
|
1028 |
+
<td colspan="4">Proprietary Models</td>
|
1029 |
+
</tr>
|
1030 |
+
<tr>
|
1031 |
+
<td>GPT4o-mini</td>
|
1032 |
+
<td>-</td>
|
1033 |
+
<td>46.4</td>
|
1034 |
+
<td>74.3</td>
|
1035 |
+
</tr>
|
1036 |
+
<tr>
|
1037 |
+
<td colspan="4">Open-source Models (Vision-Language)</td>
|
1038 |
+
</tr>
|
1039 |
+
<tr>
|
1040 |
+
<td>Qwen2 VL</td>
|
1041 |
+
<td>7B</td>
|
1042 |
+
<td>46.3</td>
|
1043 |
+
<td>76.9</td>
|
1044 |
+
</tr>
|
1045 |
+
<tr>
|
1046 |
+
<td>Qwen2 VL</td>
|
1047 |
+
<td>72B</td>
|
1048 |
+
<td><b>50.7<br></td>
|
1049 |
+
<td>80.7</td>
|
1050 |
+
</tr>
|
1051 |
+
<tr>
|
1052 |
+
<td colspan="4">Open-source Models (Omni-modal)</td>
|
1053 |
+
</tr>
|
1054 |
+
<tr>
|
1055 |
+
<td>VITA-1.5</td>
|
1056 |
+
<td>7B</td>
|
1057 |
+
<td>36.7</td>
|
1058 |
+
<td>67.1</td>
|
1059 |
+
</tr>
|
1060 |
+
<tr>
|
1061 |
+
<td>MiniCPM-o 2.6</td>
|
1062 |
+
<td>7B</td>
|
1063 |
+
<td>41.5</td>
|
1064 |
+
<td>73.6</td>
|
1065 |
+
</tr>
|
1066 |
+
<tr>
|
1067 |
+
<td><b>Baichuan-Omni-1.5<br></td>
|
1068 |
+
<td>7B</td>
|
1069 |
+
<td>49.9</td>
|
1070 |
+
<td><b>83.8<br></td>
|
1071 |
+
</tr>
|
1072 |
+
</tbody>
|
1073 |
+
</table>
|
1074 |
+
</div>
|
1075 |
+
|
1076 |
+
</details>
|
1077 |
+
|
1078 |
+
## Examples
|
1079 |
+
<br>
|
1080 |
+
|
1081 |
+
<div style="display: flex; flex-direction: column; align-items: center;">
|
1082 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/pipeline.png" alt="pipeline" style="margin-bottom: 5px;">
|
1083 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/math.png" alt="math" style="margin-bottom: 5px;">
|
1084 |
+
<img src="https://github.com/baichuan-inc/Baichuan-Omni-1.5/raw/main/assets/fly_bill.png" alt="fly_bill" style="margin-bottom: 5px;">
|
1085 |
+
</div>
|
1086 |
+
|
1087 |
+
|
1088 |
+
## 🚀 Quick Start
|
1089 |
+
We recommend interested scholars to visit our github repo for more details. [**Github**](https://github.com/baichuan-inc/Baichuan-Omni-1.5/)
|
1090 |
+
|
1091 |
+
|
1092 |
+
### Statement
|
1093 |
+
- We hereby declare that our team has not developed any applications based on Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models, not on iOS, Android, the web, or any other platform. We strongly call on all users not to use Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models for any activities that harm national / social security or violate the law. Also, we ask users not to use Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models for Internet services that have not undergone appropriate security reviews and filings. We hope that all users can abide by this principle and ensure that the development of technology proceeds in a regulated and legal environment.
|
1094 |
+
|
1095 |
+
- We have done our best to ensure the compliance of the data used in the model training process. However, despite our considerable efforts, there may still be some unforeseeable issues due to the complexity of the model and data. Therefore, if any problems arise due to the use of Baichuan-Omni-1.5/Baichuan-Omni-1.5-base open-source models, including but not limited to data security issues, public opinion risks, or any risks and problems brought about by the model being misled, abused, spread or improperly exploited, we will not assume any responsibility.
|
1096 |
+
|
1097 |
+
|
1098 |
+
|
1099 |
+
### License
|
1100 |
+
The community usage of Baichuan-Omni-1.5/Baichuan-Omni-1.5-base requires adherence to [Apache 2.0](https://github.com/baichuan-inc/Baichuan-Omni-1.5/blob/main/LICENSE) and [Community License for Baichuan-Omni-1.5 Models](https://github.com/baichuan-inc/Baichuan-Omni-1.5/blob/main/LICENSE). The Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models supports commercial use. If you plan to use the Baichuan-Omni-1.5/Baichuan-Omni-1.5-base models or its derivatives for commercial purposes, please ensure that your entity meets the following conditions:
|
1101 |
+
|
1102 |
+
1. The Daily Active Users (DAU) of your or your affiliate's service or product is less than 1 million.
|
1103 |
+
2. Neither you nor your affiliates are software service providers or cloud service providers.
|
1104 |
+
3. There is no possibility for you or your affiliates to grant the commercial license given to you, to reauthorize it to other third parties without Baichuan's permission.
|
1105 |
+
|
1106 |
+
Upon meeting the above conditions, you need to submit the application materials required by the Baichuan-Omni-1.5 Model Community License Agreement via the following contact email: [email protected]. Once approved, Baichuan will hereby grant you a non-exclusive, global, non-transferable, non-sublicensable, revocable commercial copyright license.
|
1107 |
+
|
1108 |
+
<!-- ### Citation
|
1109 |
+
|
1110 |
+
If you find our work helpful, please consider citing our papers 📝 and liking this project ❤️!
|
1111 |
+
```bib
|
1112 |
+
@article{
|
1113 |
+
} -->
|
1114 |
+
```
|
added_tokens.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<B_APE>": 151652,
|
3 |
+
"<B_CODE>": 151651,
|
4 |
+
"<B_FUNC>": 151650,
|
5 |
+
"<B_SYS>": 151646,
|
6 |
+
"<B_USYS>": 151647,
|
7 |
+
"<C_A>": 151649,
|
8 |
+
"<C_Q>": 151648,
|
9 |
+
"<audio_delim_baichuan>": 151674,
|
10 |
+
"<audio_end_baichuan>": 151658,
|
11 |
+
"<audio_pad_baichuan>": 151659,
|
12 |
+
"<audio_start_baichuan>": 151657,
|
13 |
+
"<audiogen_end_baichuan>": 151679,
|
14 |
+
"<audiogen_start_baichuan>": 151678,
|
15 |
+
"<audiotext_end_baichuan>": 151676,
|
16 |
+
"<audiotext_pad_baichuan>": 151677,
|
17 |
+
"<audiotext_start_baichuan>": 151675,
|
18 |
+
"<baichuan_pad_token>": 151672,
|
19 |
+
"<box_delim_baichuan>": 151666,
|
20 |
+
"<box_end_baichuan>": 151665,
|
21 |
+
"<box_start_baichuan>": 151664,
|
22 |
+
"<calc_end>": 151655,
|
23 |
+
"<calc_start>": 151654,
|
24 |
+
"<function_calling>": 151653,
|
25 |
+
"<img_delim_baichuan>": 151669,
|
26 |
+
"<img_end_baichuan>": 151661,
|
27 |
+
"<img_newline_baichuan>": 151663,
|
28 |
+
"<img_pad_baichuan>": 151662,
|
29 |
+
"<img_start_baichuan>": 151660,
|
30 |
+
"<inner_think>": 151656,
|
31 |
+
"<polygon_end_baichuan>": 151671,
|
32 |
+
"<polygon_start_baichuan>": 151670,
|
33 |
+
"<ref_end_baichuan>": 151668,
|
34 |
+
"<ref_start_baichuan>": 151667,
|
35 |
+
"<reserved_113>": 151673,
|
36 |
+
"<|endoftext|>": 151643,
|
37 |
+
"<|im_end|>": 151645,
|
38 |
+
"<|im_start|>": 151644
|
39 |
+
}
|
audio_modeling_omni.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, fire
|
2 |
+
from typing import Optional
|
3 |
+
import torch.distributed
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from flash_attn import flash_attn_varlen_func
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
import deepspeed
|
9 |
+
from transformers.activations import ACT2FN
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
try:
|
13 |
+
from .vector_quantize import VectorQuantize
|
14 |
+
except:
|
15 |
+
from vector_quantize import VectorQuantize
|
16 |
+
|
17 |
+
from .flow_matching import (
|
18 |
+
ConditionalDecoder,
|
19 |
+
ConditionalCFM,
|
20 |
+
)
|
21 |
+
|
22 |
+
import math
|
23 |
+
import copy
|
24 |
+
|
25 |
+
def sinusoids(length, channels, max_timescale=10000):
|
26 |
+
"""Returns sinusoids for positional embedding"""
|
27 |
+
assert channels % 2 == 0
|
28 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
29 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
30 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
31 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
32 |
+
|
33 |
+
def get_sequence_mask(inputs, inputs_length):
|
34 |
+
if inputs.dim() == 3:
|
35 |
+
bsz, tgt_len, _ = inputs.size()
|
36 |
+
else:
|
37 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
38 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
39 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
40 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
41 |
+
return sequence_mask, unpacking_index
|
42 |
+
|
43 |
+
def unpack_hidden_states(hidden_states, lengths):
|
44 |
+
bsz = lengths.shape[0]
|
45 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
|
46 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
47 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
48 |
+
)
|
49 |
+
hidden_states = torch.where(
|
50 |
+
sequence_mask, hidden_states, 0
|
51 |
+
) # 3d (bsz, max_input_len, d)
|
52 |
+
return hidden_states
|
53 |
+
|
54 |
+
|
55 |
+
class RMSNorm(nn.Module):
|
56 |
+
def __init__(self, hidden_size, eps=1e-6):
|
57 |
+
"""
|
58 |
+
RMSNorm is equivalent to T5LayerNorm
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
62 |
+
self.variance_epsilon = eps
|
63 |
+
|
64 |
+
def forward(self, hidden_states):
|
65 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
66 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
67 |
+
|
68 |
+
# convert into half-precision if necessary
|
69 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
70 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
71 |
+
|
72 |
+
return self.weight * hidden_states
|
73 |
+
|
74 |
+
|
75 |
+
class OmniWhisperAttention(nn.Module):
|
76 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
77 |
+
super().__init__()
|
78 |
+
self.embed_dim = embed_dim
|
79 |
+
self.num_heads = num_heads
|
80 |
+
self.head_dim = embed_dim // num_heads
|
81 |
+
|
82 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
83 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
84 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
85 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
86 |
+
|
87 |
+
self.causal = causal
|
88 |
+
|
89 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
90 |
+
bsz, _ = hidden_states.size()
|
91 |
+
|
92 |
+
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
93 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
94 |
+
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
95 |
+
|
96 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
97 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
98 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
|
99 |
+
max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
|
100 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
101 |
+
attn_output = self.out_proj(attn_output)
|
102 |
+
return attn_output
|
103 |
+
|
104 |
+
|
105 |
+
class OmniWhisperTransformerLayer(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
act,
|
109 |
+
d_model,
|
110 |
+
encoder_attention_heads,
|
111 |
+
encoder_ffn_dim,
|
112 |
+
causal,
|
113 |
+
ln_type="LayerNorm",
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.embed_dim = d_model
|
117 |
+
self.self_attn = OmniWhisperAttention(
|
118 |
+
self.embed_dim, encoder_attention_heads, causal
|
119 |
+
)
|
120 |
+
|
121 |
+
if ln_type == "LayerNorm":
|
122 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
123 |
+
elif ln_type == "RMSNorm":
|
124 |
+
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
|
125 |
+
else:
|
126 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
127 |
+
|
128 |
+
self.activation_fn = act
|
129 |
+
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
|
130 |
+
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
|
131 |
+
|
132 |
+
if ln_type == "LayerNorm":
|
133 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
134 |
+
elif ln_type == "RMSNorm":
|
135 |
+
self.final_layer_norm = RMSNorm(self.embed_dim)
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
|
141 |
+
) -> torch.Tensor:
|
142 |
+
residual = hidden_states
|
143 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
144 |
+
hidden_states = self.self_attn(hidden_states, seq_len)
|
145 |
+
hidden_states = residual + hidden_states
|
146 |
+
residual = hidden_states
|
147 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
148 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
149 |
+
hidden_states = self.fc2(hidden_states)
|
150 |
+
hidden_states = residual + hidden_states
|
151 |
+
|
152 |
+
if (
|
153 |
+
hidden_states.dtype == torch.float16
|
154 |
+
or hidden_states.dtype == torch.bfloat16
|
155 |
+
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
156 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
157 |
+
hidden_states = torch.clamp(
|
158 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
159 |
+
)
|
160 |
+
return hidden_states
|
161 |
+
|
162 |
+
|
163 |
+
class OmniAudioEncoder(nn.Module):
|
164 |
+
def __init__(self, config):
|
165 |
+
super().__init__()
|
166 |
+
config._attn_implementation = 'flash_attention_2' #
|
167 |
+
self.config = config
|
168 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
169 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
170 |
+
|
171 |
+
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
|
172 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
|
173 |
+
stride=config.stride_size, padding=1)
|
174 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
|
175 |
+
|
176 |
+
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
|
177 |
+
ACT2FN[config.activation_function],
|
178 |
+
config.d_model,
|
179 |
+
config.encoder_attention_heads,
|
180 |
+
config.encoder_ffn_dim,
|
181 |
+
False) for _ in range(config.encoder_layers)])
|
182 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def fake_input(self, device):
|
186 |
+
input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
|
187 |
+
encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
|
188 |
+
bridge_length = torch.ones([2], dtype=torch.int32, device=device)
|
189 |
+
return input_features, encoder_length, bridge_length
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
input_features,
|
194 |
+
output_length,
|
195 |
+
):
|
196 |
+
input_features = input_features.to(self.conv1.weight.dtype)
|
197 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
|
198 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
|
199 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
|
200 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
201 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
202 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
203 |
+
else:
|
204 |
+
current_positional_embedding = self.positional_embedding
|
205 |
+
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
|
206 |
+
|
207 |
+
# packing hidden states
|
208 |
+
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
|
209 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
|
210 |
+
self.config.d_model)
|
211 |
+
|
212 |
+
for idx, encoder_layer in enumerate(self.layers):
|
213 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
214 |
+
hidden_states = self.layer_norm(hidden_states)
|
215 |
+
# unpacking
|
216 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
217 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
218 |
+
return hidden_states
|
219 |
+
|
220 |
+
|
221 |
+
class CasualConvTranspose1d(nn.Module): # 反卷积
|
222 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
223 |
+
super().__init__()
|
224 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
225 |
+
self.norm = nn.GroupNorm(1, out_channels)
|
226 |
+
self.in_channels = in_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
|
229 |
+
def forward(self, hidden_states, input_length, output_dim=None):
|
230 |
+
kernel_size = self.conv.kernel_size[0]
|
231 |
+
stride = self.conv.stride[0]
|
232 |
+
bsz = input_length.shape[0]
|
233 |
+
|
234 |
+
if output_dim is None:
|
235 |
+
output_dim = hidden_states.dim()
|
236 |
+
if hidden_states.dim() <= 2: # unpack sequence to 3d
|
237 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
|
238 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
|
239 |
+
self.in_channels)
|
240 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
|
241 |
+
|
242 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
|
243 |
+
hidden_states = self.conv(hidden_states)
|
244 |
+
hidden_states = self.norm(hidden_states)
|
245 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
|
246 |
+
|
247 |
+
casual_padding_right = max(0, kernel_size - stride)
|
248 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
|
249 |
+
:]
|
250 |
+
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
|
251 |
+
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
|
252 |
+
if output_dim <= 2:
|
253 |
+
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
|
254 |
+
else:
|
255 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
256 |
+
hidden_states = hidden_states[:, :torch.max(output_length), :] # 截断到最大有效长度
|
257 |
+
return hidden_states, output_length
|
258 |
+
|
259 |
+
|
260 |
+
class MelSpecRefineNet(nn.Module):
|
261 |
+
"""
|
262 |
+
# post net, coarse to refined mel-spectrogram frames
|
263 |
+
# ref1: Autoregressive Speech Synthesis without Vector Quantization
|
264 |
+
# ref2: CosyVoice length_regulator.py
|
265 |
+
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(self, encoder_config, vocoder_config):
|
269 |
+
super().__init__()
|
270 |
+
self.encoder_config = encoder_config
|
271 |
+
self.vocoder_config = vocoder_config
|
272 |
+
|
273 |
+
layers = nn.ModuleList([])
|
274 |
+
in_channels = self.vocoder_config.num_mel_bins
|
275 |
+
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
|
276 |
+
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
|
277 |
+
in_channels = out_channels
|
278 |
+
norm = nn.GroupNorm(1, out_channels)
|
279 |
+
act = nn.Mish()
|
280 |
+
layers.extend([module, norm, act])
|
281 |
+
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
|
282 |
+
self.layers = nn.Sequential(*layers)
|
283 |
+
|
284 |
+
def compute_output_length(self, input_length):
|
285 |
+
output_length = input_length.to(
|
286 |
+
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
|
287 |
+
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
|
288 |
+
return output_length.to(torch.int64)
|
289 |
+
|
290 |
+
def forward(self, coarse_mel, input_length, output_length=None):
|
291 |
+
bsz, _, d = coarse_mel.shape
|
292 |
+
assert (d == self.vocoder_config.num_mel_bins)
|
293 |
+
if output_length is None or not self.training:
|
294 |
+
output_length = self.compute_output_length(input_length)
|
295 |
+
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
|
296 |
+
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
|
297 |
+
mode='nearest').to(default_dtype)
|
298 |
+
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
|
299 |
+
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
|
300 |
+
refined_mel += coarse_mel # residual conntection
|
301 |
+
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
|
302 |
+
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
|
303 |
+
refined_mel = torch.where(sequence_mask, refined_mel, 0)
|
304 |
+
return refined_mel, coarse_mel, output_length
|
305 |
+
|
306 |
+
|
307 |
+
@dataclass
|
308 |
+
class OmniAudioDecoderOutput(ModelOutput):
|
309 |
+
refined_mel: Optional[torch.FloatTensor] = None
|
310 |
+
coarse_mel: Optional[torch.FloatTensor] = None
|
311 |
+
mel_length: Optional[torch.Tensor] = None
|
312 |
+
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
|
313 |
+
output_length_before_dconv2: Optional[torch.Tensor] = None
|
314 |
+
|
315 |
+
|
316 |
+
class OmniAudioDecoder(nn.Module):
|
317 |
+
def __init__(self, config):
|
318 |
+
super().__init__()
|
319 |
+
self.config = config.audio_config
|
320 |
+
self.vocoder_config = config.vocoder_config
|
321 |
+
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
|
322 |
+
|
323 |
+
self.dconv1 = CasualConvTranspose1d(
|
324 |
+
self.config.d_model,
|
325 |
+
self.config.d_model,
|
326 |
+
self.config.decoder_kernel_size,
|
327 |
+
self.config.avg_pooler,
|
328 |
+
)
|
329 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
|
330 |
+
# causal transformer layers
|
331 |
+
self.layers = nn.ModuleList(
|
332 |
+
[OmniWhisperTransformerLayer(
|
333 |
+
ACT2FN[self.config.activation_function],
|
334 |
+
self.config.d_model,
|
335 |
+
self.config.decoder_attention_heads,
|
336 |
+
self.config.decoder_ffn_dim,
|
337 |
+
True # causal
|
338 |
+
) for _ in range(self.config.decoder_layers)
|
339 |
+
])
|
340 |
+
self.layer_norm = nn.LayerNorm(self.config.d_model)
|
341 |
+
self.dconv2 = CasualConvTranspose1d(
|
342 |
+
self.config.d_model,
|
343 |
+
self.vocoder_config.num_mel_bins,
|
344 |
+
self.config.decoder_kernel_size,
|
345 |
+
self.config.decoder_stride_size
|
346 |
+
)
|
347 |
+
self.post_net = MelSpecRefineNet(config.audio_config, config.vocoder_config)
|
348 |
+
self.gradient_checkpointing = True
|
349 |
+
|
350 |
+
@torch.no_grad()
|
351 |
+
def fake_input(self, device):
|
352 |
+
audio_embed = torch.rand([1, 10, self.config.d_model], dtype=torch.float32, device=device)
|
353 |
+
input_length = torch.ones([1], dtype=torch.int32, device=device) * 10
|
354 |
+
mel_labels_length = self.post_net.compute_output_length(input_length)
|
355 |
+
return audio_embed, input_length, None, mel_labels_length
|
356 |
+
|
357 |
+
def forward(self,
|
358 |
+
audio_embed,
|
359 |
+
input_length,
|
360 |
+
mel_labels=None,
|
361 |
+
mel_labels_length=None,
|
362 |
+
fake_input=False,
|
363 |
+
):
|
364 |
+
if fake_input:
|
365 |
+
audio_embed, input_length, mel_labels, mel_labels_length = self.fake_input(self.layer_norm.weight.device)
|
366 |
+
|
367 |
+
assert (audio_embed.shape[-1] == self.config.d_model)
|
368 |
+
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
|
369 |
+
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
|
370 |
+
_, tgt_len, _ = audio_embed.size()
|
371 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
372 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
373 |
+
else:
|
374 |
+
current_positional_embedding = self.positional_embedding
|
375 |
+
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
|
376 |
+
|
377 |
+
# packing hidden states
|
378 |
+
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
|
379 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
|
380 |
+
|
381 |
+
for idx, encoder_layer in enumerate(self.layers):
|
382 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
383 |
+
|
384 |
+
hidden_states = self.layer_norm(hidden_states)
|
385 |
+
hidden_states_before_dconv2 = hidden_states
|
386 |
+
output_length_before_dconv2 = output_length
|
387 |
+
|
388 |
+
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
|
389 |
+
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
|
390 |
+
|
391 |
+
return OmniAudioDecoderOutput(
|
392 |
+
refined_mel=refined_mel,
|
393 |
+
coarse_mel=coarse_mel,
|
394 |
+
mel_length=mel_labels_length,
|
395 |
+
hidden_states_before_dconv2=hidden_states_before_dconv2,
|
396 |
+
output_length_before_dconv2=output_length_before_dconv2,
|
397 |
+
)
|
398 |
+
|
399 |
+
|
400 |
+
class OmniAudioVQBridgeTokenizer(nn.Module):
|
401 |
+
def __init__(self, config):
|
402 |
+
super().__init__()
|
403 |
+
self.config = config.audio_config
|
404 |
+
self.gradient_checkpointing = False
|
405 |
+
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
|
406 |
+
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
407 |
+
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
408 |
+
|
409 |
+
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
|
410 |
+
self.act_fn = ACT2FN['silu']
|
411 |
+
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
|
412 |
+
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
|
413 |
+
|
414 |
+
self.vq_list = nn.ModuleList([])
|
415 |
+
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
|
416 |
+
vq_config = copy.deepcopy(self.config.vq_config)
|
417 |
+
vq_config.dim = self.intermediate_dim
|
418 |
+
vq_config.codebook_size = codebook_size
|
419 |
+
self.vq_list.append(VectorQuantize(vq_config))
|
420 |
+
for vq_layer in self.vq_list:
|
421 |
+
deepspeed.zero.register_external_parameter(self, vq_layer.codebook.embed)
|
422 |
+
|
423 |
+
def rvq_op(self, inputs, output_length):
|
424 |
+
def rvq_layer_op(vq_layer, residual_encoding, output_length):
|
425 |
+
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
|
426 |
+
residual_encoding = residual_encoding.float() - q_v_i.float()
|
427 |
+
residual_encoding = residual_encoding.to(inputs.dtype)
|
428 |
+
return residual_encoding, code_ids_i
|
429 |
+
|
430 |
+
cmt_loss, residual_encoding = 0, inputs
|
431 |
+
code_ids_list = []
|
432 |
+
for i, vq_layer in enumerate(self.vq_list):
|
433 |
+
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
|
434 |
+
code_ids_list.append(code_ids_i)
|
435 |
+
return torch.stack(code_ids_list, -1)
|
436 |
+
|
437 |
+
def forward(self, x, output_length):
|
438 |
+
batch_size, _, _ = x.shape
|
439 |
+
output_length = output_length.to(x.device)
|
440 |
+
|
441 |
+
if x.shape[1] % self.config.avg_pooler != 0:
|
442 |
+
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
|
443 |
+
xt = x.permute(0, 2, 1)
|
444 |
+
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
|
445 |
+
u = self.up_proj(xt).permute(0, 2, 1)
|
446 |
+
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
|
447 |
+
|
448 |
+
c = self.down_proj(self.act_fn(g) * u)
|
449 |
+
res = self.layer_norm(c + x)
|
450 |
+
valid_mask, _ = get_sequence_mask(res, output_length)
|
451 |
+
code_ids = self.rvq_op(res, output_length)
|
452 |
+
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
|
453 |
+
return code_ids
|
454 |
+
|
455 |
+
@torch.no_grad()
|
456 |
+
def decode(self, code_ids):
|
457 |
+
vq_num = code_ids.shape[-1]
|
458 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
459 |
+
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
|
460 |
+
return decoder_emb
|
461 |
+
|
462 |
+
@torch.no_grad()
|
463 |
+
def recover(self, code_ids):
|
464 |
+
vq_num = code_ids.shape[-1]
|
465 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
466 |
+
return res
|
467 |
+
|
468 |
+
|
469 |
+
class FlowmatchingPrenet(nn.Module):
|
470 |
+
def __init__(
|
471 |
+
self,
|
472 |
+
input_feat_dim,
|
473 |
+
out_feat_dim,
|
474 |
+
d_model,
|
475 |
+
attention_heads,
|
476 |
+
ffn_dim,
|
477 |
+
nlayers,
|
478 |
+
activation_function,
|
479 |
+
max_source_positions,
|
480 |
+
target_mel_length_scale_ratio,
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
|
484 |
+
self.d_model = d_model
|
485 |
+
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
|
486 |
+
self.gradient_checkpointing = False
|
487 |
+
|
488 |
+
self.register_buffer(
|
489 |
+
"positional_embedding", sinusoids(max_source_positions, d_model)
|
490 |
+
)
|
491 |
+
|
492 |
+
self.in_mlp = nn.Sequential(
|
493 |
+
nn.Linear(input_feat_dim, d_model * 4),
|
494 |
+
nn.SiLU(),
|
495 |
+
nn.Linear(d_model * 4, d_model),
|
496 |
+
)
|
497 |
+
|
498 |
+
self.transformer_layers = nn.ModuleList(
|
499 |
+
[
|
500 |
+
OmniWhisperTransformerLayer(
|
501 |
+
act=ACT2FN[activation_function],
|
502 |
+
d_model=d_model,
|
503 |
+
encoder_attention_heads=attention_heads,
|
504 |
+
encoder_ffn_dim=ffn_dim,
|
505 |
+
causal=True, # causal
|
506 |
+
ln_type="RMSNorm",
|
507 |
+
)
|
508 |
+
for _ in range(nlayers)
|
509 |
+
]
|
510 |
+
)
|
511 |
+
|
512 |
+
self.final_norm = RMSNorm(self.d_model)
|
513 |
+
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
|
514 |
+
|
515 |
+
def compute_output_length(self, input_length):
|
516 |
+
output_length = input_length.float() * self.target_mel_length_scale_ratio
|
517 |
+
return output_length.to(torch.int64)
|
518 |
+
|
519 |
+
def forward(self, input_feat, input_length, output_length=None):
|
520 |
+
"""
|
521 |
+
Args:
|
522 |
+
input_feat: [B, T, input_feat_dim]
|
523 |
+
input_length: [B]
|
524 |
+
output_length: [B]
|
525 |
+
|
526 |
+
"""
|
527 |
+
if output_length is None or not self.training:
|
528 |
+
output_length = self.compute_output_length(input_length)
|
529 |
+
|
530 |
+
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
|
531 |
+
orig_dtype = input_feat.dtype
|
532 |
+
|
533 |
+
input_feat = F.interpolate(
|
534 |
+
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
|
535 |
+
size=output_length.max(),
|
536 |
+
mode="nearest",
|
537 |
+
).to(orig_dtype)
|
538 |
+
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
|
539 |
+
hidden_states = self.in_mlp(input_feat)
|
540 |
+
|
541 |
+
# packing hidden states
|
542 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
543 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
544 |
+
hidden_states, output_length
|
545 |
+
)
|
546 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
547 |
+
torch.sum(output_length), self.d_model
|
548 |
+
)
|
549 |
+
|
550 |
+
for idx, encoder_layer in enumerate(self.transformer_layers):
|
551 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
552 |
+
|
553 |
+
# unpacking
|
554 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
555 |
+
bsz, tgt_len, d_model
|
556 |
+
)
|
557 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
558 |
+
|
559 |
+
hidden_states = self.final_norm(hidden_states)
|
560 |
+
output = self.out_proj(hidden_states)
|
561 |
+
return output, output_length
|
562 |
+
|
563 |
+
|
564 |
+
@dataclass
|
565 |
+
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
|
566 |
+
flow_matching_mel: Optional[torch.FloatTensor] = None
|
567 |
+
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
|
568 |
+
|
569 |
+
|
570 |
+
class OmniAudioFlowMatchingDecoder(nn.Module):
|
571 |
+
def __init__(self, config):
|
572 |
+
super().__init__()
|
573 |
+
self.config = config.flow_matching_config
|
574 |
+
self.in_channels = self.config.in_channels
|
575 |
+
self.spk_emb_dim = self.config.spk_emb_dim
|
576 |
+
self.diffusion_steps = self.config.diffusion_steps
|
577 |
+
self.cal_mel_mae = self.config.cal_mel_mae
|
578 |
+
self.forward_step = -1
|
579 |
+
|
580 |
+
self.prenet = FlowmatchingPrenet(
|
581 |
+
input_feat_dim=self.config.prenet_in_dim,
|
582 |
+
out_feat_dim=self.config.prenet_out_dim,
|
583 |
+
d_model=self.config.prenet_d_model,
|
584 |
+
attention_heads=self.config.prenet_attention_heads,
|
585 |
+
ffn_dim=self.config.prenet_ffn_dim,
|
586 |
+
nlayers=self.config.prenet_nlayers,
|
587 |
+
activation_function=self.config.prenet_activation_function,
|
588 |
+
max_source_positions=self.config.prenet_max_source_positions,
|
589 |
+
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
|
590 |
+
)
|
591 |
+
|
592 |
+
self.conditional_decoder = ConditionalDecoder(
|
593 |
+
in_channels=self.in_channels * 2 + self.spk_emb_dim,
|
594 |
+
out_channels=self.in_channels,
|
595 |
+
causal=True,
|
596 |
+
channels=self.config.channels,
|
597 |
+
dropout=self.config.dropout,
|
598 |
+
attention_head_dim=self.config.attention_head_dim,
|
599 |
+
n_blocks=self.config.n_blocks,
|
600 |
+
num_mid_blocks=self.config.num_mid_blocks,
|
601 |
+
num_heads=self.config.num_heads,
|
602 |
+
act_fn=self.config.act_fn,
|
603 |
+
)
|
604 |
+
|
605 |
+
self.cfm = ConditionalCFM(
|
606 |
+
in_channels=self.in_channels,
|
607 |
+
cfm_params=self.config.cfm_params,
|
608 |
+
n_spks=0,
|
609 |
+
spk_emb_dim=self.spk_emb_dim,
|
610 |
+
)
|
611 |
+
|
612 |
+
|
613 |
+
def unpack_hidden_states(self, hidden_states, output_length):
|
614 |
+
unpacked = unpack_hidden_states(hidden_states, output_length)
|
615 |
+
return unpacked, output_length
|
616 |
+
|
617 |
+
def forward(
|
618 |
+
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
|
619 |
+
):
|
620 |
+
"""
|
621 |
+
:param refined_mel: [bs, max_input_len, mel_bin]
|
622 |
+
:param input_length: [batch_size]
|
623 |
+
:param refined_mel: [bs, mel_bin, max_input_len]
|
624 |
+
:return:
|
625 |
+
"""
|
626 |
+
self.forward_step += 1
|
627 |
+
|
628 |
+
orig_dtype = refined_mel.dtype
|
629 |
+
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
|
630 |
+
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
|
631 |
+
|
632 |
+
if self.prenet is not None:
|
633 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
634 |
+
if mel_labels_length is None:
|
635 |
+
mel_labels_length = self.prenet.compute_output_length(input_length)
|
636 |
+
refined_mel, input_length = self.prenet(
|
637 |
+
refined_mel, input_length, mel_labels_length
|
638 |
+
)
|
639 |
+
|
640 |
+
float_dtype = refined_mel.dtype
|
641 |
+
refined_mel = refined_mel.float()
|
642 |
+
input_length = input_length.long()
|
643 |
+
|
644 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
645 |
+
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
|
646 |
+
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
|
647 |
+
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
|
648 |
+
|
649 |
+
fm_mel = self.cfm.forward(
|
650 |
+
estimator=self.conditional_decoder,
|
651 |
+
mu=refined_mel.to(float_dtype),
|
652 |
+
mask=sequence_mask.float(),
|
653 |
+
n_timesteps=self.diffusion_steps,
|
654 |
+
)
|
655 |
+
return OmniAudioFlowMatchingDecoderOutput(
|
656 |
+
flow_matching_mel=fm_mel.transpose(1, 2),
|
657 |
+
flow_matching_mel_lengths=mel_labels_length,
|
658 |
+
)
|
config.json
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "_",
|
3 |
+
"architectures": [
|
4 |
+
"OmniForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_qkv_bias": true,
|
7 |
+
"attention_qkv_pack": true,
|
8 |
+
"audio_config": {
|
9 |
+
"audio_head_transformer_layers": 3,
|
10 |
+
"audio_delim_token_id": 151674,
|
11 |
+
"audio_end_token_id": 151658,
|
12 |
+
"audio_pad_token_id": 151659,
|
13 |
+
"audio_start_token_id": 151657,
|
14 |
+
"audiogen_end_token_id": 151679,
|
15 |
+
"audiogen_start_token_id": 151678,
|
16 |
+
"audiotext_end_token_id": 151676,
|
17 |
+
"audiotext_pad_token_id": 151677,
|
18 |
+
"audiotext_start_token_id": 151675,
|
19 |
+
"avg_pooler": 4,
|
20 |
+
"d_model": 1280,
|
21 |
+
"decoder_attention_heads": 20,
|
22 |
+
"decoder_ffn_dim": 5120,
|
23 |
+
"decoder_kernel_size": 3,
|
24 |
+
"decoder_layers": 8,
|
25 |
+
"decoder_stride_size": 2,
|
26 |
+
"enable": true,
|
27 |
+
"encoder_attention_heads": 20,
|
28 |
+
"encoder_ffn_dim": 5120,
|
29 |
+
"encoder_layers": 32,
|
30 |
+
"hop_length": 160,
|
31 |
+
"kernel_size": 3,
|
32 |
+
"max_audio_seconds": 30,
|
33 |
+
"n_fft": 400,
|
34 |
+
"num_mel_bins": 128,
|
35 |
+
"sampling_rate": 16000,
|
36 |
+
"stride_size": 2,
|
37 |
+
"split_overlap": 0.0,
|
38 |
+
"vq_config":{
|
39 |
+
"enable": true,
|
40 |
+
"codebook_sizes": [8192, 4096, 2048, 1024, 1024, 1024, 1024, 1024]
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"auto_map": {
|
44 |
+
"AutoConfig": "configuration_omni.OmniConfig",
|
45 |
+
"AutoModelForCausalLM": "modeling_omni.OmniForCausalLM"
|
46 |
+
},
|
47 |
+
"omni_tokenizer_type": "auto",
|
48 |
+
"bos_token_id": 1,
|
49 |
+
"eos_token_id": 2,
|
50 |
+
"flow_matching_config": {
|
51 |
+
"enable": true,
|
52 |
+
"use_hires_mel": true,
|
53 |
+
"sampling_rate": 24000,
|
54 |
+
"hop_length": 480,
|
55 |
+
"max_audio_seconds": 30,
|
56 |
+
"split_overlap": 0.1,
|
57 |
+
"use_hidden_states_before_dconv2": true,
|
58 |
+
"prenet_in_dim": 1280,
|
59 |
+
"prenet_out_dim": 80,
|
60 |
+
"prenet_d_model": 512,
|
61 |
+
"prenet_attention_heads": 8,
|
62 |
+
"prenet_ffn_dim": 2048,
|
63 |
+
"prenet_nlayers": 12,
|
64 |
+
"prenet_activation_function": "gelu",
|
65 |
+
"prenet_max_source_positions": 5000,
|
66 |
+
"prenet_target_mel_length_scale_ratio": 1.0,
|
67 |
+
"prenet_loss_weight": 1.0,
|
68 |
+
"unet_use_omni_attn": false,
|
69 |
+
"loss_weight": 1.0,
|
70 |
+
"in_channels": 80,
|
71 |
+
"spk_emb_dim": 0,
|
72 |
+
"diffusion_steps": 10,
|
73 |
+
"channels": [256],
|
74 |
+
"dropout": 0.0,
|
75 |
+
"attention_head_dim": 64,
|
76 |
+
"n_blocks": 4,
|
77 |
+
"num_mid_blocks": 12,
|
78 |
+
"num_heads": 8,
|
79 |
+
"act_fn": "gelu",
|
80 |
+
"cal_mel_mae": true,
|
81 |
+
"cfm_params": {
|
82 |
+
"sigma_min": 1e-6,
|
83 |
+
"solver": "euler",
|
84 |
+
"t_scheduler": "cosine",
|
85 |
+
"training_cfg_rate": 0.2,
|
86 |
+
"inference_cfg_rate": 0.7,
|
87 |
+
"reg_loss_type": "l1"
|
88 |
+
}
|
89 |
+
},
|
90 |
+
"head_dim": 128,
|
91 |
+
"hidden_act": "silu",
|
92 |
+
"hidden_size": 3584,
|
93 |
+
"initializer_range": 0.02,
|
94 |
+
"intermediate_size": 18944,
|
95 |
+
"max_position_embeddings": 65536,
|
96 |
+
"max_window_layers": 28,
|
97 |
+
"model_type": "omni",
|
98 |
+
"multimodal": [
|
99 |
+
"audio",
|
100 |
+
"audiogen"
|
101 |
+
],
|
102 |
+
"multimodal_special_token_list": [
|
103 |
+
151657,
|
104 |
+
151658,
|
105 |
+
151659,
|
106 |
+
151674,
|
107 |
+
151675,
|
108 |
+
151676,
|
109 |
+
151677,
|
110 |
+
151678,
|
111 |
+
151679
|
112 |
+
],
|
113 |
+
"num_attention_heads": 28,
|
114 |
+
"num_hidden_layers": 28,
|
115 |
+
"num_key_value_heads": 4,
|
116 |
+
"pad_token_id": 0,
|
117 |
+
"position_embedding_type": "rope",
|
118 |
+
"rms_norm_eps": 1e-06,
|
119 |
+
"rope_theta": 1000000.0,
|
120 |
+
"sliding_window": 131072,
|
121 |
+
"sparse_attention_heads": null,
|
122 |
+
"sparse_attention_layers": [],
|
123 |
+
"tie_word_embeddings": false,
|
124 |
+
"torch_dtype": "bfloat16",
|
125 |
+
"train_multimodal_special_tokens_only": false,
|
126 |
+
"transformers_version": "4.45.0.dev0",
|
127 |
+
"use_cache": false,
|
128 |
+
"use_norm_head": false,
|
129 |
+
"use_sliding_window": false,
|
130 |
+
"video_config": {
|
131 |
+
"_name_or_path": "",
|
132 |
+
"_attn_implementation": "flash_attention_2",
|
133 |
+
"decode_way": "1fps",
|
134 |
+
"depth": 32,
|
135 |
+
"embed_dim": 1280,
|
136 |
+
"enable": false,
|
137 |
+
"hidden_act": "quick_gelu",
|
138 |
+
"hidden_size": 3584,
|
139 |
+
"image_delimiter_token_id": 151688,
|
140 |
+
"image_end_token_id": 151680,
|
141 |
+
"image_line_token_id": 151682,
|
142 |
+
"image_mean": [
|
143 |
+
0.48145466,
|
144 |
+
0.4578275,
|
145 |
+
0.40821073
|
146 |
+
],
|
147 |
+
"image_pad_token_id": 151681,
|
148 |
+
"image_size": 224,
|
149 |
+
"image_start_token_id": 151679,
|
150 |
+
"image_std": [
|
151 |
+
0.26862954,
|
152 |
+
0.26130258,
|
153 |
+
0.27577711
|
154 |
+
],
|
155 |
+
"in_channels": 3,
|
156 |
+
"in_chans": 3,
|
157 |
+
"intermediate_size": 3072,
|
158 |
+
"layer_norm_eps": 1e-05,
|
159 |
+
"max_frame_num": 32,
|
160 |
+
"max_length": 20,
|
161 |
+
"max_pixels": 602112,
|
162 |
+
"merge_size": 2,
|
163 |
+
"min_length": 0,
|
164 |
+
"min_pixels": 3136,
|
165 |
+
"mlp_ratio": 4,
|
166 |
+
"model_type": "clip_vision_model",
|
167 |
+
"num_attention_heads": 12,
|
168 |
+
"num_channels": 3,
|
169 |
+
"num_heads": 16,
|
170 |
+
"num_hidden_layers": 12,
|
171 |
+
"patch_size": 14,
|
172 |
+
"spatial_merge_size": 2,
|
173 |
+
"spatial_patch_size": 14,
|
174 |
+
"temporal_patch_size": 2,
|
175 |
+
"video_end_token_id": 151696,
|
176 |
+
"video_place_token_id": 151694,
|
177 |
+
"video_start_token_id": 151695
|
178 |
+
},
|
179 |
+
"visual_config": {
|
180 |
+
"_name_or_path": "",
|
181 |
+
"_attn_implementation": "flash_attention_2",
|
182 |
+
"depth": 32,
|
183 |
+
"diversity_penalty": 0.0,
|
184 |
+
"do_sample": false,
|
185 |
+
"early_stopping": false,
|
186 |
+
"embed_dim": 1280,
|
187 |
+
"enable": false,
|
188 |
+
"hidden_act": "quick_gelu",
|
189 |
+
"hidden_size": 3584,
|
190 |
+
"image_delimiter_token_id": 151688,
|
191 |
+
"image_end_token_id": 151680,
|
192 |
+
"image_line_token_id": 151682,
|
193 |
+
"image_mean": [
|
194 |
+
0.48145466,
|
195 |
+
0.4578275,
|
196 |
+
0.40821073
|
197 |
+
],
|
198 |
+
"image_pad_token_id": 151681,
|
199 |
+
"image_size": 224,
|
200 |
+
"image_start_token_id": 151679,
|
201 |
+
"image_std": [
|
202 |
+
0.26862954,
|
203 |
+
0.26130258,
|
204 |
+
0.27577711
|
205 |
+
],
|
206 |
+
"in_channels": 3,
|
207 |
+
"in_chans": 3,
|
208 |
+
"intermediate_size": 3072,
|
209 |
+
"layer_norm_eps": 1e-05,
|
210 |
+
"length_penalty": 1.0,
|
211 |
+
"max_length": 20,
|
212 |
+
"max_pixels": 3211264,
|
213 |
+
"merge_size": 2,
|
214 |
+
"min_length": 0,
|
215 |
+
"min_pixels": 3136,
|
216 |
+
"mlp_ratio": 4,
|
217 |
+
"model_type": "clip_vision_model",
|
218 |
+
"num_attention_heads": 12,
|
219 |
+
"num_channels": 3,
|
220 |
+
"num_heads": 16,
|
221 |
+
"num_hidden_layers": 12,
|
222 |
+
"patch_size": 14,
|
223 |
+
"projection_dim": 512,
|
224 |
+
"spatial_merge_size": 2,
|
225 |
+
"spatial_patch_size": 14,
|
226 |
+
"temporal_patch_size": 2
|
227 |
+
},
|
228 |
+
"vocab_size": 152064,
|
229 |
+
"vocoder_config":{
|
230 |
+
"enable": true,
|
231 |
+
"enable_multi_scale": true,
|
232 |
+
"max_audio_seconds": 30,
|
233 |
+
"sampling_rate": 16000,
|
234 |
+
"hop_length": 256,
|
235 |
+
"split_overlap": 0.0,
|
236 |
+
"n_fft": 1024,
|
237 |
+
"num_mel_bins": 80,
|
238 |
+
"channels": [256, 256, 256, 256, 256]
|
239 |
+
}
|
240 |
+
}
|
configuration_omni.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.utils import logging
|
24 |
+
from transformers import WhisperConfig
|
25 |
+
from transformers import CLIPVisionConfig
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class OmniConfig(PretrainedConfig):
|
31 |
+
model_type = "omni"
|
32 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vocab_size=125696,
|
37 |
+
hidden_size=4096,
|
38 |
+
intermediate_size=11008,
|
39 |
+
num_hidden_layers=32,
|
40 |
+
num_attention_heads=32,
|
41 |
+
num_key_value_heads=None,
|
42 |
+
sparse_attention_heads=None,
|
43 |
+
sparse_attention_layers=[],
|
44 |
+
head_dim=None,
|
45 |
+
attention_qkv_pack=True,
|
46 |
+
attention_qkv_bias=False,
|
47 |
+
use_norm_head=True,
|
48 |
+
hidden_act="silu",
|
49 |
+
max_position_embeddings=4096,
|
50 |
+
position_embedding_type="rope",
|
51 |
+
initializer_range=0.02,
|
52 |
+
rms_norm_eps=1e-6,
|
53 |
+
use_cache=True,
|
54 |
+
pad_token_id=0,
|
55 |
+
bos_token_id=1,
|
56 |
+
eos_token_id=2,
|
57 |
+
tie_word_embeddings=False,
|
58 |
+
audio_config=None,
|
59 |
+
visual_config=None,
|
60 |
+
video_config=None,
|
61 |
+
vocoder_config=None,
|
62 |
+
flow_matching_config=None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
self.vocab_size = vocab_size
|
66 |
+
self.max_position_embeddings = max_position_embeddings
|
67 |
+
self.hidden_size = hidden_size
|
68 |
+
self.intermediate_size = intermediate_size
|
69 |
+
self.num_hidden_layers = num_hidden_layers
|
70 |
+
self.num_attention_heads = num_attention_heads
|
71 |
+
self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
|
72 |
+
self.sparse_attention_heads = sparse_attention_heads
|
73 |
+
self.sparse_attention_layers = sparse_attention_layers
|
74 |
+
self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
|
75 |
+
self.attention_qkv_pack = attention_qkv_pack
|
76 |
+
self.attention_qkv_bias = attention_qkv_bias
|
77 |
+
self.use_norm_head = use_norm_head
|
78 |
+
self.hidden_act = hidden_act
|
79 |
+
self.position_embedding_type = position_embedding_type
|
80 |
+
self.initializer_range = initializer_range
|
81 |
+
self.rms_norm_eps = rms_norm_eps
|
82 |
+
self.use_cache = use_cache
|
83 |
+
assert self.position_embedding_type.lower() in ("rope", "alibi")
|
84 |
+
super().__init__(
|
85 |
+
pad_token_id=pad_token_id,
|
86 |
+
bos_token_id=bos_token_id,
|
87 |
+
eos_token_id=eos_token_id,
|
88 |
+
tie_word_embeddings=tie_word_embeddings,
|
89 |
+
**kwargs,
|
90 |
+
)
|
91 |
+
if audio_config is not None:
|
92 |
+
self.audio_config = WhisperConfig(**audio_config)
|
93 |
+
if self.audio_config.vq_config is not None:
|
94 |
+
self.audio_config.vq_config = PretrainedConfig(**self.audio_config.vq_config)
|
95 |
+
if vocoder_config is not None:
|
96 |
+
self.vocoder_config = WhisperConfig(**vocoder_config)
|
97 |
+
if flow_matching_config is not None:
|
98 |
+
self.flow_matching_config = PretrainedConfig(**flow_matching_config)
|
99 |
+
self.flow_matching_config.cfm_params = PretrainedConfig(**self.flow_matching_config.cfm_params)
|
100 |
+
if visual_config is not None:
|
101 |
+
self.visual_config = CLIPVisionConfig(**visual_config)
|
102 |
+
if video_config is not None:
|
103 |
+
self.video_config = CLIPVisionConfig(**video_config)
|
104 |
+
|
105 |
+
|
106 |
+
def to_diff_dict(self):
|
107 |
+
data = super().to_diff_dict()
|
108 |
+
data["model_type"] = self.model_type
|
109 |
+
return data
|
110 |
+
|
111 |
+
def get_rotary_base(self):
|
112 |
+
if hasattr(self, "rotary_emb_base"):
|
113 |
+
return self.rotary_emb_base
|
114 |
+
else:
|
115 |
+
return self.rope_theta
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
from transformers import AutoConfig
|
119 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
120 |
+
print(config)
|
flow_matching.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
|
2 |
+
"""
|
3 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from abc import ABC
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from typing import Dict, Optional
|
22 |
+
|
23 |
+
import torch.nn as nn
|
24 |
+
from einops import pack, rearrange, repeat
|
25 |
+
from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
26 |
+
from .matcha_transformer import BasicTransformerBlock
|
27 |
+
from omegaconf import DictConfig
|
28 |
+
|
29 |
+
|
30 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
31 |
+
assert mask.dtype == torch.bool
|
32 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
33 |
+
mask = mask.to(dtype)
|
34 |
+
# attention mask bias
|
35 |
+
# NOTE(Mddct): torch.finfo jit issues
|
36 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
37 |
+
mask = (1.0 - mask) * torch.finfo(dtype).min
|
38 |
+
return mask
|
39 |
+
|
40 |
+
|
41 |
+
def subsequent_chunk_mask(
|
42 |
+
size: int,
|
43 |
+
chunk_size: int,
|
44 |
+
num_left_chunks: int = -1,
|
45 |
+
device: torch.device = torch.device("cpu"),
|
46 |
+
) -> torch.Tensor:
|
47 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
48 |
+
this is for streaming encoder
|
49 |
+
|
50 |
+
Args:
|
51 |
+
size (int): size of mask
|
52 |
+
chunk_size (int): size of chunk
|
53 |
+
num_left_chunks (int): number of left chunks
|
54 |
+
<0: use full chunk
|
55 |
+
>=0: use num_left_chunks
|
56 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: mask
|
60 |
+
|
61 |
+
Examples:
|
62 |
+
>>> subsequent_chunk_mask(4, 2)
|
63 |
+
[[1, 1, 0, 0],
|
64 |
+
[1, 1, 0, 0],
|
65 |
+
[1, 1, 1, 1],
|
66 |
+
[1, 1, 1, 1]]
|
67 |
+
"""
|
68 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
69 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
70 |
+
pos_idx = torch.arange(size, device=device)
|
71 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
72 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
73 |
+
return ret
|
74 |
+
|
75 |
+
def subsequent_mask(
|
76 |
+
size: int,
|
77 |
+
device: torch.device = torch.device("cpu"),
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Create mask for subsequent steps (size, size).
|
80 |
+
|
81 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
82 |
+
This means the current step could only do attention with its left steps.
|
83 |
+
|
84 |
+
In encoder, fully attention is used when streaming is not necessary and
|
85 |
+
the sequence is not long. In this case, no attention mask is needed.
|
86 |
+
|
87 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
88 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
size (int): size of mask
|
92 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
93 |
+
dtype (torch.device): result dtype
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: mask
|
97 |
+
|
98 |
+
Examples:
|
99 |
+
>>> subsequent_mask(3)
|
100 |
+
[[1, 0, 0],
|
101 |
+
[1, 1, 0],
|
102 |
+
[1, 1, 1]]
|
103 |
+
"""
|
104 |
+
arange = torch.arange(size, device=device)
|
105 |
+
mask = arange.expand(size, size)
|
106 |
+
arange = arange.unsqueeze(-1)
|
107 |
+
mask = mask <= arange
|
108 |
+
return mask
|
109 |
+
|
110 |
+
|
111 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
112 |
+
masks: torch.Tensor,
|
113 |
+
use_dynamic_chunk: bool,
|
114 |
+
use_dynamic_left_chunk: bool,
|
115 |
+
decoding_chunk_size: int,
|
116 |
+
static_chunk_size: int,
|
117 |
+
num_decoding_left_chunks: int,
|
118 |
+
enable_full_context: bool = True):
|
119 |
+
""" Apply optional mask for encoder.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
123 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
124 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
125 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
126 |
+
training.
|
127 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
128 |
+
0: default for training, use random dynamic chunk.
|
129 |
+
<0: for decoding, use full chunk.
|
130 |
+
>0: for decoding, use fixed chunk size as set.
|
131 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
132 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
133 |
+
this parameter will be ignored
|
134 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
135 |
+
the chunk size is decoding_chunk_size.
|
136 |
+
>=0: use num_decoding_left_chunks
|
137 |
+
<0: use all left chunks
|
138 |
+
enable_full_context (bool):
|
139 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
140 |
+
False: chunk size ~ U[1, 25]
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
torch.Tensor: chunk mask of the input xs.
|
144 |
+
"""
|
145 |
+
# Whether to use chunk mask or not
|
146 |
+
if use_dynamic_chunk:
|
147 |
+
max_len = xs.size(1)
|
148 |
+
if decoding_chunk_size < 0:
|
149 |
+
chunk_size = max_len
|
150 |
+
num_left_chunks = -1
|
151 |
+
elif decoding_chunk_size > 0:
|
152 |
+
chunk_size = decoding_chunk_size
|
153 |
+
num_left_chunks = num_decoding_left_chunks
|
154 |
+
else:
|
155 |
+
# chunk size is either [1, 25] or full context(max_len).
|
156 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
157 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
158 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
159 |
+
num_left_chunks = -1
|
160 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
161 |
+
chunk_size = max_len
|
162 |
+
else:
|
163 |
+
chunk_size = chunk_size % 25 + 1
|
164 |
+
if use_dynamic_left_chunk:
|
165 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
166 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
167 |
+
(1, )).item()
|
168 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
169 |
+
num_left_chunks,
|
170 |
+
xs.device) # (L, L)
|
171 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
172 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
173 |
+
elif static_chunk_size > 0:
|
174 |
+
num_left_chunks = num_decoding_left_chunks
|
175 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
176 |
+
num_left_chunks,
|
177 |
+
xs.device) # (L, L)
|
178 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
179 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
180 |
+
else:
|
181 |
+
chunk_masks = masks
|
182 |
+
return chunk_masks
|
183 |
+
|
184 |
+
|
185 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
186 |
+
"""Make mask tensor containing indices of padded part.
|
187 |
+
|
188 |
+
See description of make_non_pad_mask.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
192 |
+
Returns:
|
193 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
194 |
+
|
195 |
+
Examples:
|
196 |
+
>>> lengths = [5, 3, 2]
|
197 |
+
>>> make_pad_mask(lengths)
|
198 |
+
masks = [[0, 0, 0, 0 ,0],
|
199 |
+
[0, 0, 0, 1, 1],
|
200 |
+
[0, 0, 1, 1, 1]]
|
201 |
+
"""
|
202 |
+
batch_size = lengths.size(0)
|
203 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
204 |
+
seq_range = torch.arange(0,
|
205 |
+
max_len,
|
206 |
+
dtype=torch.int64,
|
207 |
+
device=lengths.device)
|
208 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
209 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
210 |
+
mask = seq_range_expand >= seq_length_expand
|
211 |
+
return mask
|
212 |
+
|
213 |
+
# Causal
|
214 |
+
class Transpose(torch.nn.Module):
|
215 |
+
def __init__(self, dim0: int, dim1: int):
|
216 |
+
super().__init__()
|
217 |
+
self.dim0 = dim0
|
218 |
+
self.dim1 = dim1
|
219 |
+
|
220 |
+
def forward(self, x: torch.Tensor):
|
221 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
222 |
+
return x
|
223 |
+
|
224 |
+
class CausalBlock1D(Block1D):
|
225 |
+
def __init__(self, dim: int, dim_out: int):
|
226 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
227 |
+
self.block = torch.nn.Sequential(
|
228 |
+
CausalConv1d(dim, dim_out, 3),
|
229 |
+
Transpose(1, 2),
|
230 |
+
nn.LayerNorm(dim_out),
|
231 |
+
Transpose(1, 2),
|
232 |
+
nn.Mish(),
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
236 |
+
output = self.block(x * mask)
|
237 |
+
return output * mask
|
238 |
+
|
239 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
240 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
241 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
242 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
243 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
244 |
+
|
245 |
+
class CausalConv1d(torch.nn.Conv1d):
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
in_channels: int,
|
249 |
+
out_channels: int,
|
250 |
+
kernel_size: int,
|
251 |
+
stride: int = 1,
|
252 |
+
dilation: int = 1,
|
253 |
+
groups: int = 1,
|
254 |
+
bias: bool = True,
|
255 |
+
padding_mode: str = 'zeros',
|
256 |
+
device=None,
|
257 |
+
dtype=None
|
258 |
+
) -> None:
|
259 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
260 |
+
kernel_size, stride,
|
261 |
+
padding=0, dilation=dilation,
|
262 |
+
groups=groups, bias=bias,
|
263 |
+
padding_mode=padding_mode,
|
264 |
+
device=device, dtype=dtype)
|
265 |
+
assert stride == 1
|
266 |
+
self.causal_padding = (kernel_size - 1, 0)
|
267 |
+
|
268 |
+
def forward(self, x: torch.Tensor):
|
269 |
+
x = F.pad(x, self.causal_padding)
|
270 |
+
x = super(CausalConv1d, self).forward(x)
|
271 |
+
return x
|
272 |
+
|
273 |
+
|
274 |
+
class BASECFM(torch.nn.Module, ABC):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
n_feats,
|
278 |
+
cfm_params,
|
279 |
+
n_spks=1,
|
280 |
+
spk_emb_dim=128,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
self.n_feats = n_feats
|
284 |
+
self.n_spks = n_spks
|
285 |
+
self.spk_emb_dim = spk_emb_dim
|
286 |
+
self.solver = cfm_params.solver
|
287 |
+
if hasattr(cfm_params, "sigma_min"):
|
288 |
+
self.sigma_min = cfm_params.sigma_min
|
289 |
+
else:
|
290 |
+
self.sigma_min = 1e-4
|
291 |
+
|
292 |
+
self.estimator = None
|
293 |
+
|
294 |
+
@torch.inference_mode()
|
295 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
296 |
+
"""Forward diffusion
|
297 |
+
|
298 |
+
Args:
|
299 |
+
mu (torch.Tensor): output of encoder
|
300 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
301 |
+
mask (torch.Tensor): output_mask
|
302 |
+
shape: (batch_size, 1, mel_timesteps)
|
303 |
+
n_timesteps (int): number of diffusion steps
|
304 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
305 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
306 |
+
shape: (batch_size, spk_emb_dim)
|
307 |
+
cond: Not used but kept for future purposes
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
sample: generated mel-spectrogram
|
311 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
312 |
+
"""
|
313 |
+
z = torch.randn_like(mu) * temperature
|
314 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
315 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
316 |
+
|
317 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
318 |
+
"""
|
319 |
+
Fixed euler solver for ODEs.
|
320 |
+
Args:
|
321 |
+
x (torch.Tensor): random noise
|
322 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
323 |
+
shape: (n_timesteps + 1,)
|
324 |
+
mu (torch.Tensor): output of encoder
|
325 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
326 |
+
mask (torch.Tensor): output_mask
|
327 |
+
shape: (batch_size, 1, mel_timesteps)
|
328 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
329 |
+
shape: (batch_size, spk_emb_dim)
|
330 |
+
cond: Not used but kept for future purposes
|
331 |
+
"""
|
332 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
333 |
+
|
334 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
335 |
+
# Or in future might add like a return_all_steps flag
|
336 |
+
sol = []
|
337 |
+
|
338 |
+
for step in range(1, len(t_span)):
|
339 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
340 |
+
|
341 |
+
x = x + dt * dphi_dt
|
342 |
+
t = t + dt
|
343 |
+
sol.append(x)
|
344 |
+
if step < len(t_span) - 1:
|
345 |
+
dt = t_span[step + 1] - t
|
346 |
+
|
347 |
+
return sol[-1]
|
348 |
+
|
349 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
350 |
+
"""Computes diffusion loss
|
351 |
+
|
352 |
+
Args:
|
353 |
+
x1 (torch.Tensor): Target
|
354 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
355 |
+
mask (torch.Tensor): target mask
|
356 |
+
shape: (batch_size, 1, mel_timesteps)
|
357 |
+
mu (torch.Tensor): output of encoder
|
358 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
359 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
360 |
+
shape: (batch_size, spk_emb_dim)
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
loss: conditional flow matching loss
|
364 |
+
y: conditional flow
|
365 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
366 |
+
"""
|
367 |
+
b, _, t = mu.shape
|
368 |
+
|
369 |
+
# random timestep
|
370 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
371 |
+
# sample noise p(x_0)
|
372 |
+
z = torch.randn_like(x1)
|
373 |
+
|
374 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
375 |
+
u = x1 - (1 - self.sigma_min) * z
|
376 |
+
|
377 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
378 |
+
torch.sum(mask) * u.shape[1]
|
379 |
+
)
|
380 |
+
return loss, y
|
381 |
+
|
382 |
+
|
383 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
384 |
+
"""Make mask tensor containing indices of padded part.
|
385 |
+
|
386 |
+
See description of make_non_pad_mask.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
390 |
+
Returns:
|
391 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
392 |
+
|
393 |
+
Examples:
|
394 |
+
>>> lengths = [5, 3, 2]
|
395 |
+
>>> make_pad_mask(lengths)
|
396 |
+
masks = [[0, 0, 0, 0 ,0],
|
397 |
+
[0, 0, 0, 1, 1],
|
398 |
+
[0, 0, 1, 1, 1]]
|
399 |
+
"""
|
400 |
+
batch_size = lengths.size(0)
|
401 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
402 |
+
seq_range = torch.arange(0,
|
403 |
+
max_len,
|
404 |
+
dtype=torch.int64,
|
405 |
+
device=lengths.device)
|
406 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
407 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
408 |
+
mask = seq_range_expand >= seq_length_expand
|
409 |
+
return mask
|
410 |
+
|
411 |
+
|
412 |
+
class ConditionalDecoder(nn.Module):
|
413 |
+
def __init__(
|
414 |
+
self,
|
415 |
+
in_channels,
|
416 |
+
out_channels,
|
417 |
+
causal=False,
|
418 |
+
channels=(256, 256),
|
419 |
+
dropout=0.05,
|
420 |
+
attention_head_dim=64,
|
421 |
+
n_blocks=1,
|
422 |
+
num_mid_blocks=2,
|
423 |
+
num_heads=4,
|
424 |
+
act_fn="snake",
|
425 |
+
gradient_checkpointing=True,
|
426 |
+
):
|
427 |
+
"""
|
428 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
429 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
430 |
+
"""
|
431 |
+
super().__init__()
|
432 |
+
channels = tuple(channels)
|
433 |
+
self.in_channels = in_channels
|
434 |
+
self.out_channels = out_channels
|
435 |
+
self.causal = causal
|
436 |
+
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
|
437 |
+
self.gradient_checkpointing = gradient_checkpointing
|
438 |
+
|
439 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
440 |
+
time_embed_dim = channels[0] * 4
|
441 |
+
self.time_mlp = TimestepEmbedding(
|
442 |
+
in_channels=in_channels,
|
443 |
+
time_embed_dim=time_embed_dim,
|
444 |
+
act_fn="silu",
|
445 |
+
)
|
446 |
+
self.down_blocks = nn.ModuleList([])
|
447 |
+
self.mid_blocks = nn.ModuleList([])
|
448 |
+
self.up_blocks = nn.ModuleList([])
|
449 |
+
|
450 |
+
output_channel = in_channels
|
451 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
452 |
+
input_channel = output_channel
|
453 |
+
output_channel = channels[i]
|
454 |
+
is_last = i == len(channels) - 1
|
455 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
456 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
457 |
+
transformer_blocks = nn.ModuleList(
|
458 |
+
[
|
459 |
+
BasicTransformerBlock(
|
460 |
+
dim=output_channel,
|
461 |
+
num_attention_heads=num_heads,
|
462 |
+
attention_head_dim=attention_head_dim,
|
463 |
+
dropout=dropout,
|
464 |
+
activation_fn=act_fn,
|
465 |
+
)
|
466 |
+
for _ in range(n_blocks)
|
467 |
+
]
|
468 |
+
)
|
469 |
+
downsample = (
|
470 |
+
Downsample1D(output_channel) if not is_last else
|
471 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
472 |
+
)
|
473 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
474 |
+
|
475 |
+
for _ in range(num_mid_blocks):
|
476 |
+
input_channel = channels[-1]
|
477 |
+
out_channels = channels[-1]
|
478 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
479 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
480 |
+
|
481 |
+
transformer_blocks = nn.ModuleList(
|
482 |
+
[
|
483 |
+
BasicTransformerBlock(
|
484 |
+
dim=output_channel,
|
485 |
+
num_attention_heads=num_heads,
|
486 |
+
attention_head_dim=attention_head_dim,
|
487 |
+
dropout=dropout,
|
488 |
+
activation_fn=act_fn,
|
489 |
+
)
|
490 |
+
for _ in range(n_blocks)
|
491 |
+
]
|
492 |
+
)
|
493 |
+
|
494 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
495 |
+
|
496 |
+
channels = channels[::-1] + (channels[0],)
|
497 |
+
for i in range(len(channels) - 1):
|
498 |
+
input_channel = channels[i] * 2
|
499 |
+
output_channel = channels[i + 1]
|
500 |
+
is_last = i == len(channels) - 2
|
501 |
+
resnet = CausalResnetBlock1D(
|
502 |
+
dim=input_channel,
|
503 |
+
dim_out=output_channel,
|
504 |
+
time_emb_dim=time_embed_dim,
|
505 |
+
) if self.causal else ResnetBlock1D(
|
506 |
+
dim=input_channel,
|
507 |
+
dim_out=output_channel,
|
508 |
+
time_emb_dim=time_embed_dim,
|
509 |
+
)
|
510 |
+
transformer_blocks = nn.ModuleList(
|
511 |
+
[
|
512 |
+
BasicTransformerBlock(
|
513 |
+
dim=output_channel,
|
514 |
+
num_attention_heads=num_heads,
|
515 |
+
attention_head_dim=attention_head_dim,
|
516 |
+
dropout=dropout,
|
517 |
+
activation_fn=act_fn,
|
518 |
+
)
|
519 |
+
for _ in range(n_blocks)
|
520 |
+
]
|
521 |
+
)
|
522 |
+
upsample = (
|
523 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
524 |
+
if not is_last
|
525 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
526 |
+
)
|
527 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
528 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
529 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
530 |
+
self.initialize_weights()
|
531 |
+
|
532 |
+
def initialize_weights(self):
|
533 |
+
for m in self.modules():
|
534 |
+
if isinstance(m, nn.Conv1d):
|
535 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
536 |
+
if m.bias is not None:
|
537 |
+
nn.init.constant_(m.bias, 0)
|
538 |
+
elif isinstance(m, nn.GroupNorm):
|
539 |
+
nn.init.constant_(m.weight, 1)
|
540 |
+
nn.init.constant_(m.bias, 0)
|
541 |
+
elif isinstance(m, nn.Linear):
|
542 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
543 |
+
if m.bias is not None:
|
544 |
+
nn.init.constant_(m.bias, 0)
|
545 |
+
|
546 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
547 |
+
"""Forward pass of the UNet1DConditional model.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
551 |
+
mask (_type_): shape (batch_size, 1, time)
|
552 |
+
t (_type_): shape (batch_size)
|
553 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
554 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
555 |
+
|
556 |
+
Raises:
|
557 |
+
ValueError: _description_
|
558 |
+
ValueError: _description_
|
559 |
+
|
560 |
+
Returns:
|
561 |
+
_type_: _description_
|
562 |
+
"""
|
563 |
+
t = self.time_embeddings(t)
|
564 |
+
t = t.to(x.dtype)
|
565 |
+
t = self.time_mlp(t)
|
566 |
+
x = pack([x, mu], "b * t")[0]
|
567 |
+
mask = mask.to(x.dtype)
|
568 |
+
if spks is not None:
|
569 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
570 |
+
x = pack([x, spks], "b * t")[0]
|
571 |
+
if cond is not None:
|
572 |
+
x = pack([x, cond], "b * t")[0]
|
573 |
+
|
574 |
+
hiddens = []
|
575 |
+
masks = [mask]
|
576 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
577 |
+
mask_down = masks[-1]
|
578 |
+
x = resnet(x, mask_down, t)
|
579 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
580 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
581 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
582 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
583 |
+
for transformer_block in transformer_blocks:
|
584 |
+
if self.gradient_checkpointing and self.training:
|
585 |
+
def create_custom_forward(module):
|
586 |
+
def custom_forward(*inputs):
|
587 |
+
return module(*inputs)
|
588 |
+
return custom_forward
|
589 |
+
x = torch.utils.checkpoint.checkpoint(
|
590 |
+
create_custom_forward(transformer_block),
|
591 |
+
x,
|
592 |
+
attn_mask,
|
593 |
+
t,
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
x = transformer_block(
|
597 |
+
hidden_states=x,
|
598 |
+
attention_mask=attn_mask,
|
599 |
+
timestep=t,
|
600 |
+
)
|
601 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
602 |
+
hiddens.append(x) # Save hidden states for skip connections
|
603 |
+
x = downsample(x * mask_down)
|
604 |
+
masks.append(mask_down[:, :, ::2])
|
605 |
+
masks = masks[:-1]
|
606 |
+
mask_mid = masks[-1]
|
607 |
+
|
608 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
609 |
+
x = resnet(x, mask_mid, t)
|
610 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
611 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
612 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
613 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
614 |
+
for transformer_block in transformer_blocks:
|
615 |
+
if self.gradient_checkpointing and self.training:
|
616 |
+
def create_custom_forward(module):
|
617 |
+
def custom_forward(*inputs):
|
618 |
+
return module(*inputs)
|
619 |
+
return custom_forward
|
620 |
+
x = torch.utils.checkpoint.checkpoint(
|
621 |
+
create_custom_forward(transformer_block),
|
622 |
+
x,
|
623 |
+
attn_mask,
|
624 |
+
t,
|
625 |
+
)
|
626 |
+
else:
|
627 |
+
x = transformer_block(
|
628 |
+
hidden_states=x,
|
629 |
+
attention_mask=attn_mask,
|
630 |
+
timestep=t,
|
631 |
+
)
|
632 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
633 |
+
|
634 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
635 |
+
mask_up = masks.pop()
|
636 |
+
skip = hiddens.pop()
|
637 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
638 |
+
x = resnet(x, mask_up, t)
|
639 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
640 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
641 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
642 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
643 |
+
for transformer_block in transformer_blocks:
|
644 |
+
if self.gradient_checkpointing and self.training:
|
645 |
+
def create_custom_forward(module):
|
646 |
+
def custom_forward(*inputs):
|
647 |
+
return module(*inputs)
|
648 |
+
return custom_forward
|
649 |
+
x = torch.utils.checkpoint.checkpoint(
|
650 |
+
create_custom_forward(transformer_block),
|
651 |
+
x,
|
652 |
+
attn_mask,
|
653 |
+
t,
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
x = transformer_block(
|
657 |
+
hidden_states=x,
|
658 |
+
attention_mask=attn_mask,
|
659 |
+
timestep=t,
|
660 |
+
)
|
661 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
662 |
+
x = upsample(x * mask_up)
|
663 |
+
x = self.final_block(x, mask_up)
|
664 |
+
output = self.final_proj(x * mask_up)
|
665 |
+
return output * mask
|
666 |
+
|
667 |
+
|
668 |
+
class ConditionalCFM(BASECFM):
|
669 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
|
670 |
+
super().__init__(
|
671 |
+
n_feats=in_channels,
|
672 |
+
cfm_params=cfm_params,
|
673 |
+
n_spks=n_spks,
|
674 |
+
spk_emb_dim=spk_emb_dim,
|
675 |
+
)
|
676 |
+
self.t_scheduler = cfm_params.t_scheduler
|
677 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
678 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
679 |
+
|
680 |
+
@torch.inference_mode()
|
681 |
+
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
682 |
+
"""Forward diffusion
|
683 |
+
|
684 |
+
Args:
|
685 |
+
mu (torch.Tensor): output of encoder
|
686 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
687 |
+
mask (torch.Tensor): output_mask
|
688 |
+
shape: (batch_size, 1, mel_timesteps)
|
689 |
+
n_timesteps (int): number of diffusion steps
|
690 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
691 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
692 |
+
shape: (batch_size, spk_emb_dim)
|
693 |
+
cond: Not used but kept for future purposes
|
694 |
+
|
695 |
+
Returns:
|
696 |
+
sample: generated mel-spectrogram
|
697 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
698 |
+
"""
|
699 |
+
z = torch.randn_like(mu) * temperature
|
700 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
701 |
+
if self.t_scheduler == 'cosine':
|
702 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
703 |
+
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
|
704 |
+
|
705 |
+
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
|
706 |
+
"""
|
707 |
+
Fixed euler solver for ODEs.
|
708 |
+
Args:
|
709 |
+
x (torch.Tensor): random noise
|
710 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
711 |
+
shape: (n_timesteps + 1,)
|
712 |
+
mu (torch.Tensor): output of encoder
|
713 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
714 |
+
mask (torch.Tensor): output_mask
|
715 |
+
shape: (batch_size, 1, mel_timesteps)
|
716 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
717 |
+
shape: (batch_size, spk_emb_dim)
|
718 |
+
cond: Not used but kept for future purposes
|
719 |
+
"""
|
720 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
721 |
+
|
722 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
723 |
+
# Or in future might add like a return_all_steps flag
|
724 |
+
sol = []
|
725 |
+
|
726 |
+
for step in range(1, len(t_span)):
|
727 |
+
dphi_dt = estimator(x, mask, mu, t, spks, cond)
|
728 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
729 |
+
if self.inference_cfg_rate > 0:
|
730 |
+
cfg_dphi_dt = estimator(
|
731 |
+
x, mask,
|
732 |
+
torch.zeros_like(mu), t,
|
733 |
+
torch.zeros_like(spks) if spks is not None else None,
|
734 |
+
cond=cond
|
735 |
+
)
|
736 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
737 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
738 |
+
x = x + dt * dphi_dt
|
739 |
+
t = t + dt
|
740 |
+
sol.append(x)
|
741 |
+
if step < len(t_span) - 1:
|
742 |
+
dt = t_span[step + 1] - t
|
743 |
+
|
744 |
+
return sol[-1]
|
745 |
+
|
746 |
+
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
|
747 |
+
"""Computes diffusion loss
|
748 |
+
|
749 |
+
Args:
|
750 |
+
x1 (torch.Tensor): Target
|
751 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
752 |
+
mask (torch.Tensor): target mask
|
753 |
+
shape: (batch_size, 1, mel_timesteps)
|
754 |
+
mu (torch.Tensor): output of encoder
|
755 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
756 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
757 |
+
shape: (batch_size, spk_emb_dim)
|
758 |
+
|
759 |
+
Returns:
|
760 |
+
loss: conditional flow matching loss
|
761 |
+
y: conditional flow
|
762 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
763 |
+
"""
|
764 |
+
org_dtype = x1.dtype
|
765 |
+
|
766 |
+
b, _, t = mu.shape
|
767 |
+
# random timestep
|
768 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
769 |
+
if self.t_scheduler == 'cosine':
|
770 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
771 |
+
# sample noise p(x_0)
|
772 |
+
z = torch.randn_like(x1)
|
773 |
+
|
774 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
775 |
+
u = x1 - (1 - self.sigma_min) * z
|
776 |
+
|
777 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
778 |
+
if self.training_cfg_rate > 0:
|
779 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
780 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
781 |
+
if spks is not None:
|
782 |
+
spks = spks * cfg_mask.view(-1, 1)
|
783 |
+
if cond is not None:
|
784 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
785 |
+
|
786 |
+
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
|
787 |
+
pred = pred.float()
|
788 |
+
u = u.float()
|
789 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
790 |
+
loss = loss.to(org_dtype)
|
791 |
+
return loss, y
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 151643,
|
3 |
+
"eos_token_id": 151643,
|
4 |
+
"max_new_tokens": 2048,
|
5 |
+
"transformers_version": "4.45.0.dev0"
|
6 |
+
}
|
generation_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from queue import Queue
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
8 |
+
def _parse_messages(messages, split_role="user"):
|
9 |
+
system, rounds = "", []
|
10 |
+
round = []
|
11 |
+
for i, message in enumerate(messages):
|
12 |
+
if message["role"] == "system":
|
13 |
+
assert i == 0
|
14 |
+
system = message["content"]
|
15 |
+
continue
|
16 |
+
if message["role"] == split_role and round:
|
17 |
+
rounds.append(round)
|
18 |
+
round = []
|
19 |
+
round.append(message)
|
20 |
+
if round:
|
21 |
+
rounds.append(round)
|
22 |
+
return system, rounds
|
23 |
+
|
24 |
+
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
|
25 |
+
max_input_tokens = model.config.model_max_length - max_new_tokens
|
26 |
+
system, rounds = _parse_messages(messages, split_role="user")
|
27 |
+
system_tokens = tokenizer.encode(system)
|
28 |
+
max_history_tokens = max_input_tokens - len(system_tokens)
|
29 |
+
|
30 |
+
history_tokens = []
|
31 |
+
for round in rounds[::-1]:
|
32 |
+
round_tokens = []
|
33 |
+
for message in round:
|
34 |
+
if message["role"] == "user":
|
35 |
+
round_tokens.append(model.generation_config.user_token_id)
|
36 |
+
else:
|
37 |
+
round_tokens.append(model.generation_config.assistant_token_id)
|
38 |
+
round_tokens.extend(tokenizer.encode(message["content"]))
|
39 |
+
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
|
40 |
+
history_tokens = round_tokens + history_tokens # concat left
|
41 |
+
if len(history_tokens) < max_history_tokens:
|
42 |
+
continue
|
43 |
+
break
|
44 |
+
|
45 |
+
input_tokens = system_tokens + history_tokens
|
46 |
+
if messages[-1]["role"] != "assistant":
|
47 |
+
input_tokens.append(model.generation_config.assistant_token_id)
|
48 |
+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
|
49 |
+
return torch.LongTensor([input_tokens]).to(model.device)
|
50 |
+
|
51 |
+
|
52 |
+
class TextIterStreamer:
|
53 |
+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
54 |
+
self.tokenizer = tokenizer
|
55 |
+
self.skip_prompt = skip_prompt
|
56 |
+
self.skip_special_tokens = skip_special_tokens
|
57 |
+
self.tokens = []
|
58 |
+
self.text_queue = Queue()
|
59 |
+
self.next_tokens_are_prompt = True
|
60 |
+
|
61 |
+
def put(self, value):
|
62 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
63 |
+
self.next_tokens_are_prompt = False
|
64 |
+
else:
|
65 |
+
if len(value.shape) > 1:
|
66 |
+
value = value[0]
|
67 |
+
self.tokens.extend(value.tolist())
|
68 |
+
self.text_queue.put(
|
69 |
+
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
|
70 |
+
|
71 |
+
def end(self):
|
72 |
+
self.text_queue.put(None)
|
73 |
+
|
74 |
+
def __iter__(self):
|
75 |
+
return self
|
76 |
+
|
77 |
+
def __next__(self):
|
78 |
+
value = self.text_queue.get()
|
79 |
+
if value is None:
|
80 |
+
raise StopIteration()
|
81 |
+
else:
|
82 |
+
return value
|
83 |
+
|
matcha_components.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import math
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
|
33 |
+
from diffusers.models.activations import get_activation
|
34 |
+
|
35 |
+
|
36 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super().__init__()
|
39 |
+
self.dim = dim
|
40 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
41 |
+
|
42 |
+
def forward(self, x, scale=1000):
|
43 |
+
if x.ndim < 1:
|
44 |
+
x = x.unsqueeze(0)
|
45 |
+
device = x.device
|
46 |
+
half_dim = self.dim // 2
|
47 |
+
emb = math.log(10000) / (half_dim - 1)
|
48 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
49 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
50 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
51 |
+
return emb
|
52 |
+
|
53 |
+
|
54 |
+
class Block1D(torch.nn.Module):
|
55 |
+
def __init__(self, dim, dim_out, groups=8):
|
56 |
+
super().__init__()
|
57 |
+
self.block = torch.nn.Sequential(
|
58 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
59 |
+
torch.nn.GroupNorm(groups, dim_out),
|
60 |
+
nn.Mish(),
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x, mask):
|
64 |
+
output = self.block(x * mask)
|
65 |
+
return output * mask
|
66 |
+
|
67 |
+
|
68 |
+
class ResnetBlock1D(torch.nn.Module):
|
69 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
70 |
+
super().__init__()
|
71 |
+
self.mlp = torch.nn.Sequential(
|
72 |
+
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
73 |
+
)
|
74 |
+
|
75 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
76 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
77 |
+
|
78 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
79 |
+
|
80 |
+
def forward(self, x, mask, time_emb):
|
81 |
+
h = self.block1(x, mask)
|
82 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
83 |
+
h = self.block2(h, mask)
|
84 |
+
output = h + self.res_conv(x * mask)
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
class Downsample1D(nn.Module):
|
89 |
+
def __init__(self, dim):
|
90 |
+
super().__init__()
|
91 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
return self.conv(x)
|
95 |
+
|
96 |
+
|
97 |
+
class TimestepEmbedding(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
in_channels: int,
|
101 |
+
time_embed_dim: int,
|
102 |
+
act_fn: str = "silu",
|
103 |
+
out_dim: int = None,
|
104 |
+
post_act_fn: Optional[str] = None,
|
105 |
+
cond_proj_dim=None,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
110 |
+
|
111 |
+
if cond_proj_dim is not None:
|
112 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
113 |
+
else:
|
114 |
+
self.cond_proj = None
|
115 |
+
|
116 |
+
self.act = get_activation(act_fn)
|
117 |
+
|
118 |
+
if out_dim is not None:
|
119 |
+
time_embed_dim_out = out_dim
|
120 |
+
else:
|
121 |
+
time_embed_dim_out = time_embed_dim
|
122 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
123 |
+
|
124 |
+
if post_act_fn is None:
|
125 |
+
self.post_act = None
|
126 |
+
else:
|
127 |
+
self.post_act = get_activation(post_act_fn)
|
128 |
+
|
129 |
+
def forward(self, sample, condition=None):
|
130 |
+
if condition is not None:
|
131 |
+
sample = sample + self.cond_proj(condition)
|
132 |
+
sample = self.linear_1(sample)
|
133 |
+
|
134 |
+
if self.act is not None:
|
135 |
+
sample = self.act(sample)
|
136 |
+
|
137 |
+
sample = self.linear_2(sample)
|
138 |
+
|
139 |
+
if self.post_act is not None:
|
140 |
+
sample = self.post_act(sample)
|
141 |
+
return sample
|
142 |
+
|
143 |
+
|
144 |
+
class Upsample1D(nn.Module):
|
145 |
+
"""A 1D upsampling layer with an optional convolution.
|
146 |
+
|
147 |
+
Parameters:
|
148 |
+
channels (`int`):
|
149 |
+
number of channels in the inputs and outputs.
|
150 |
+
use_conv (`bool`, default `False`):
|
151 |
+
option to use a convolution.
|
152 |
+
use_conv_transpose (`bool`, default `False`):
|
153 |
+
option to use a convolution transpose.
|
154 |
+
out_channels (`int`, optional):
|
155 |
+
number of output channels. Defaults to `channels`.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
channels,
|
161 |
+
use_conv=False,
|
162 |
+
use_conv_transpose=True,
|
163 |
+
out_channels=None,
|
164 |
+
name="conv",
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
self.channels = channels
|
168 |
+
self.out_channels = out_channels or channels
|
169 |
+
self.use_conv = use_conv
|
170 |
+
self.use_conv_transpose = use_conv_transpose
|
171 |
+
self.name = name
|
172 |
+
|
173 |
+
self.conv = None
|
174 |
+
if use_conv_transpose:
|
175 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
176 |
+
elif use_conv:
|
177 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
178 |
+
|
179 |
+
def forward(self, inputs):
|
180 |
+
assert inputs.shape[1] == self.channels
|
181 |
+
if self.use_conv_transpose:
|
182 |
+
return self.conv(inputs)
|
183 |
+
|
184 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
185 |
+
|
186 |
+
if self.use_conv:
|
187 |
+
outputs = self.conv(outputs)
|
188 |
+
|
189 |
+
return outputs
|
matcha_feat.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.utils.data
|
29 |
+
from librosa.filters import mel as librosa_mel_fn
|
30 |
+
from scipy.io.wavfile import read
|
31 |
+
|
32 |
+
MAX_WAV_VALUE = 32768.0
|
33 |
+
|
34 |
+
|
35 |
+
def load_wav(full_path):
|
36 |
+
sampling_rate, data = read(full_path)
|
37 |
+
return data, sampling_rate
|
38 |
+
|
39 |
+
|
40 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
41 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
42 |
+
|
43 |
+
|
44 |
+
def dynamic_range_decompression(x, C=1):
|
45 |
+
return np.exp(x) / C
|
46 |
+
|
47 |
+
|
48 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
49 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
50 |
+
|
51 |
+
|
52 |
+
def dynamic_range_decompression_torch(x, C=1):
|
53 |
+
return torch.exp(x) / C
|
54 |
+
|
55 |
+
|
56 |
+
def spectral_normalize_torch(magnitudes):
|
57 |
+
output = dynamic_range_compression_torch(magnitudes)
|
58 |
+
return output
|
59 |
+
|
60 |
+
|
61 |
+
def spectral_de_normalize_torch(magnitudes):
|
62 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
mel_basis = {}
|
67 |
+
hann_window = {}
|
68 |
+
|
69 |
+
|
70 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
71 |
+
if torch.min(y) < -1.0:
|
72 |
+
print("min value is ", torch.min(y))
|
73 |
+
if torch.max(y) > 1.0:
|
74 |
+
print("max value is ", torch.max(y))
|
75 |
+
|
76 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
77 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
78 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
79 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
80 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
81 |
+
|
82 |
+
y = torch.nn.functional.pad(
|
83 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
84 |
+
)
|
85 |
+
y = y.squeeze(1)
|
86 |
+
|
87 |
+
spec = torch.view_as_real(
|
88 |
+
torch.stft(
|
89 |
+
y,
|
90 |
+
n_fft,
|
91 |
+
hop_length=hop_size,
|
92 |
+
win_length=win_size,
|
93 |
+
window=hann_window[str(y.device)],
|
94 |
+
center=center,
|
95 |
+
pad_mode="reflect",
|
96 |
+
normalized=False,
|
97 |
+
onesided=True,
|
98 |
+
return_complex=True,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
103 |
+
|
104 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
105 |
+
spec = spectral_normalize_torch(spec)
|
106 |
+
|
107 |
+
return spec
|
matcha_transformer.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
from typing import Any, Dict, Optional
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from diffusers.models.attention import (
|
31 |
+
GEGLU,
|
32 |
+
GELU,
|
33 |
+
AdaLayerNorm,
|
34 |
+
AdaLayerNormZero,
|
35 |
+
ApproximateGELU,
|
36 |
+
)
|
37 |
+
from diffusers.models.attention_processor import Attention
|
38 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
39 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
40 |
+
|
41 |
+
import torch.nn.functional as F
|
42 |
+
from flash_attn import flash_attn_varlen_func
|
43 |
+
|
44 |
+
|
45 |
+
def get_sequence_mask(inputs, inputs_length):
|
46 |
+
if inputs.dim() == 3:
|
47 |
+
bsz, tgt_len, _ = inputs.size()
|
48 |
+
else:
|
49 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
50 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
51 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
|
52 |
+
bsz, tgt_len, 1
|
53 |
+
)
|
54 |
+
unpacking_index = (
|
55 |
+
torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
|
56 |
+
) # 转成下标
|
57 |
+
return sequence_mask, unpacking_index
|
58 |
+
|
59 |
+
|
60 |
+
class OmniWhisperAttention(nn.Module):
|
61 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
62 |
+
super().__init__()
|
63 |
+
self.embed_dim = embed_dim
|
64 |
+
self.num_heads = num_heads
|
65 |
+
self.head_dim = embed_dim // num_heads
|
66 |
+
|
67 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
68 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
69 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
70 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
71 |
+
|
72 |
+
self.causal = causal
|
73 |
+
|
74 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
75 |
+
bsz, _ = hidden_states.size()
|
76 |
+
|
77 |
+
query_states = self.q_proj(hidden_states).view(
|
78 |
+
bsz, self.num_heads, self.head_dim
|
79 |
+
)
|
80 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
81 |
+
value_states = self.v_proj(hidden_states).view(
|
82 |
+
bsz, self.num_heads, self.head_dim
|
83 |
+
)
|
84 |
+
|
85 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
|
86 |
+
torch.int32
|
87 |
+
)
|
88 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
89 |
+
attn_output = flash_attn_varlen_func(
|
90 |
+
query_states,
|
91 |
+
key_states,
|
92 |
+
value_states,
|
93 |
+
cu_len,
|
94 |
+
cu_len,
|
95 |
+
max_seqlen,
|
96 |
+
max_seqlen,
|
97 |
+
causal=self.causal,
|
98 |
+
) # (bsz * qlen, nheads, headdim)
|
99 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
100 |
+
attn_output = self.out_proj(attn_output)
|
101 |
+
return attn_output
|
102 |
+
|
103 |
+
|
104 |
+
class SnakeBeta(nn.Module):
|
105 |
+
"""
|
106 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
107 |
+
Shape:
|
108 |
+
- Input: (B, C, T)
|
109 |
+
- Output: (B, C, T), same shape as the input
|
110 |
+
Parameters:
|
111 |
+
- alpha - trainable parameter that controls frequency
|
112 |
+
- beta - trainable parameter that controls magnitude
|
113 |
+
References:
|
114 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
115 |
+
https://arxiv.org/abs/2006.08195
|
116 |
+
Examples:
|
117 |
+
>>> a1 = snakebeta(256)
|
118 |
+
>>> x = torch.randn(256)
|
119 |
+
>>> x = a1(x)
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
in_features,
|
125 |
+
out_features,
|
126 |
+
alpha=1.0,
|
127 |
+
alpha_trainable=True,
|
128 |
+
alpha_logscale=True,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Initialization.
|
132 |
+
INPUT:
|
133 |
+
- in_features: shape of the input
|
134 |
+
- alpha - trainable parameter that controls frequency
|
135 |
+
- beta - trainable parameter that controls magnitude
|
136 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
137 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
138 |
+
alpha will be trained along with the rest of your model.
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
self.in_features = (
|
142 |
+
out_features if isinstance(out_features, list) else [out_features]
|
143 |
+
)
|
144 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
145 |
+
|
146 |
+
# initialize alpha
|
147 |
+
self.alpha_logscale = alpha_logscale
|
148 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
149 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
150 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
151 |
+
else: # linear scale alphas initialized to ones
|
152 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
153 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
154 |
+
|
155 |
+
self.alpha.requires_grad = alpha_trainable
|
156 |
+
self.beta.requires_grad = alpha_trainable
|
157 |
+
|
158 |
+
self.no_div_by_zero = 0.000000001
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
"""
|
162 |
+
Forward pass of the function.
|
163 |
+
Applies the function to the input elementwise.
|
164 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
165 |
+
"""
|
166 |
+
x = self.proj(x)
|
167 |
+
if self.alpha_logscale:
|
168 |
+
alpha = torch.exp(self.alpha)
|
169 |
+
beta = torch.exp(self.beta)
|
170 |
+
else:
|
171 |
+
alpha = self.alpha
|
172 |
+
beta = self.beta
|
173 |
+
|
174 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
175 |
+
torch.sin(x * alpha), 2
|
176 |
+
)
|
177 |
+
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class FeedForward(nn.Module):
|
182 |
+
r"""
|
183 |
+
A feed-forward layer.
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
dim (`int`): The number of channels in the input.
|
187 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
188 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
189 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
190 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
191 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
dim: int,
|
197 |
+
dim_out: Optional[int] = None,
|
198 |
+
mult: int = 4,
|
199 |
+
dropout: float = 0.0,
|
200 |
+
activation_fn: str = "geglu",
|
201 |
+
final_dropout: bool = False,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
inner_dim = int(dim * mult)
|
205 |
+
dim_out = dim_out if dim_out is not None else dim
|
206 |
+
|
207 |
+
if activation_fn == "gelu":
|
208 |
+
act_fn = GELU(dim, inner_dim)
|
209 |
+
if activation_fn == "gelu-approximate":
|
210 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
211 |
+
elif activation_fn == "geglu":
|
212 |
+
act_fn = GEGLU(dim, inner_dim)
|
213 |
+
elif activation_fn == "geglu-approximate":
|
214 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
215 |
+
elif activation_fn == "snakebeta":
|
216 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
217 |
+
|
218 |
+
self.net = nn.ModuleList([])
|
219 |
+
# project in
|
220 |
+
self.net.append(act_fn)
|
221 |
+
# project dropout
|
222 |
+
self.net.append(nn.Dropout(dropout))
|
223 |
+
# project out
|
224 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
225 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
226 |
+
if final_dropout:
|
227 |
+
self.net.append(nn.Dropout(dropout))
|
228 |
+
|
229 |
+
def forward(self, hidden_states):
|
230 |
+
for module in self.net:
|
231 |
+
hidden_states = module(hidden_states)
|
232 |
+
return hidden_states
|
233 |
+
|
234 |
+
|
235 |
+
@maybe_allow_in_graph
|
236 |
+
class BasicTransformerBlock(nn.Module):
|
237 |
+
r"""
|
238 |
+
A basic Transformer block.
|
239 |
+
|
240 |
+
Parameters:
|
241 |
+
dim (`int`): The number of channels in the input and output.
|
242 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
243 |
+
attention_head_dim (`int`): The number of channels in each head.
|
244 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
245 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
246 |
+
only_cross_attention (`bool`, *optional*):
|
247 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
248 |
+
double_self_attention (`bool`, *optional*):
|
249 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
250 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
251 |
+
num_embeds_ada_norm (:
|
252 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
253 |
+
attention_bias (:
|
254 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
dim: int,
|
260 |
+
num_attention_heads: int,
|
261 |
+
attention_head_dim: int,
|
262 |
+
dropout=0.0,
|
263 |
+
cross_attention_dim: Optional[int] = None,
|
264 |
+
activation_fn: str = "geglu",
|
265 |
+
num_embeds_ada_norm: Optional[int] = None,
|
266 |
+
attention_bias: bool = False,
|
267 |
+
only_cross_attention: bool = False,
|
268 |
+
double_self_attention: bool = False,
|
269 |
+
upcast_attention: bool = False,
|
270 |
+
norm_elementwise_affine: bool = True,
|
271 |
+
norm_type: str = "layer_norm",
|
272 |
+
final_dropout: bool = False,
|
273 |
+
use_omni_attn: bool = False,
|
274 |
+
):
|
275 |
+
super().__init__()
|
276 |
+
|
277 |
+
self.use_omni_attn = use_omni_attn
|
278 |
+
self.dim = dim
|
279 |
+
|
280 |
+
self.only_cross_attention = only_cross_attention
|
281 |
+
|
282 |
+
self.use_ada_layer_norm_zero = (
|
283 |
+
num_embeds_ada_norm is not None
|
284 |
+
) and norm_type == "ada_norm_zero"
|
285 |
+
self.use_ada_layer_norm = (
|
286 |
+
num_embeds_ada_norm is not None
|
287 |
+
) and norm_type == "ada_norm"
|
288 |
+
|
289 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
290 |
+
raise ValueError(
|
291 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
292 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
293 |
+
)
|
294 |
+
|
295 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
296 |
+
# 1. Self-Attn
|
297 |
+
if self.use_ada_layer_norm:
|
298 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
299 |
+
elif self.use_ada_layer_norm_zero:
|
300 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
301 |
+
else:
|
302 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
303 |
+
|
304 |
+
if self.use_omni_attn:
|
305 |
+
if only_cross_attention:
|
306 |
+
raise NotImplementedError
|
307 |
+
print(
|
308 |
+
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
|
309 |
+
)
|
310 |
+
self.attn1 = OmniWhisperAttention(
|
311 |
+
embed_dim=dim, num_heads=num_attention_heads, causal=False
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
self.attn1 = Attention(
|
315 |
+
query_dim=dim,
|
316 |
+
heads=num_attention_heads,
|
317 |
+
dim_head=attention_head_dim,
|
318 |
+
dropout=dropout,
|
319 |
+
bias=attention_bias,
|
320 |
+
cross_attention_dim=(
|
321 |
+
cross_attention_dim if only_cross_attention else None
|
322 |
+
),
|
323 |
+
upcast_attention=upcast_attention,
|
324 |
+
)
|
325 |
+
|
326 |
+
# 2. Cross-Attn
|
327 |
+
if cross_attention_dim is not None or double_self_attention:
|
328 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
329 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
330 |
+
# the second cross attention block.
|
331 |
+
self.norm2 = (
|
332 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
333 |
+
if self.use_ada_layer_norm
|
334 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
335 |
+
)
|
336 |
+
self.attn2 = Attention(
|
337 |
+
query_dim=dim,
|
338 |
+
cross_attention_dim=(
|
339 |
+
cross_attention_dim if not double_self_attention else None
|
340 |
+
),
|
341 |
+
heads=num_attention_heads,
|
342 |
+
dim_head=attention_head_dim,
|
343 |
+
dropout=dropout,
|
344 |
+
bias=attention_bias,
|
345 |
+
upcast_attention=upcast_attention,
|
346 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
347 |
+
) # is self-attn if encoder_hidden_states is none
|
348 |
+
else:
|
349 |
+
self.norm2 = None
|
350 |
+
self.attn2 = None
|
351 |
+
|
352 |
+
# 3. Feed-forward
|
353 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
354 |
+
self.ff = FeedForward(
|
355 |
+
dim,
|
356 |
+
dropout=dropout,
|
357 |
+
activation_fn=activation_fn,
|
358 |
+
final_dropout=final_dropout,
|
359 |
+
)
|
360 |
+
|
361 |
+
# let chunk size default to None
|
362 |
+
self._chunk_size = None
|
363 |
+
self._chunk_dim = 0
|
364 |
+
|
365 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
366 |
+
# Sets chunk feed-forward
|
367 |
+
self._chunk_size = chunk_size
|
368 |
+
self._chunk_dim = dim
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self,
|
372 |
+
hidden_states: torch.FloatTensor,
|
373 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
374 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
375 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
376 |
+
timestep: Optional[torch.LongTensor] = None,
|
377 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
378 |
+
class_labels: Optional[torch.LongTensor] = None,
|
379 |
+
):
|
380 |
+
|
381 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
382 |
+
|
383 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
384 |
+
# 1. Self-Attention
|
385 |
+
if self.use_ada_layer_norm:
|
386 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
387 |
+
elif self.use_ada_layer_norm_zero:
|
388 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
389 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
norm_hidden_states = self.norm1(hidden_states)
|
393 |
+
|
394 |
+
cross_attention_kwargs = (
|
395 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
396 |
+
)
|
397 |
+
|
398 |
+
if self.use_omni_attn:
|
399 |
+
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
|
400 |
+
var_len_attention_mask, unpacking_index = get_sequence_mask(
|
401 |
+
norm_hidden_states, seq_len
|
402 |
+
)
|
403 |
+
norm_hidden_states = torch.masked_select(
|
404 |
+
norm_hidden_states, var_len_attention_mask
|
405 |
+
)
|
406 |
+
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
|
407 |
+
attn_output = self.attn1(norm_hidden_states, seq_len)
|
408 |
+
# unpacking
|
409 |
+
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
|
410 |
+
bsz, tgt_len, d_model
|
411 |
+
)
|
412 |
+
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
|
413 |
+
else:
|
414 |
+
attn_output = self.attn1(
|
415 |
+
norm_hidden_states,
|
416 |
+
encoder_hidden_states=(
|
417 |
+
encoder_hidden_states if self.only_cross_attention else None
|
418 |
+
),
|
419 |
+
attention_mask=(
|
420 |
+
encoder_attention_mask
|
421 |
+
if self.only_cross_attention
|
422 |
+
else attention_mask
|
423 |
+
),
|
424 |
+
**cross_attention_kwargs,
|
425 |
+
)
|
426 |
+
|
427 |
+
if self.use_ada_layer_norm_zero:
|
428 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
429 |
+
hidden_states = attn_output + hidden_states
|
430 |
+
|
431 |
+
# 2. Cross-Attention
|
432 |
+
if self.attn2 is not None:
|
433 |
+
norm_hidden_states = (
|
434 |
+
self.norm2(hidden_states, timestep)
|
435 |
+
if self.use_ada_layer_norm
|
436 |
+
else self.norm2(hidden_states)
|
437 |
+
)
|
438 |
+
|
439 |
+
attn_output = self.attn2(
|
440 |
+
norm_hidden_states,
|
441 |
+
encoder_hidden_states=encoder_hidden_states,
|
442 |
+
attention_mask=encoder_attention_mask,
|
443 |
+
**cross_attention_kwargs,
|
444 |
+
)
|
445 |
+
hidden_states = attn_output + hidden_states
|
446 |
+
|
447 |
+
# 3. Feed-forward
|
448 |
+
norm_hidden_states = self.norm3(hidden_states)
|
449 |
+
|
450 |
+
if self.use_ada_layer_norm_zero:
|
451 |
+
norm_hidden_states = (
|
452 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
453 |
+
)
|
454 |
+
|
455 |
+
if self._chunk_size is not None:
|
456 |
+
# "feed_forward_chunk_size" can be used to save memory
|
457 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
458 |
+
raise ValueError(
|
459 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
460 |
+
)
|
461 |
+
|
462 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
463 |
+
ff_output = torch.cat(
|
464 |
+
[
|
465 |
+
self.ff(hid_slice)
|
466 |
+
for hid_slice in norm_hidden_states.chunk(
|
467 |
+
num_chunks, dim=self._chunk_dim
|
468 |
+
)
|
469 |
+
],
|
470 |
+
dim=self._chunk_dim,
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
ff_output = self.ff(norm_hidden_states)
|
474 |
+
|
475 |
+
if self.use_ada_layer_norm_zero:
|
476 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
477 |
+
|
478 |
+
hidden_states = ff_output + hidden_states
|
479 |
+
|
480 |
+
return hidden_states
|
model-00001-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:887a4aafba70ac6740debcf22c58c4f40555f584c702a85776901991498ce59a
|
3 |
+
size 4877656728
|
model-00002-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d959e03801a1794b6b6e4382c2ea49a4070789ab60d94734274cb7923604547
|
3 |
+
size 4932746496
|
model-00003-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3876a179fd44ca371bade65d23887555ba0fa6945bfeb1809ba901cd296ae735
|
3 |
+
size 4999921608
|
model-00004-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d90afae3003cab9992d3f6ccf75f76d7c99779439049a230b898f7a63eb19f39
|
3 |
+
size 4677721496
|
model-00005-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d81a0faa09bbed4856c1c86a3800317232888d1326fa0a3854cbc91febc67139
|
3 |
+
size 1640609776
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_omni.py
ADDED
@@ -0,0 +1,1011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
""" PyTorch omni model."""
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
import json
|
25 |
+
import math
|
26 |
+
import numpy as np
|
27 |
+
from typing import List, Optional, Tuple, Union, Any
|
28 |
+
from threading import Thread
|
29 |
+
from easydict import EasyDict
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.distributed
|
33 |
+
import torch.utils.checkpoint
|
34 |
+
from torch import nn
|
35 |
+
from torch.nn import CrossEntropyLoss
|
36 |
+
from torch.nn import functional as F
|
37 |
+
import torch.distributed as dist
|
38 |
+
from transformers import PreTrainedModel
|
39 |
+
from transformers.activations import ACT2FN
|
40 |
+
from dataclasses import dataclass
|
41 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
42 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
43 |
+
from transformers.generation.utils import GenerationConfig
|
44 |
+
from transformers.utils import logging
|
45 |
+
# import for dynamic import not used in this file
|
46 |
+
from .vector_quantize import VectorQuantize, EuclideanCodebook
|
47 |
+
from .matcha_components import (
|
48 |
+
SinusoidalPosEmb,
|
49 |
+
Block1D,
|
50 |
+
ResnetBlock1D,
|
51 |
+
Downsample1D,
|
52 |
+
TimestepEmbedding,
|
53 |
+
Upsample1D,
|
54 |
+
)
|
55 |
+
from .matcha_transformer import BasicTransformerBlock
|
56 |
+
from .flow_matching import ConditionalDecoder, ConditionalCFM
|
57 |
+
|
58 |
+
from .configuration_omni import OmniConfig
|
59 |
+
from .audio_modeling_omni import (RMSNorm,
|
60 |
+
OmniAudioEncoder,
|
61 |
+
OmniAudioDecoder,
|
62 |
+
OmniAudioVQBridgeTokenizer,
|
63 |
+
OmniAudioFlowMatchingDecoder)
|
64 |
+
from .visual_modeling_omni import OmniVisualEncoder, OmniVisualBridge
|
65 |
+
from .processor_omni import OmniMMProcessor
|
66 |
+
|
67 |
+
# support model path contain point(.)
|
68 |
+
try:
|
69 |
+
# step1: copy relative imports to transformers_modules
|
70 |
+
from .generation_utils import build_chat_input, TextIterStreamer
|
71 |
+
from .sequence_parallel_utils import (
|
72 |
+
create_attention_layer,
|
73 |
+
get_sequence_parallel_size,
|
74 |
+
get_sequence_parallel_chunk,
|
75 |
+
)
|
76 |
+
except ModuleNotFoundError:
|
77 |
+
# step2: direct import from transformers_modules
|
78 |
+
try: # bypass check_imports failure
|
79 |
+
import sys
|
80 |
+
sys.path.append(os.path.dirname(__file__))
|
81 |
+
from generation_utils import build_chat_input, TextIterStreamer
|
82 |
+
from sequence_parallel_utils import (
|
83 |
+
create_attention_layer,
|
84 |
+
get_sequence_parallel_size,
|
85 |
+
get_sequence_parallel_chunk,
|
86 |
+
)
|
87 |
+
except Exception:
|
88 |
+
raise
|
89 |
+
|
90 |
+
logger = logging.get_logger(__name__)
|
91 |
+
|
92 |
+
def get_slopes(n):
|
93 |
+
def get_slopes_power_of_2(n):
|
94 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
95 |
+
ratio = start
|
96 |
+
return [start * ratio ** i for i in range(n)]
|
97 |
+
|
98 |
+
if math.log2(n).is_integer():
|
99 |
+
return get_slopes_power_of_2(
|
100 |
+
n) # In the paper, we only train models that have 2^a heads for some a. This function has
|
101 |
+
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
102 |
+
closest_power_of_2 = 2 ** math.floor(
|
103 |
+
math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
|
104 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
105 |
+
:n - closest_power_of_2]
|
106 |
+
|
107 |
+
|
108 |
+
class RotaryEmbedding(torch.nn.Module):
|
109 |
+
def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
|
110 |
+
super().__init__()
|
111 |
+
# 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
|
112 |
+
# DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
|
113 |
+
try:
|
114 |
+
import deepspeed
|
115 |
+
self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
|
116 |
+
except:
|
117 |
+
self.arange = torch.arange
|
118 |
+
|
119 |
+
self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
|
120 |
+
self.max_seq_len_cached = max_position_embeddings
|
121 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
122 |
+
freqs = torch.outer(t, self.inv_freq)
|
123 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
124 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
125 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
126 |
+
|
127 |
+
def forward(self, x, seq_len=None):
|
128 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
129 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
130 |
+
if seq_len > self.max_seq_len_cached:
|
131 |
+
self.max_seq_len_cached = seq_len
|
132 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
133 |
+
freqs = torch.outer(t, self.inv_freq)
|
134 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
135 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
136 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
137 |
+
return (
|
138 |
+
self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
139 |
+
self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def rotate_half(x):
|
144 |
+
"""Rotates half the hidden dims of the input."""
|
145 |
+
x1 = x[..., : x.shape[-1] // 2]
|
146 |
+
x2 = x[..., x.shape[-1] // 2:]
|
147 |
+
return torch.cat((-x2, x1), dim=-1)
|
148 |
+
|
149 |
+
|
150 |
+
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
|
151 |
+
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
|
152 |
+
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
|
153 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
154 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
155 |
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
156 |
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
157 |
+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
158 |
+
|
159 |
+
|
160 |
+
class MLP(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
hidden_size: int,
|
164 |
+
intermediate_size: int,
|
165 |
+
hidden_act: str,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
169 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
170 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
171 |
+
self.act_fn = ACT2FN[hidden_act]
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
175 |
+
|
176 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
177 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
178 |
+
"""
|
179 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
180 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
181 |
+
"""
|
182 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
183 |
+
if n_rep == 1:
|
184 |
+
return hidden_states
|
185 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
186 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
187 |
+
|
188 |
+
|
189 |
+
class Attention(nn.Module):
|
190 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
191 |
+
def __init__(self, config: OmniConfig, is_sparse=False):
|
192 |
+
super().__init__()
|
193 |
+
self.config = config
|
194 |
+
self.position_embedding_type = config.position_embedding_type.lower()
|
195 |
+
self.num_kv_heads = config.num_key_value_heads
|
196 |
+
self.head_dim = config.head_dim
|
197 |
+
self.hidden_size = config.num_attention_heads * self.head_dim
|
198 |
+
self.hidden_kv_size = self.num_kv_heads * self.head_dim
|
199 |
+
|
200 |
+
if is_sparse:
|
201 |
+
self.num_heads = config.sparse_attention_heads
|
202 |
+
assert self.num_kv_heads == config.num_attention_heads
|
203 |
+
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
|
204 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
205 |
+
else:
|
206 |
+
self.num_heads = config.num_attention_heads
|
207 |
+
if self.config.attention_qkv_pack:
|
208 |
+
self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
|
209 |
+
else:
|
210 |
+
self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
|
211 |
+
self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
212 |
+
self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
213 |
+
|
214 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
215 |
+
|
216 |
+
if self.position_embedding_type == 'rope':
|
217 |
+
self.rotary_emb = RotaryEmbedding(
|
218 |
+
dim=self.head_dim,
|
219 |
+
max_position_embeddings=config.max_position_embeddings,
|
220 |
+
base=config.get_rotary_base()
|
221 |
+
)
|
222 |
+
elif self.position_embedding_type == 'alibi':
|
223 |
+
self.alibi_slopes = get_slopes(self.num_heads)
|
224 |
+
self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
|
225 |
+
|
226 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
227 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
228 |
+
|
229 |
+
def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
|
230 |
+
assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
|
231 |
+
return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
|
232 |
+
|
233 |
+
def forward(
|
234 |
+
self,
|
235 |
+
hidden_states: torch.Tensor,
|
236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
237 |
+
position_ids: Optional[torch.LongTensor] = None,
|
238 |
+
seqlens: Optional[torch.IntTensor] = None,
|
239 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
240 |
+
output_attentions: bool = False,
|
241 |
+
use_cache: bool = False,
|
242 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
243 |
+
bsz, q_len = hidden_states.shape[:2]
|
244 |
+
|
245 |
+
if self.config.attention_qkv_pack:
|
246 |
+
proj = self.W_pack(hidden_states)
|
247 |
+
query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
|
248 |
+
else:
|
249 |
+
query_states = self.q_proj(hidden_states)
|
250 |
+
key_states = self.k_proj(hidden_states)
|
251 |
+
value_states = self.v_proj(hidden_states)
|
252 |
+
|
253 |
+
# (B, S, hidden_size) -> (B, num_heads, S, head_size)
|
254 |
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
255 |
+
# (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
|
256 |
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
257 |
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
258 |
+
|
259 |
+
kv_seq_len = key_states.shape[-2]
|
260 |
+
if past_key_value is not None:
|
261 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
262 |
+
if self.position_embedding_type == 'rope':
|
263 |
+
max_position = position_ids.max().item()+1 if position_ids is not None else kv_seq_len * get_sequence_parallel_size()
|
264 |
+
cos, sin = self.rotary_emb(value_states, seq_len=max_position)
|
265 |
+
query_states, key_states = apply_rotary_pos_emb(
|
266 |
+
query_states, key_states, cos, sin,
|
267 |
+
get_sequence_parallel_chunk(position_ids)
|
268 |
+
)
|
269 |
+
|
270 |
+
if past_key_value is not None:
|
271 |
+
# reuse k, v, self_attention
|
272 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
273 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
274 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
275 |
+
|
276 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
277 |
+
key_states = self._repeat_kv(key_states, query_states.size(1))
|
278 |
+
value_states = self._repeat_kv(value_states, query_states.size(1))
|
279 |
+
|
280 |
+
if seqlens is not None:
|
281 |
+
seqlens = seqlens.to(dtype=torch.int32)
|
282 |
+
max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
|
283 |
+
if self.position_embedding_type == 'alibi':
|
284 |
+
alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
|
285 |
+
else:
|
286 |
+
alibi_slopes = None
|
287 |
+
attn_output = self.attention(
|
288 |
+
query_states, key_states, value_states, seqlens, seqlens,
|
289 |
+
max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
|
290 |
+
else:
|
291 |
+
attn_output = self.attention(
|
292 |
+
query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
|
293 |
+
|
294 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
295 |
+
attn_output = self.o_proj(attn_output)
|
296 |
+
|
297 |
+
return attn_output, None, past_key_value
|
298 |
+
|
299 |
+
|
300 |
+
class DecoderLayer(nn.Module):
|
301 |
+
def __init__(self, config: OmniConfig, is_sparse=False):
|
302 |
+
super().__init__()
|
303 |
+
self.hidden_size = config.hidden_size
|
304 |
+
self.self_attn = Attention(config=config, is_sparse=is_sparse)
|
305 |
+
self.mlp = MLP(
|
306 |
+
hidden_size=self.hidden_size,
|
307 |
+
intermediate_size=config.intermediate_size,
|
308 |
+
hidden_act=config.hidden_act,
|
309 |
+
)
|
310 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
311 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
hidden_states: torch.Tensor,
|
316 |
+
attention_mask: Optional[torch.Tensor] = None,
|
317 |
+
position_ids: Optional[torch.LongTensor] = None,
|
318 |
+
seqlens: Optional[torch.IntTensor] = None,
|
319 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
320 |
+
output_attentions: Optional[bool] = False,
|
321 |
+
use_cache: Optional[bool] = False,
|
322 |
+
group_index=None,
|
323 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
324 |
+
|
325 |
+
residual = hidden_states
|
326 |
+
|
327 |
+
hidden_states = self.input_layernorm(hidden_states)
|
328 |
+
|
329 |
+
# Self Attention
|
330 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
331 |
+
hidden_states=hidden_states,
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
position_ids=position_ids,
|
334 |
+
seqlens=seqlens,
|
335 |
+
past_key_value=past_key_value,
|
336 |
+
output_attentions=output_attentions,
|
337 |
+
use_cache=use_cache,
|
338 |
+
)
|
339 |
+
hidden_states = residual + hidden_states
|
340 |
+
|
341 |
+
# Fully Connected
|
342 |
+
residual = hidden_states
|
343 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
344 |
+
hidden_states = self.mlp(hidden_states)
|
345 |
+
hidden_states = residual + hidden_states
|
346 |
+
|
347 |
+
outputs = (hidden_states,)
|
348 |
+
|
349 |
+
if output_attentions:
|
350 |
+
outputs += (self_attn_weights,)
|
351 |
+
|
352 |
+
if use_cache:
|
353 |
+
outputs += (present_key_value,)
|
354 |
+
|
355 |
+
return outputs
|
356 |
+
|
357 |
+
|
358 |
+
class OmniPreTrainedModel(PreTrainedModel):
|
359 |
+
config_class = OmniConfig
|
360 |
+
base_model_prefix = "model"
|
361 |
+
supports_gradient_checkpointing = True
|
362 |
+
_no_split_modules = ["DecoderLayer"]
|
363 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
364 |
+
|
365 |
+
def _init_weights(self, module):
|
366 |
+
std = self.config.initializer_range
|
367 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
368 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
369 |
+
if module.bias is not None:
|
370 |
+
module.bias.data.zero_()
|
371 |
+
elif isinstance(module, nn.Embedding):
|
372 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
373 |
+
if module.padding_idx is not None:
|
374 |
+
module.weight.data[module.padding_idx].zero_()
|
375 |
+
elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.GroupNorm):
|
376 |
+
module.weight.data.fill_(1.0)
|
377 |
+
module.bias.data.zero_()
|
378 |
+
elif isinstance(module, RMSNorm):
|
379 |
+
module.weight.data.fill_(1.0)
|
380 |
+
|
381 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
382 |
+
if isinstance(module, OmniModel):
|
383 |
+
module.gradient_checkpointing = value
|
384 |
+
|
385 |
+
@dataclass
|
386 |
+
class OmniModelOutputWithPast(BaseModelOutputWithPast):
|
387 |
+
audio_encoder_ret: Optional[Any] = None
|
388 |
+
audio_decoder_ret: Optional[Any] = None
|
389 |
+
|
390 |
+
class OmniModel(OmniPreTrainedModel):
|
391 |
+
def __init__(self, config: OmniConfig):
|
392 |
+
super().__init__(config)
|
393 |
+
self.padding_idx = config.pad_token_id
|
394 |
+
self.vocab_size = config.vocab_size
|
395 |
+
|
396 |
+
if config.visual_config.enable:
|
397 |
+
self.visual_model = OmniVisualEncoder(config.visual_config)
|
398 |
+
self.visual_bridge_model = OmniVisualBridge(config.visual_config)
|
399 |
+
if config.video_config.enable and not config.visual_config.enable: # in case 没有visual_config而只有video_config
|
400 |
+
self.visual_model = OmniVisualEncoder(config.video_config)
|
401 |
+
self.visual_bridge_model = OmniVisualBridge(config.video_config)
|
402 |
+
|
403 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
404 |
+
self.layers = nn.ModuleList([
|
405 |
+
DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
|
406 |
+
for layer_idx in range(config.num_hidden_layers)
|
407 |
+
])
|
408 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
409 |
+
|
410 |
+
self.audio_embed_layers = nn.ModuleList([
|
411 |
+
nn.Embedding(codedim + 1, config.hidden_size)
|
412 |
+
for i, codedim in enumerate(config.audio_config.vq_config.codebook_sizes)
|
413 |
+
])
|
414 |
+
|
415 |
+
self.gradient_checkpointing = True
|
416 |
+
# Initialize weights and apply final processing
|
417 |
+
self.post_init()
|
418 |
+
|
419 |
+
def get_input_embeddings(self):
|
420 |
+
return self.embed_tokens
|
421 |
+
|
422 |
+
def set_input_embeddings(self, value):
|
423 |
+
self.embed_tokens = value
|
424 |
+
|
425 |
+
@torch.no_grad()
|
426 |
+
def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
|
427 |
+
'''
|
428 |
+
获取任意模态的特殊mask,包含以下
|
429 |
+
1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
|
430 |
+
2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
|
431 |
+
3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
|
432 |
+
'''
|
433 |
+
pad_mask = torch.eq(input_ids, pad_token_id)
|
434 |
+
sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
435 |
+
lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
|
436 |
+
for sp_id in special_token_list:
|
437 |
+
sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
|
438 |
+
lm_head_mask[sp_id, 0] = True
|
439 |
+
return pad_mask, sp_mask, lm_head_mask
|
440 |
+
|
441 |
+
def get_multimodal_embed(
|
442 |
+
self,
|
443 |
+
input_ids,
|
444 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
445 |
+
multimodal_embed,
|
446 |
+
pad_token_id,
|
447 |
+
fake_input,
|
448 |
+
group_index=None, # 某种模态的编号
|
449 |
+
):
|
450 |
+
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, pad_token_id, self.config.multimodal_special_token_list)
|
451 |
+
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
452 |
+
multimodal_embed = multimodal_embed.to(input_ids.device)
|
453 |
+
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
454 |
+
assert pad_mask.sum() == multimodal_embed.shape[0]
|
455 |
+
else:
|
456 |
+
assert pad_mask.sum() <= 0
|
457 |
+
|
458 |
+
# 合并 当前模态embeddings 和text embeddings
|
459 |
+
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
460 |
+
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0
|
461 |
+
multimodal_embedding = torch.embedding(multimodal_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
462 |
+
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
463 |
+
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
464 |
+
|
465 |
+
if group_index is None:
|
466 |
+
group_index = pad_mask.to(torch.int32)
|
467 |
+
else:
|
468 |
+
current_index = torch.max(group_index) + 1
|
469 |
+
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
470 |
+
|
471 |
+
return final_embedding, group_index
|
472 |
+
|
473 |
+
def get_visual_embed(
|
474 |
+
self,
|
475 |
+
input_ids,
|
476 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
477 |
+
images = None,
|
478 |
+
patch_nums = None,
|
479 |
+
images_grid = None,
|
480 |
+
videos = None,
|
481 |
+
videos_patch_nums = None,
|
482 |
+
videos_grid = None,
|
483 |
+
group_index = None, # 某种模态的编号
|
484 |
+
):
|
485 |
+
if images is None or len(images) <= 0:
|
486 |
+
images, images_grid, patch_nums = self.visual_model.fake_input(input_ids.device)
|
487 |
+
image_fake_input = True
|
488 |
+
else:
|
489 |
+
image_fake_input = False
|
490 |
+
|
491 |
+
if videos is None or len(videos) <= 0 :
|
492 |
+
videos, videos_grid, videos_patch_nums = self.visual_model.fake_input(input_ids.device)
|
493 |
+
video_fake_input = True
|
494 |
+
else:
|
495 |
+
video_fake_input = False
|
496 |
+
|
497 |
+
visual_input = images + videos
|
498 |
+
visual_grid = images_grid + videos_grid
|
499 |
+
|
500 |
+
visual_input = torch.cat(visual_input, dim=0)
|
501 |
+
visual_grid = torch.tensor(np.array(visual_grid))
|
502 |
+
|
503 |
+
visual_embed = self.visual_model(visual_input, grid_thw=visual_grid)
|
504 |
+
visual_embed = self.visual_bridge_model(visual_embed)
|
505 |
+
|
506 |
+
assert sum(patch_nums) + sum(videos_patch_nums) == visual_embed.shape[0]
|
507 |
+
images_embed = visual_embed[:sum(patch_nums)]
|
508 |
+
videos_embed = visual_embed[sum(patch_nums):]
|
509 |
+
|
510 |
+
final_embedding, group_index = self.get_multimodal_embed(input_ids, text_embedding, images_embed, self.config.visual_config.image_pad_token_id, image_fake_input, group_index=group_index)
|
511 |
+
final_embedding, group_index = self.get_multimodal_embed(input_ids, final_embedding, videos_embed, self.config.video_config.video_place_token_id, video_fake_input, group_index=group_index)
|
512 |
+
return final_embedding, group_index
|
513 |
+
|
514 |
+
|
515 |
+
@torch.no_grad()
|
516 |
+
def audio_fake_input(self, device):
|
517 |
+
return torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=device)
|
518 |
+
|
519 |
+
def forward(
|
520 |
+
self,
|
521 |
+
input_ids: torch.LongTensor = None,
|
522 |
+
attention_mask: Optional[torch.Tensor] = None,
|
523 |
+
position_ids: Optional[torch.LongTensor] = None,
|
524 |
+
seqlens: Optional[torch.IntTensor] = None,
|
525 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
526 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
527 |
+
audios_tokens: Optional[List|torch.Tensor] = None, # 音频token bs*seqlen*vq_num
|
528 |
+
images: Optional[List|torch.Tensor] = None,
|
529 |
+
patch_nums: Optional[torch.Tensor] = None,
|
530 |
+
images_grid: Optional[List|torch.Tensor] = None,
|
531 |
+
videos: Optional[List|torch.Tensor] = None,
|
532 |
+
videos_patch_nums: Optional[torch.Tensor] = None,
|
533 |
+
videos_grid: Optional[List|torch.Tensor] = None,
|
534 |
+
use_cache: Optional[bool] = None,
|
535 |
+
output_attentions: Optional[bool] = None,
|
536 |
+
output_hidden_states: Optional[bool] = None,
|
537 |
+
return_dict: Optional[bool] = None,
|
538 |
+
) -> Union[Tuple, OmniModelOutputWithPast]:
|
539 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
540 |
+
output_hidden_states = (
|
541 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
542 |
+
)
|
543 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
544 |
+
return_dict = True if (return_dict is not None or self.training) else self.config.use_return_dict
|
545 |
+
|
546 |
+
# retrieve input_ids and inputs_embeds
|
547 |
+
if input_ids is not None and inputs_embeds is not None:
|
548 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
549 |
+
elif input_ids is not None:
|
550 |
+
batch_size, seq_length = input_ids.shape
|
551 |
+
elif inputs_embeds is not None:
|
552 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
553 |
+
else:
|
554 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
555 |
+
|
556 |
+
seq_length_with_past = seq_length
|
557 |
+
past_key_values_length = 0
|
558 |
+
|
559 |
+
if past_key_values is not None:
|
560 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
561 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
562 |
+
|
563 |
+
if position_ids is None:
|
564 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
565 |
+
position_ids = torch.arange(
|
566 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
567 |
+
)
|
568 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
569 |
+
else:
|
570 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
571 |
+
|
572 |
+
group_index, audio_decoder_ret = None, None
|
573 |
+
if inputs_embeds is None:
|
574 |
+
sp_input_ids = get_sequence_parallel_chunk(input_ids)
|
575 |
+
inputs_embeds = self.embed_tokens(sp_input_ids)
|
576 |
+
if audios_tokens is None or len(audios_tokens) <= 0 :
|
577 |
+
audios_tokens = torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=input_ids.device) # a fake input
|
578 |
+
fake_input = True
|
579 |
+
else:
|
580 |
+
fake_input = False
|
581 |
+
for i, audio_emb_layer in enumerate(self.audio_embed_layers):
|
582 |
+
if i==0:
|
583 |
+
audio_embs = audio_emb_layer(audios_tokens[..., i])
|
584 |
+
else:
|
585 |
+
audio_embs += audio_emb_layer(audios_tokens[..., i])
|
586 |
+
inputs_embeds, group_index = self.get_multimodal_embed(sp_input_ids, inputs_embeds, audio_embs, self.config.audio_config.audio_pad_token_id, fake_input, group_index=group_index)
|
587 |
+
|
588 |
+
if self.config.visual_config.enable or self.config.video_config.enable:
|
589 |
+
inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, patch_nums, images_grid, videos, videos_patch_nums, videos_grid, group_index=group_index) # 注意更新group index
|
590 |
+
|
591 |
+
if seqlens is not None and seqlens.ndim == 2:
|
592 |
+
cu_seqlens = []
|
593 |
+
offset, seqlen = 0, seqlens.size(1)
|
594 |
+
for lens in seqlens:
|
595 |
+
cu_seqlens.append(offset)
|
596 |
+
cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
|
597 |
+
offset += seqlen
|
598 |
+
cu_seqlens.append(offset)
|
599 |
+
seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
|
600 |
+
elif seqlens is None and self.training:
|
601 |
+
seqlens = torch.arange(
|
602 |
+
end=input_ids.size(0) + 1,
|
603 |
+
dtype=torch.int32,
|
604 |
+
device=input_ids.device
|
605 |
+
) * input_ids.size(1)
|
606 |
+
if seqlens is not None:
|
607 |
+
attention_mask = None # unset attention_mask to save memory
|
608 |
+
|
609 |
+
if seqlens is None and attention_mask is None:
|
610 |
+
attention_mask = torch.ones(
|
611 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
612 |
+
)
|
613 |
+
if attention_mask is not None:
|
614 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
615 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
616 |
+
)
|
617 |
+
|
618 |
+
# embed positions
|
619 |
+
hidden_states = inputs_embeds
|
620 |
+
|
621 |
+
if self.gradient_checkpointing and self.training:
|
622 |
+
if use_cache:
|
623 |
+
logger.warning_once(
|
624 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
625 |
+
)
|
626 |
+
use_cache = False
|
627 |
+
|
628 |
+
# decoder layers
|
629 |
+
all_hidden_states = () if output_hidden_states else None
|
630 |
+
all_self_attns = () if output_attentions else None
|
631 |
+
next_decoder_cache = () if use_cache else None
|
632 |
+
|
633 |
+
for idx, decoder_layer in enumerate(self.layers):
|
634 |
+
if output_hidden_states:
|
635 |
+
all_hidden_states += (hidden_states,)
|
636 |
+
|
637 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
638 |
+
|
639 |
+
if self.gradient_checkpointing and self.training:
|
640 |
+
|
641 |
+
def create_custom_forward(module):
|
642 |
+
def custom_forward(*inputs):
|
643 |
+
# None for past_key_value
|
644 |
+
return module(*inputs, output_attentions, False, group_index)
|
645 |
+
|
646 |
+
return custom_forward
|
647 |
+
|
648 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(decoder_layer),
|
650 |
+
hidden_states,
|
651 |
+
attention_mask,
|
652 |
+
position_ids,
|
653 |
+
seqlens,
|
654 |
+
None,
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
layer_outputs = decoder_layer(
|
658 |
+
hidden_states,
|
659 |
+
attention_mask=attention_mask,
|
660 |
+
position_ids=position_ids,
|
661 |
+
seqlens=seqlens,
|
662 |
+
past_key_value=past_key_value,
|
663 |
+
output_attentions=output_attentions,
|
664 |
+
use_cache=use_cache,
|
665 |
+
group_index=group_index,
|
666 |
+
)
|
667 |
+
|
668 |
+
hidden_states = layer_outputs[0]
|
669 |
+
|
670 |
+
if use_cache:
|
671 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
672 |
+
|
673 |
+
if output_attentions:
|
674 |
+
all_self_attns += (layer_outputs[1],)
|
675 |
+
|
676 |
+
hidden_states = self.norm(hidden_states)
|
677 |
+
|
678 |
+
# add hidden states from the last decoder layer
|
679 |
+
if output_hidden_states:
|
680 |
+
all_hidden_states += (hidden_states,)
|
681 |
+
|
682 |
+
next_cache = next_decoder_cache if use_cache else None
|
683 |
+
if not return_dict:
|
684 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
685 |
+
return BaseModelOutputWithPast(
|
686 |
+
last_hidden_state=hidden_states,
|
687 |
+
past_key_values=next_cache,
|
688 |
+
hidden_states=all_hidden_states,
|
689 |
+
attentions=all_self_attns,
|
690 |
+
)
|
691 |
+
|
692 |
+
|
693 |
+
class NormHead(nn.Module):
|
694 |
+
def __init__(self, hidden_size, vocab_size, bias=False):
|
695 |
+
super().__init__()
|
696 |
+
self.hidden_size = hidden_size
|
697 |
+
self.vocab_size = vocab_size
|
698 |
+
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
|
699 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
700 |
+
|
701 |
+
def forward(self, hidden_states, mask=None):
|
702 |
+
norm_weight = nn.functional.normalize(self.weight)
|
703 |
+
if mask is not None:
|
704 |
+
mask = mask.to(norm_weight)
|
705 |
+
norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
|
706 |
+
return nn.functional.linear(hidden_states, norm_weight)
|
707 |
+
|
708 |
+
|
709 |
+
def extra_repr(self) -> str:
|
710 |
+
return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
|
711 |
+
|
712 |
+
@dataclass
|
713 |
+
class OmniMMCausalLMOutputWithPast(ModelOutput):
|
714 |
+
loss: Optional[torch.FloatTensor] = None
|
715 |
+
logits: Optional[torch.FloatTensor] = None
|
716 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
717 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
718 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
719 |
+
audios_emb_for_infer: Optional[torch.FloatTensor] = None # 用于audio head 推理的 embeddings
|
720 |
+
|
721 |
+
|
722 |
+
class CasualDepthTransformerLayer(nn.Module):
|
723 |
+
def __init__(self, config, depth):
|
724 |
+
super().__init__()
|
725 |
+
self.config = config
|
726 |
+
embed_size = config.hidden_size
|
727 |
+
assert embed_size % 128 == 0
|
728 |
+
num_heads = embed_size // 128
|
729 |
+
self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
|
730 |
+
self.layernorm1 = RMSNorm(embed_size)
|
731 |
+
self.layernorm2 = RMSNorm(embed_size)
|
732 |
+
self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size)
|
733 |
+
self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
|
734 |
+
|
735 |
+
def forward(self, x):
|
736 |
+
seq_len = x.size(1)
|
737 |
+
res = x
|
738 |
+
x = self.layernorm1(x)
|
739 |
+
src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
|
740 |
+
_x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
|
741 |
+
res = _x + res # (bs, sl, d)
|
742 |
+
res = self.layernorm2(res)
|
743 |
+
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.hidden_size, -1, self.config.hidden_size)))
|
744 |
+
x = torch.nn.functional.gelu(x)
|
745 |
+
x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.hidden_size, -1, 2 * self.config.hidden_size)))
|
746 |
+
return res + x
|
747 |
+
|
748 |
+
class OmniAudioHead(nn.Module):
|
749 |
+
def __init__(self, config):
|
750 |
+
super().__init__()
|
751 |
+
self.config = config
|
752 |
+
hidden_size = config.hidden_size
|
753 |
+
self.transformer_layers = nn.ModuleList([
|
754 |
+
CasualDepthTransformerLayer(config, len(config.audio_config.vq_config.codebook_sizes))
|
755 |
+
for _ in range(config.audio_config.audio_head_transformer_layers)
|
756 |
+
])
|
757 |
+
self.headnorm = RMSNorm(hidden_size)
|
758 |
+
self.heads = nn.ModuleList([
|
759 |
+
nn.Linear(hidden_size, vq_size+1)
|
760 |
+
for vq_size in config.audio_config.vq_config.codebook_sizes
|
761 |
+
])
|
762 |
+
self.gradient_checkpointing = True
|
763 |
+
|
764 |
+
def forward(self, x, audios_tokens, audio_emb_layers):
|
765 |
+
cumsum_audio_embed = torch.stack([
|
766 |
+
audio_emb_layers[i](audios_tokens[..., i])
|
767 |
+
for i, vq_size in enumerate(self.config.audio_config.vq_config.codebook_sizes[:-1])
|
768 |
+
], dim=1)
|
769 |
+
cumsum_audio_embed = torch.cumsum(cumsum_audio_embed, dim=1) # (bs, depth-1, d)
|
770 |
+
hidden_states = torch.concat([x.reshape(-1, 1, self.config.hidden_size), cumsum_audio_embed], dim=1) # (bs, depth, d)
|
771 |
+
assert hidden_states.size(1) == len(self.config.audio_config.vq_config.codebook_sizes)
|
772 |
+
for i, tlayer in enumerate(self.transformer_layers):
|
773 |
+
hidden_states = tlayer(hidden_states,)
|
774 |
+
hidden_states = self.headnorm(hidden_states)
|
775 |
+
logits = [head(hidden_states[:,i]) for i, head in enumerate(self.heads)]
|
776 |
+
return logits
|
777 |
+
|
778 |
+
|
779 |
+
class OmniForCausalLM(OmniPreTrainedModel):
|
780 |
+
def __init__(self, config):
|
781 |
+
super().__init__(config)
|
782 |
+
self.config = config
|
783 |
+
self.model = OmniModel(config)
|
784 |
+
self.audio_tokenizer = OmniAudioTokenizer(config)
|
785 |
+
self.audio_head = OmniAudioHead(config)
|
786 |
+
if config.use_norm_head:
|
787 |
+
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
788 |
+
else:
|
789 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
790 |
+
# Initialize weights and apply final processing
|
791 |
+
self.post_init()
|
792 |
+
|
793 |
+
@property
|
794 |
+
def main_device(self):
|
795 |
+
return self.lm_head.weight.device
|
796 |
+
|
797 |
+
def bind_processor(self, tokenizer, **kwargs):
|
798 |
+
self.processor = OmniMMProcessor(
|
799 |
+
tokenizer=tokenizer,
|
800 |
+
config=self.config,
|
801 |
+
**kwargs,
|
802 |
+
)
|
803 |
+
return self.processor
|
804 |
+
|
805 |
+
def get_input_embeddings(self):
|
806 |
+
return self.model.embed_tokens
|
807 |
+
|
808 |
+
def set_input_embeddings(self, value):
|
809 |
+
self.model.embed_tokens = value
|
810 |
+
|
811 |
+
def get_output_embeddings(self):
|
812 |
+
return self.lm_head
|
813 |
+
|
814 |
+
def set_output_embeddings(self, new_embeddings):
|
815 |
+
self.lm_head = new_embeddings
|
816 |
+
|
817 |
+
def set_decoder(self, decoder):
|
818 |
+
self.model = decoder
|
819 |
+
|
820 |
+
def get_decoder(self):
|
821 |
+
return self.model
|
822 |
+
|
823 |
+
def forward(
|
824 |
+
self,
|
825 |
+
input_ids: torch.LongTensor = None,
|
826 |
+
attention_mask: Optional[torch.Tensor] = None,
|
827 |
+
position_ids: Optional[torch.LongTensor] = None,
|
828 |
+
seqlens: Optional[torch.IntTensor] = None,
|
829 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
830 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
831 |
+
labels: Optional[torch.LongTensor] = None,
|
832 |
+
audios: Optional[List|torch.Tensor] = None,
|
833 |
+
audios_tokens: Optional[List|torch.Tensor] = None,
|
834 |
+
encoder_length: Optional[torch.Tensor] = None,
|
835 |
+
bridge_length: Optional[torch.Tensor] = None,
|
836 |
+
images: Optional[torch.Tensor] = None,
|
837 |
+
patch_nums: Optional[torch.Tensor] = None,
|
838 |
+
images_grid: Optional[torch.Tensor] = None,
|
839 |
+
videos: Optional[torch.Tensor] = None,
|
840 |
+
videos_patch_nums: Optional[torch.Tensor] = None,
|
841 |
+
videos_grid: Optional[torch.Tensor] = None,
|
842 |
+
use_cache: Optional[bool] = None,
|
843 |
+
output_attentions: Optional[bool] = None,
|
844 |
+
output_hidden_states: Optional[bool] = None,
|
845 |
+
return_dict: Optional[bool] = None,
|
846 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
847 |
+
|
848 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
849 |
+
output_hidden_states = (
|
850 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
851 |
+
)
|
852 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
853 |
+
|
854 |
+
if audios_tokens is not None:
|
855 |
+
assert isinstance(audios_tokens, torch.Tensor)
|
856 |
+
else:
|
857 |
+
if audios is None or len(audios) == 0:
|
858 |
+
audios_tokens = None
|
859 |
+
else:
|
860 |
+
audios_tokens = self.audio_tokenizer(audios,encoder_length,bridge_length)
|
861 |
+
|
862 |
+
outputs = self.model(
|
863 |
+
input_ids=input_ids,
|
864 |
+
attention_mask=attention_mask,
|
865 |
+
position_ids=position_ids,
|
866 |
+
seqlens=seqlens,
|
867 |
+
past_key_values=past_key_values,
|
868 |
+
inputs_embeds=inputs_embeds,
|
869 |
+
audios_tokens=audios_tokens,
|
870 |
+
images=images,
|
871 |
+
patch_nums=patch_nums,
|
872 |
+
images_grid=images_grid,
|
873 |
+
videos=videos,
|
874 |
+
videos_patch_nums=videos_patch_nums,
|
875 |
+
videos_grid=videos_grid,
|
876 |
+
use_cache=use_cache,
|
877 |
+
output_attentions=output_attentions,
|
878 |
+
output_hidden_states=output_hidden_states,
|
879 |
+
return_dict=return_dict,
|
880 |
+
)
|
881 |
+
hidden_states = outputs.last_hidden_state
|
882 |
+
audios_emb_for_infer = hidden_states[:,-1,:]
|
883 |
+
logits = self.lm_head(hidden_states)
|
884 |
+
|
885 |
+
return OmniMMCausalLMOutputWithPast(
|
886 |
+
logits=logits,
|
887 |
+
past_key_values=outputs.past_key_values,
|
888 |
+
hidden_states=outputs.hidden_states,
|
889 |
+
attentions=outputs.attentions,
|
890 |
+
audios_emb_for_infer=audios_emb_for_infer
|
891 |
+
)
|
892 |
+
|
893 |
+
def prepare_inputs_for_generation(
|
894 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
895 |
+
):
|
896 |
+
if past_key_values:
|
897 |
+
input_ids = input_ids[:, past_key_values[0][0].shape[-2]:]
|
898 |
+
|
899 |
+
position_ids = kwargs.get("position_ids", None)
|
900 |
+
if attention_mask is not None and position_ids is None:
|
901 |
+
# create position_ids on the fly for batch generation
|
902 |
+
position_ids = attention_mask.long().cumsum(-1)
|
903 |
+
# position_ids.masked_fill_(attention_mask == 0, 1)
|
904 |
+
if past_key_values:
|
905 |
+
position_ids = position_ids[:, past_key_values[0][0].shape[-2]:]
|
906 |
+
|
907 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
908 |
+
if inputs_embeds is not None and past_key_values is None:
|
909 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
910 |
+
elif past_key_values is not None:
|
911 |
+
model_inputs = {"input_ids": input_ids}
|
912 |
+
else:
|
913 |
+
model_inputs = {"input_ids": input_ids,
|
914 |
+
"audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
|
915 |
+
"audios_tokens": kwargs.get("audios_tokens", None),
|
916 |
+
"images": kwargs.get("images", None),
|
917 |
+
"videos": kwargs.get("videos", None)
|
918 |
+
}
|
919 |
+
|
920 |
+
model_inputs.update(
|
921 |
+
{
|
922 |
+
"position_ids": position_ids,
|
923 |
+
"past_key_values": past_key_values,
|
924 |
+
"use_cache": kwargs.get("use_cache"),
|
925 |
+
"attention_mask": attention_mask,
|
926 |
+
"images_grid": kwargs.get("images_grid"),
|
927 |
+
"videos_grid": kwargs.get("videos_grid"),
|
928 |
+
"patch_nums": kwargs.get("patch_nums"),
|
929 |
+
"videos_patch_nums": kwargs.get("videos_patch_nums"),
|
930 |
+
}
|
931 |
+
)
|
932 |
+
return model_inputs
|
933 |
+
|
934 |
+
@staticmethod
|
935 |
+
def _reorder_cache(past_key_values, beam_idx):
|
936 |
+
reordered_past = ()
|
937 |
+
for layer_past in past_key_values:
|
938 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
939 |
+
return reordered_past
|
940 |
+
|
941 |
+
def chat(self, tokenizer, messages: List[dict], stream=False,
|
942 |
+
generation_config: Optional[GenerationConfig]=None):
|
943 |
+
generation_config = generation_config or self.generation_config
|
944 |
+
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
945 |
+
if stream:
|
946 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
947 |
+
Thread(target=self.generate, kwargs=dict(
|
948 |
+
inputs=input_ids, streamer=streamer,
|
949 |
+
generation_config=generation_config,
|
950 |
+
)).start()
|
951 |
+
return streamer
|
952 |
+
else:
|
953 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
954 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
955 |
+
return response
|
956 |
+
|
957 |
+
|
958 |
+
class OmniAudioTokenizer(OmniPreTrainedModel):
|
959 |
+
"""
|
960 |
+
Construct an audio tokenizer and decoder.
|
961 |
+
"""
|
962 |
+
def __init__(self, config: OmniConfig):
|
963 |
+
super().__init__(config)
|
964 |
+
self.padding_idx = None
|
965 |
+
self.vocab_size = config.vocab_size
|
966 |
+
self.training = False
|
967 |
+
self.eval()
|
968 |
+
self.audio_model = OmniAudioEncoder(config.audio_config)
|
969 |
+
self.audio_bridge_model = OmniAudioVQBridgeTokenizer(config)
|
970 |
+
if config.vocoder_config.enable:
|
971 |
+
self.audio_decoder = OmniAudioDecoder(config)
|
972 |
+
if config.flow_matching_config.enable:
|
973 |
+
self.audio_flow_matching_decoder = OmniAudioFlowMatchingDecoder(config)
|
974 |
+
|
975 |
+
def encode(self, x, encoder_length: Optional[torch.Tensor] = None,
|
976 |
+
bridge_length: Optional[torch.Tensor] = None):
|
977 |
+
audio_emb = self.audio_model(x, encoder_length)
|
978 |
+
audios_tokens = self.audio_bridge_model(audio_emb, bridge_length)
|
979 |
+
return audios_tokens
|
980 |
+
|
981 |
+
def decode(self, audio_code_ids, bridge_length: Optional[torch.Tensor] = None):
|
982 |
+
assert self.config.vocoder_config.enable, "Vocoder is not enabled in config."
|
983 |
+
audio_emb = self.audio_bridge_model.decode(audio_code_ids)
|
984 |
+
audio_dec = self.audio_decoder(
|
985 |
+
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
|
986 |
+
)
|
987 |
+
if self.config.flow_matching_config.enable:
|
988 |
+
if self.config.flow_matching_config.use_hidden_states_before_dconv2:
|
989 |
+
hidden_states, hidden_states_length = (
|
990 |
+
self.audio_flow_matching_decoder.unpack_hidden_states(
|
991 |
+
audio_dec.hidden_states_before_dconv2,
|
992 |
+
audio_dec.output_length_before_dconv2,
|
993 |
+
)
|
994 |
+
)
|
995 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
996 |
+
hidden_states, hidden_states_length
|
997 |
+
)
|
998 |
+
|
999 |
+
else:
|
1000 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
1001 |
+
audio_dec.refined_mel, audio_dec.mel_length
|
1002 |
+
)
|
1003 |
+
return audio_flow_matching_decoder_ret
|
1004 |
+
else:
|
1005 |
+
return audio_dec
|
1006 |
+
|
1007 |
+
@torch.no_grad()
|
1008 |
+
def forward(self, audios, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
|
1009 |
+
self.eval()
|
1010 |
+
audios_tokens = self.encode(audios, encoder_length, bridge_length)
|
1011 |
+
return audios_tokens
|
processor_omni.py
ADDED
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import re, ujson, os, sys, fire, glob, random, time, json
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import default_collate
|
7 |
+
import torchaudio
|
8 |
+
from typing import *
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
import transformers
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
|
13 |
+
from functools import lru_cache
|
14 |
+
from io import BytesIO
|
15 |
+
from PIL import Image
|
16 |
+
import concurrent.futures as cf
|
17 |
+
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
|
18 |
+
from transformers.image_utils import PILImageResampling
|
19 |
+
from PIL import Image, ImageOps
|
20 |
+
from PIL import ImageFile
|
21 |
+
torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
|
22 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
23 |
+
import base64
|
24 |
+
from decord import VideoReader, cpu
|
25 |
+
import cv2
|
26 |
+
import av
|
27 |
+
import imagesize
|
28 |
+
import tempfile
|
29 |
+
import math
|
30 |
+
from multiprocessing import Pool
|
31 |
+
from cairosvg import svg2png
|
32 |
+
import hashlib
|
33 |
+
|
34 |
+
IMAGE_FACTOR = 28
|
35 |
+
MIN_PIXELS = 4 * 28 * 28
|
36 |
+
MAX_PIXELS = 16384 * 28 * 28
|
37 |
+
MAX_RATIO = 200
|
38 |
+
|
39 |
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
40 |
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
41 |
+
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
42 |
+
FRAME_FACTOR = 2
|
43 |
+
FPS = 2.0
|
44 |
+
FPS_MIN_FRAMES = 4
|
45 |
+
FPS_MAX_FRAMES = 768
|
46 |
+
|
47 |
+
def round_by_factor(number: int, factor: int) -> int:
|
48 |
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
49 |
+
return round(number / factor) * factor
|
50 |
+
|
51 |
+
|
52 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
53 |
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
54 |
+
return math.ceil(number / factor) * factor
|
55 |
+
|
56 |
+
|
57 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
58 |
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
59 |
+
return math.floor(number / factor) * factor
|
60 |
+
|
61 |
+
|
62 |
+
def smart_resize(
|
63 |
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
64 |
+
) -> tuple[int, int]:
|
65 |
+
"""
|
66 |
+
Rescales the image so that the following conditions are met:
|
67 |
+
|
68 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
69 |
+
|
70 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
71 |
+
|
72 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
73 |
+
"""
|
74 |
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
75 |
+
raise ValueError(
|
76 |
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
77 |
+
)
|
78 |
+
h_bar = max(factor, round_by_factor(height, factor))
|
79 |
+
w_bar = max(factor, round_by_factor(width, factor))
|
80 |
+
if h_bar * w_bar > max_pixels:
|
81 |
+
beta = math.sqrt((height * width) / max_pixels)
|
82 |
+
h_bar = floor_by_factor(height / beta, factor)
|
83 |
+
w_bar = floor_by_factor(width / beta, factor)
|
84 |
+
elif h_bar * w_bar < min_pixels:
|
85 |
+
beta = math.sqrt(min_pixels / (height * width))
|
86 |
+
h_bar = ceil_by_factor(height * beta, factor)
|
87 |
+
w_bar = ceil_by_factor(width * beta, factor)
|
88 |
+
return h_bar, w_bar
|
89 |
+
|
90 |
+
|
91 |
+
def split_text(text, match_regex):
|
92 |
+
matches = list(re.finditer(match_regex, text))
|
93 |
+
# 初始化结果列表
|
94 |
+
result = []
|
95 |
+
match_flag_list = []
|
96 |
+
# 上一个匹配的结束位置
|
97 |
+
last_end = 0
|
98 |
+
# 遍历所有匹配项
|
99 |
+
for match in matches:
|
100 |
+
# 添加匹配项之前的部分
|
101 |
+
if text[last_end:match.start()]:
|
102 |
+
result.append(text[last_end:match.start()])
|
103 |
+
match_flag_list.append(False)
|
104 |
+
# 添加匹配项
|
105 |
+
result.append(match.group(0))
|
106 |
+
match_flag_list.append(True)
|
107 |
+
# 更新上一个匹配的结束位置
|
108 |
+
last_end = match.end()
|
109 |
+
# 添加最后一个匹配项之后的部分
|
110 |
+
if text[last_end:]:
|
111 |
+
result.append(text[last_end:])
|
112 |
+
match_flag_list.append(False)
|
113 |
+
return result, match_flag_list
|
114 |
+
|
115 |
+
|
116 |
+
def read_video(image_path, max_frame_number, decode_way):
|
117 |
+
if decode_way=='1fps':
|
118 |
+
try:
|
119 |
+
# print(image_path)
|
120 |
+
vr = VideoReader(image_path, ctx=cpu(0))
|
121 |
+
total_frame_num = len(vr)
|
122 |
+
fps = round(vr.get_avg_fps())
|
123 |
+
frame_idx = [i for i in range(0, len(vr), fps)]
|
124 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
125 |
+
cnt = len(frames)
|
126 |
+
frame_times = range(cnt)
|
127 |
+
except Exception as e:
|
128 |
+
print(image_path)
|
129 |
+
print('error is', e)
|
130 |
+
return None
|
131 |
+
elif decode_way=='key':
|
132 |
+
try:
|
133 |
+
with av.open(image_path) as container:
|
134 |
+
stream = container.streams.video[0]
|
135 |
+
stream.codec_context.skip_frame = 'NONKEY'
|
136 |
+
frames = []
|
137 |
+
frame_times = []
|
138 |
+
fps = int(stream.average_rate)
|
139 |
+
cnt = 0
|
140 |
+
for frame in container.decode(stream): # 关键帧存成image patch
|
141 |
+
image = np.array(frame.to_image())
|
142 |
+
frames.append(image)
|
143 |
+
frame_time = int(frame.time)
|
144 |
+
frame_times.append(frame_time)
|
145 |
+
cnt += 1
|
146 |
+
except Exception as e:
|
147 |
+
print('error is', e)
|
148 |
+
return None
|
149 |
+
if frames is None or len(frames)==0:
|
150 |
+
return None
|
151 |
+
if len(frames)>max_frame_number and max_frame_number>0:
|
152 |
+
# 生成14个均匀间隔的索引
|
153 |
+
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
|
154 |
+
# 根据索引获取对应元素
|
155 |
+
frames = frames[indices]
|
156 |
+
frame_times = frame_times[indices]
|
157 |
+
return frames, frame_times
|
158 |
+
|
159 |
+
|
160 |
+
class OmniImageProcessor:
|
161 |
+
def __init__(self, config, **kwargs):
|
162 |
+
self.config = config # visual_config
|
163 |
+
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
|
164 |
+
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
|
165 |
+
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
|
166 |
+
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
|
167 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
168 |
+
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
|
169 |
+
|
170 |
+
def image_transform(self, strseq, return_mm_data = True):
|
171 |
+
image = None
|
172 |
+
if isinstance(strseq, str):
|
173 |
+
if return_mm_data:
|
174 |
+
image = Image.open(strseq).convert("RGB")
|
175 |
+
else:
|
176 |
+
try:
|
177 |
+
image = Image.open(BytesIO(strseq)).convert("RGB")
|
178 |
+
except:
|
179 |
+
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
|
180 |
+
|
181 |
+
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
|
182 |
+
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
|
183 |
+
|
184 |
+
# resize, crop, scale, normalize
|
185 |
+
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
|
186 |
+
resized_height, resized_width = smart_resize(
|
187 |
+
image_org_size[0], image_org_size[1],
|
188 |
+
factor=self.patch_size * self.spatial_merge_size,
|
189 |
+
min_pixels=self.min_pixels,
|
190 |
+
max_pixels=self.max_pixels,
|
191 |
+
)
|
192 |
+
output_size = (resized_height, resized_width)
|
193 |
+
|
194 |
+
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
|
195 |
+
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
|
196 |
+
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
|
197 |
+
image = resize(image, output_size, PILImageResampling.BICUBIC)
|
198 |
+
img = image.transpose(2, 0, 1)
|
199 |
+
# 对图像进行归一化和标准化处理
|
200 |
+
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
|
201 |
+
# 处理成patch
|
202 |
+
patches = image[np.newaxis, :]
|
203 |
+
if patches.shape[0] == 1:
|
204 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
205 |
+
channel = patches.shape[1]
|
206 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
207 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
208 |
+
patches = patches.reshape(
|
209 |
+
grid_t,
|
210 |
+
self.temporal_patch_size,
|
211 |
+
channel,
|
212 |
+
grid_h // self.spatial_merge_size,
|
213 |
+
self.spatial_merge_size,
|
214 |
+
self.patch_size,
|
215 |
+
grid_w // self.spatial_merge_size,
|
216 |
+
self.spatial_merge_size,
|
217 |
+
self.patch_size,
|
218 |
+
)
|
219 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
220 |
+
flatten_patches = patches.reshape(
|
221 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
222 |
+
)
|
223 |
+
|
224 |
+
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
|
225 |
+
|
226 |
+
|
227 |
+
class OmniAudioProcessor:
|
228 |
+
# 包含基本的音频特征抽取模块 + 输入数据解析模块
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
config, # audio processor config
|
232 |
+
**kwargs
|
233 |
+
):
|
234 |
+
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
|
235 |
+
assert(len(torchaudio.list_audio_backends()) > 0)
|
236 |
+
self.config = config
|
237 |
+
self.mel_filters = mel_filter_bank(
|
238 |
+
num_frequency_bins=1 + self.config.n_fft // 2,
|
239 |
+
num_mel_filters=self.config.num_mel_bins,
|
240 |
+
min_frequency=0.0,
|
241 |
+
max_frequency=self.config.sampling_rate / 2.0,
|
242 |
+
sampling_rate=self.config.sampling_rate,
|
243 |
+
norm="slaney",
|
244 |
+
mel_scale="slaney",
|
245 |
+
)
|
246 |
+
self.window = torch.hann_window(self.config.n_fft)
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-6):
|
250 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def zero_mean_unit_var_norm(x):
|
254 |
+
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
255 |
+
|
256 |
+
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
|
257 |
+
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
|
258 |
+
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
259 |
+
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
|
260 |
+
if self.config.sampling_rate != metadata.sample_rate:
|
261 |
+
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
|
262 |
+
|
263 |
+
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
264 |
+
if metadata.num_channels > 1:
|
265 |
+
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
266 |
+
|
267 |
+
# normalized to zero mean
|
268 |
+
if do_normalize:
|
269 |
+
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
270 |
+
|
271 |
+
if return_tensors: # (channels, samples)
|
272 |
+
return waveform_tensor
|
273 |
+
else:
|
274 |
+
return waveform_tensor.numpy()
|
275 |
+
|
276 |
+
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
277 |
+
channels, wave_samples = waveform.shape
|
278 |
+
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
|
279 |
+
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
|
280 |
+
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
281 |
+
|
282 |
+
split_waveform, start = [], 0
|
283 |
+
while start < wave_samples: # 统一按秒数对齐overlap
|
284 |
+
if start > int(self.config.sampling_rate * self.config.split_overlap):
|
285 |
+
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
286 |
+
end = min(start + max_audio_samples, wave_samples)
|
287 |
+
if end - start>= self.config.n_fft: # 保证至少有一帧数据
|
288 |
+
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
289 |
+
start = end
|
290 |
+
return split_waveform
|
291 |
+
|
292 |
+
@classmethod
|
293 |
+
def inference_output_length(cls, config, input_length):
|
294 |
+
# for whisper + bridge
|
295 |
+
kernel_size = config.kernel_size
|
296 |
+
stride_size = config.stride_size
|
297 |
+
avg_pooler = config.avg_pooler
|
298 |
+
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
299 |
+
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
300 |
+
if avg_pooler > 1:
|
301 |
+
bridge_length = encoder_length // avg_pooler
|
302 |
+
return encoder_length, bridge_length
|
303 |
+
|
304 |
+
def extract_fbank_features(self, waveform):
|
305 |
+
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
306 |
+
channels, wave_samples = waveform.shape
|
307 |
+
assert(wave_samples >= self.config.n_fft)
|
308 |
+
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
|
309 |
+
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
|
310 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
|
311 |
+
else:
|
312 |
+
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
|
313 |
+
|
314 |
+
# window = torch.hann_window(self.config.n_fft)
|
315 |
+
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
316 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
317 |
+
|
318 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
319 |
+
mel_spec = mel_filters.T @ magnitudes
|
320 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
321 |
+
if waveform.dim() == 2:
|
322 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
323 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
324 |
+
else:
|
325 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
326 |
+
log_spec = (log_spec + 4.0) / 4.0
|
327 |
+
|
328 |
+
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
329 |
+
log_spec[:, valid_frame_nums:] = 0.0 # pad0
|
330 |
+
|
331 |
+
return log_spec, valid_frame_nums
|
332 |
+
|
333 |
+
def data_augment(self, feature: np.array, input_length, training=True):
|
334 |
+
# reference https://arxiv.org/pdf/1904.08779
|
335 |
+
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
|
336 |
+
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
|
337 |
+
num_masked_span = max(num_masked_span, min_masks)
|
338 |
+
start_indices = list(range(input_length - mask_length))
|
339 |
+
random.shuffle(start_indices)
|
340 |
+
start_indices = start_indices[:num_masked_span]
|
341 |
+
return start_indices
|
342 |
+
|
343 |
+
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
|
344 |
+
return feature
|
345 |
+
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
|
346 |
+
return feature
|
347 |
+
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
|
348 |
+
return feature
|
349 |
+
|
350 |
+
if self.config.mask_time_prob > 0:
|
351 |
+
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
|
352 |
+
for start_idx in start_indices:
|
353 |
+
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
|
354 |
+
if self.config.mask_feature_prob > 0:
|
355 |
+
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
|
356 |
+
for start_idx in start_indices:
|
357 |
+
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
|
358 |
+
|
359 |
+
return feature
|
360 |
+
|
361 |
+
@dataclass
|
362 |
+
class OmniProcessorOutput(ModelOutput):
|
363 |
+
input_ids: Optional["List|torch.Tensor"] = None
|
364 |
+
labels: Optional["List|torch.Tensor"] = None
|
365 |
+
attention_mask: Optional["List|torch.Tensor"] = None
|
366 |
+
position_ids: Optional["List|torch.Tensor"] = None
|
367 |
+
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
|
368 |
+
# audio fields
|
369 |
+
audios: Optional["List|torch.Tensor"] = None
|
370 |
+
encoder_length: Optional["List|torch.Tensor"] = None
|
371 |
+
bridge_length: Optional["List|torch.Tensor"] = None
|
372 |
+
# image fields
|
373 |
+
images: Optional["List|torch.Tensor"] = None
|
374 |
+
patch_nums: Optional["List|torch.Tensor"] = None
|
375 |
+
images_size: Optional["List|torch.Tensor"] = None
|
376 |
+
crop_size: Optional["List|torch.Tensor"] = None
|
377 |
+
images_grid: Optional["List|torch.Tensor"] = None
|
378 |
+
# video fields
|
379 |
+
videos: Optional["List|torch.Tensor"] = None
|
380 |
+
videos_patch_nums: Optional["List|torch.Tensor"] = None
|
381 |
+
videos_size: Optional["List|torch.Tensor"] = None
|
382 |
+
videos_crop_size: Optional["List|torch.Tensor"] = None
|
383 |
+
videos_grid: Optional["List|torch.Tensor"] = None
|
384 |
+
# processor fields
|
385 |
+
raw_text: Optional[str] = None
|
386 |
+
index: Optional[int] = None
|
387 |
+
|
388 |
+
def concatenate(self, other): # 仅限list使用
|
389 |
+
def concat_one(a, b):
|
390 |
+
if a is None and b is None:
|
391 |
+
return None
|
392 |
+
elif a is None and b is not None:
|
393 |
+
return b
|
394 |
+
elif a is not None and b is None:
|
395 |
+
return a
|
396 |
+
else:
|
397 |
+
return a + b
|
398 |
+
return OmniProcessorOutput(
|
399 |
+
input_ids=concat_one(self.input_ids, other.input_ids),
|
400 |
+
labels=concat_one(self.labels, other.labels),
|
401 |
+
audios=concat_one(self.audios, other.audios),
|
402 |
+
encoder_length=concat_one(self.encoder_length, other.encoder_length),
|
403 |
+
bridge_length=concat_one(self.bridge_length, other.bridge_length),
|
404 |
+
images=concat_one(self.images, other.images),
|
405 |
+
images_grid=concat_one(self.images_grid, other.images_grid),
|
406 |
+
patch_nums=concat_one(self.patch_nums, other.patch_nums),
|
407 |
+
|
408 |
+
videos=concat_one(self.videos, other.videos),
|
409 |
+
videos_grid=concat_one(self.videos_grid, other.videos_grid),
|
410 |
+
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
|
411 |
+
|
412 |
+
position_ids=concat_one(self.position_ids, other.position_ids),
|
413 |
+
seqlens=concat_one(self.seqlens, other.seqlens),
|
414 |
+
images_size=concat_one(self.images_size, other.images_size),
|
415 |
+
videos_size=concat_one(self.videos_size, other.videos_size),
|
416 |
+
index = self.index # concat保持index不变
|
417 |
+
)
|
418 |
+
|
419 |
+
class OmniMMProcessor(object):
|
420 |
+
def __init__(self,
|
421 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
422 |
+
config,
|
423 |
+
training,
|
424 |
+
relative_path=None,
|
425 |
+
parallel=None,
|
426 |
+
**kwargs,
|
427 |
+
):
|
428 |
+
self.tokenizer = tokenizer
|
429 |
+
self.config = config
|
430 |
+
self.audio_processor = OmniAudioProcessor(config.audio_config)
|
431 |
+
self.visual_processor = None
|
432 |
+
if hasattr(config, "visual_config"):
|
433 |
+
self.visual_processor = OmniImageProcessor(config.visual_config)
|
434 |
+
self.video_processor = None
|
435 |
+
if hasattr(config, "video_config"):
|
436 |
+
self.video_processor = OmniImageProcessor(config.video_config)
|
437 |
+
self.training = training
|
438 |
+
self.relative_path = relative_path
|
439 |
+
self.parallel = parallel
|
440 |
+
# audio tag
|
441 |
+
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
|
442 |
+
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
|
443 |
+
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
|
444 |
+
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
|
445 |
+
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
|
446 |
+
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
|
447 |
+
# image tag
|
448 |
+
self.image_start_tag = None
|
449 |
+
self.image_end_tag = None
|
450 |
+
self.image_pad_tag = None
|
451 |
+
self.video_start_tag = None
|
452 |
+
self.video_end_tag = None
|
453 |
+
# videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
|
454 |
+
self.videoframe_start_tag = '<videoframe_start_omni>'
|
455 |
+
self.videoframe_end_tag = '<videoframe_end_omni>'
|
456 |
+
if hasattr(self.config, "visual_config"):
|
457 |
+
# special token for start_tag
|
458 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
|
459 |
+
# special token for end_tag
|
460 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
|
461 |
+
# special token for pad_tag
|
462 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
|
463 |
+
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
|
464 |
+
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
|
465 |
+
if hasattr(self.config, "video_config"):
|
466 |
+
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
|
467 |
+
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
|
468 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
|
469 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
|
470 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
|
471 |
+
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
|
472 |
+
|
473 |
+
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
|
474 |
+
|
475 |
+
|
476 |
+
# @lru_cache(maxsize=1024)
|
477 |
+
def _get_audio(self, audio_info):
|
478 |
+
try:
|
479 |
+
audio_info = ujson.loads(audio_info)
|
480 |
+
if 'path' in audio_info.keys():
|
481 |
+
audio_uri = None
|
482 |
+
if os.path.exists(audio_info['path']):
|
483 |
+
audio_uri = audio_info['path']
|
484 |
+
elif self.relative_path is not None:
|
485 |
+
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
|
486 |
+
if not os.path.exists(audio_uri):
|
487 |
+
audio_uri = None
|
488 |
+
if audio_uri is not None:
|
489 |
+
waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
|
490 |
+
waveforms = self.audio_processor.split_with_overlap(waveform)
|
491 |
+
|
492 |
+
ret = OmniProcessorOutput() # 默认初始化 audios字段为None
|
493 |
+
for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
|
494 |
+
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
|
495 |
+
audio = self.audio_processor.data_augment(audio, input_length, self.training)
|
496 |
+
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
|
497 |
+
if bridge_length <= 0:
|
498 |
+
continue
|
499 |
+
current_ret = OmniProcessorOutput(
|
500 |
+
audios=[audio[:,:input_length]],
|
501 |
+
encoder_length=[encoder_length],
|
502 |
+
bridge_length=[bridge_length],
|
503 |
+
)
|
504 |
+
if ret.audios is None:
|
505 |
+
ret = current_ret
|
506 |
+
else:
|
507 |
+
ret = ret.concatenate(current_ret) # 拼接多个切片
|
508 |
+
return ret
|
509 |
+
else:
|
510 |
+
raise ValueError("can not find path in audio_info")
|
511 |
+
except Exception as e:
|
512 |
+
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
|
513 |
+
return OmniProcessorOutput()
|
514 |
+
|
515 |
+
# @lru_cache(maxsize=1024)
|
516 |
+
def _get_image(self, image_info):
|
517 |
+
try:
|
518 |
+
try:
|
519 |
+
image_info = ujson.loads(image_info)
|
520 |
+
except:
|
521 |
+
image_info = re.sub(r"(?<!\\)'", '"', image_info)
|
522 |
+
image_info = ujson.loads(image_info)
|
523 |
+
if 'base64' in image_info.keys():
|
524 |
+
image_data = base64.b64decode(image_info['base64'])
|
525 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
|
526 |
+
elif 'local' in image_info.keys():
|
527 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
|
528 |
+
elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
|
529 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
|
530 |
+
elif 'url' in image_info.keys():
|
531 |
+
image_bytes = self._get_vision_obj_byte('url', image_info['url'])
|
532 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
|
533 |
+
else:
|
534 |
+
raise ValueError("can not find any path in image_info")
|
535 |
+
|
536 |
+
merge_length = self.visual_processor.merge_size**2
|
537 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
538 |
+
|
539 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
540 |
+
return OmniProcessorOutput(
|
541 |
+
images=[image_feat],
|
542 |
+
patch_nums=[patch_nums],
|
543 |
+
crop_size=[image_list],
|
544 |
+
images_size= [org_size],
|
545 |
+
images_grid=[image_list]
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
|
549 |
+
return OmniProcessorOutput()
|
550 |
+
|
551 |
+
except Exception as e:
|
552 |
+
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
|
553 |
+
return OmniProcessorOutput()
|
554 |
+
|
555 |
+
# @lru_cache(maxsize=1024)
|
556 |
+
def _get_video_frame(self, video_frame_infos):
|
557 |
+
try:
|
558 |
+
pattern = r'\{.*?\}'
|
559 |
+
matches = re.findall(pattern, video_frame_infos)
|
560 |
+
ret = OmniProcessorOutput()
|
561 |
+
# 逐个解析
|
562 |
+
for match in matches:
|
563 |
+
video_frame_info = ujson.loads(match)
|
564 |
+
# video_frame_info = ujson.loads(video_frame_info)
|
565 |
+
if 'local' in video_frame_info.keys():
|
566 |
+
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
|
567 |
+
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
|
568 |
+
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
|
569 |
+
else:
|
570 |
+
raise ValueError("can not find any path in video_info")
|
571 |
+
|
572 |
+
merge_length = self.video_processor.merge_size**2
|
573 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
574 |
+
|
575 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
576 |
+
ret = ret.concatenate(
|
577 |
+
OmniProcessorOutput(
|
578 |
+
videos=[image_feat],
|
579 |
+
videos_patch_nums=[patch_nums],
|
580 |
+
videos_crop_size=[image_list],
|
581 |
+
videos_size= [org_size],
|
582 |
+
videos_grid=[image_list]
|
583 |
+
)
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
|
587 |
+
return ret
|
588 |
+
|
589 |
+
except Exception as e:
|
590 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
|
591 |
+
return OmniProcessorOutput()
|
592 |
+
|
593 |
+
# 读取视频
|
594 |
+
def _get_vision_obj_byte(self, source, path):
|
595 |
+
vision_obj_byte = None
|
596 |
+
if source == "local":
|
597 |
+
if os.path.exists(path):
|
598 |
+
vision_obj_byte = open(path, "rb").read()
|
599 |
+
else:
|
600 |
+
vision_obj_byte = None
|
601 |
+
if source == "base64":
|
602 |
+
vision_obj_byte = base64.b64decode(path)
|
603 |
+
if source == "url":
|
604 |
+
vision_obj_byte = requests.get(url=path).content
|
605 |
+
return vision_obj_byte
|
606 |
+
|
607 |
+
# 将视频切分为帧,保存至子目录中
|
608 |
+
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
|
609 |
+
if decode_way=='1fps':
|
610 |
+
frame_suffix = f'_frames'
|
611 |
+
elif decode_way=='key':
|
612 |
+
frame_suffix = f'_keyframes'
|
613 |
+
else:
|
614 |
+
raise ValueError('unvalid decode way!!!')
|
615 |
+
|
616 |
+
server = "local"
|
617 |
+
if 'local' in video_info.keys():
|
618 |
+
# 本地路径
|
619 |
+
video_path = video_info['local']
|
620 |
+
# 帧保存本地路径
|
621 |
+
frame_path = video_path[:video_path.rfind('.')] + frame_suffix
|
622 |
+
mm_obj_byte = self._get_vision_obj_byte('local', video_path)
|
623 |
+
elif 'base64' in video_info.keys():
|
624 |
+
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
|
625 |
+
if self.relative_path is not None:
|
626 |
+
video_path = os.path.join(self.relative_path, md5)
|
627 |
+
else:
|
628 |
+
video_path = os.path.join(os.getcwd(), md5)
|
629 |
+
frame_path = video_path + frame_suffix
|
630 |
+
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
|
631 |
+
elif 'url' in video_info.keys():
|
632 |
+
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
|
633 |
+
if self.relative_path is not None:
|
634 |
+
video_path = os.path.join(self.relative_path, md5)
|
635 |
+
else:
|
636 |
+
video_path = os.path.join(os.getcwd(), md5)
|
637 |
+
frame_path = video_path + frame_suffix
|
638 |
+
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
|
639 |
+
else:
|
640 |
+
raise ValueError('unvalid video server !!!')
|
641 |
+
return ""
|
642 |
+
|
643 |
+
if mm_obj_byte is None: # 未读取到视频文件
|
644 |
+
return ""
|
645 |
+
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
|
646 |
+
# 保存帧
|
647 |
+
os.makedirs(frame_path, exist_ok=True)
|
648 |
+
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
|
649 |
+
for frame_idx, frame in enumerate(frames):
|
650 |
+
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
|
651 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
652 |
+
cv2.imwrite(output_filename, frame)
|
653 |
+
frame_paths = os.listdir(frame_path)
|
654 |
+
|
655 |
+
# 选取帧
|
656 |
+
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
|
657 |
+
frame_times.sort() #从小到大排序
|
658 |
+
frame_number = len(frame_times)
|
659 |
+
if frame_number > max_frame_number:
|
660 |
+
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
|
661 |
+
else:
|
662 |
+
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
|
663 |
+
# 拼接模式
|
664 |
+
replace_str = ""
|
665 |
+
for frame_idx, idx in enumerate(indices):
|
666 |
+
frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
|
667 |
+
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
|
668 |
+
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
|
669 |
+
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
|
670 |
+
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
|
671 |
+
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
|
672 |
+
replace_str += frame_str
|
673 |
+
|
674 |
+
return replace_str
|
675 |
+
|
676 |
+
def sample_frame(self,frames_str,max_frame = 32):
|
677 |
+
def uniform_sample(lst, num_samples):
|
678 |
+
if num_samples > len(lst):
|
679 |
+
return lst
|
680 |
+
interval = len(lst) / num_samples
|
681 |
+
samples = [lst[int(i * interval)] for i in range(num_samples)]
|
682 |
+
return samples
|
683 |
+
p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
|
684 |
+
frames_str_split = re.split(p,frames_str)
|
685 |
+
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
|
686 |
+
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
|
687 |
+
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
|
688 |
+
|
689 |
+
def _get_video_frame_str(self, video_info):
|
690 |
+
try:
|
691 |
+
if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
|
692 |
+
frames_str = video_info
|
693 |
+
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
|
694 |
+
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
|
695 |
+
video_info = ujson.loads(video_info)
|
696 |
+
# 获取包含多���图像路径的字符串,最大帧数量max_frame_number
|
697 |
+
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
|
698 |
+
return frames_str
|
699 |
+
except Exception as e:
|
700 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
|
701 |
+
return ""
|
702 |
+
|
703 |
+
def _replace_image(self, image_text):
|
704 |
+
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
|
705 |
+
ret = self._get_image(image_info) # 重复取结果 cached result
|
706 |
+
if ret.patch_nums is None:
|
707 |
+
return ''
|
708 |
+
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
|
709 |
+
|
710 |
+
def _replace_video_frame(self, video_frame_text):
|
711 |
+
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
|
712 |
+
ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
|
713 |
+
if ret.videos_patch_nums is None:
|
714 |
+
return ''
|
715 |
+
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
|
716 |
+
return ret, ''.join(video_frame_str)
|
717 |
+
|
718 |
+
|
719 |
+
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
|
720 |
+
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
|
721 |
+
if (self.audio_start_tag != None) and (mtype == 'audio'):
|
722 |
+
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
|
723 |
+
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
|
724 |
+
elif (self.image_start_tag != None) and (mtype == 'image'):
|
725 |
+
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
|
726 |
+
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
|
727 |
+
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
|
728 |
+
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
|
729 |
+
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
|
730 |
+
elif (self.video_start_tag != None) and (mtype == 'video'):
|
731 |
+
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
|
732 |
+
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
|
733 |
+
else:
|
734 |
+
raise ValueError("mtype not supportted!")
|
735 |
+
new_text_list = []
|
736 |
+
new_mm_label_list = []
|
737 |
+
new_trainable_flag_list = []
|
738 |
+
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
|
739 |
+
for t,m in zip(*split_text(text, match_regex)):
|
740 |
+
new_trainable_flag_list.append(trainable)
|
741 |
+
if m:
|
742 |
+
new_text_list.append(re.sub(drop_regex, '', t))
|
743 |
+
new_mm_label_list.append(mtype)
|
744 |
+
else:
|
745 |
+
new_text_list.append(t)
|
746 |
+
new_mm_label_list.append(mm_label)
|
747 |
+
return new_text_list, new_mm_label_list, new_trainable_flag_list
|
748 |
+
|
749 |
+
def process_multimodal_chunk(self, text, mm_label, trainable):
|
750 |
+
ret = OmniProcessorOutput()
|
751 |
+
if mm_label == 'audio':
|
752 |
+
ret = self._get_audio(text)
|
753 |
+
if ret.bridge_length is not None:
|
754 |
+
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
|
755 |
+
else:
|
756 |
+
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
757 |
+
elif mm_label == 'audiogen':
|
758 |
+
ret = self._get_audio(text)
|
759 |
+
if ret.bridge_length is not None:
|
760 |
+
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
|
761 |
+
else:
|
762 |
+
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
763 |
+
elif mm_label == 'image':
|
764 |
+
ret, input_str = self._replace_image(text)
|
765 |
+
if input_str:
|
766 |
+
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
767 |
+
else:
|
768 |
+
raise ValueError("Get image data Failed at Process image chunk")
|
769 |
+
elif mm_label == 'video':
|
770 |
+
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
|
771 |
+
ret, input_str = self._replace_video_frame(frame_str)
|
772 |
+
if input_str:
|
773 |
+
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
774 |
+
else:
|
775 |
+
raise ValueError("Get video data Failed at Process video chunk")
|
776 |
+
elif mm_label == 'text':
|
777 |
+
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
778 |
+
if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
|
779 |
+
raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
|
780 |
+
else:
|
781 |
+
raise ValueError(f"mm_label not supportted! must in ['audio', 'audiogen', 'image', 'video', 'text'] but get {mm_label}")
|
782 |
+
return ret
|
783 |
+
|
784 |
+
def process_one(self, text, index=0, raw_only=False):
|
785 |
+
ret = OmniProcessorOutput(index=index)
|
786 |
+
all_text_list = []
|
787 |
+
all_mm_label_list = []
|
788 |
+
all_trainable_flag_list = []
|
789 |
+
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
|
790 |
+
if len(text_list) == 1:
|
791 |
+
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
|
792 |
+
all_text_list.append(text)
|
793 |
+
all_mm_label_list.append('text')
|
794 |
+
all_trainable_flag_list.append(True)
|
795 |
+
else:
|
796 |
+
for text, match in zip(text_list, match_flag):
|
797 |
+
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
|
798 |
+
if text.strip() == '':
|
799 |
+
continue # 把多余的空格干掉
|
800 |
+
all_text_list.append(text)
|
801 |
+
all_mm_label_list.append('text')
|
802 |
+
all_trainable_flag_list.append(match)
|
803 |
+
# 处理多模态信息
|
804 |
+
for mtype in self.config.multimodal: # 循环获取音频 图像结果
|
805 |
+
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
|
806 |
+
if len(all_text_list) == 0:
|
807 |
+
print(f"Process {text} chunk error: No valid Text data!!!!!")
|
808 |
+
return OmniProcessorOutput(index=index)
|
809 |
+
|
810 |
+
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
|
811 |
+
try:
|
812 |
+
mret = self.process_multimodal_chunk(text, mm_label, trainable)
|
813 |
+
ret = ret.concatenate(mret)
|
814 |
+
except ValueError as e:
|
815 |
+
tt = text[:24].replace('\n','<LF>')
|
816 |
+
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
|
817 |
+
return OmniProcessorOutput(index=index)
|
818 |
+
|
819 |
+
if raw_only:
|
820 |
+
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
|
821 |
+
return ret
|
822 |
+
return ret
|
823 |
+
|
824 |
+
@torch.no_grad()
|
825 |
+
def __call__(self, example, parallel=128):
|
826 |
+
if isinstance(example, Dict):
|
827 |
+
pass
|
828 |
+
elif isinstance(example, str):
|
829 |
+
return self.process_one(example)
|
830 |
+
elif isinstance(example, List): # batch推理 异步多线程处理
|
831 |
+
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
|
832 |
+
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
|
833 |
+
batch_data = [key.result() for key in cf.as_completed(future_list)]
|
834 |
+
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
|
835 |
+
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
|
836 |
+
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
|
837 |
+
|
838 |
+
ret = OmniProcessorOutput()
|
839 |
+
for i in range(len(batch_data)):
|
840 |
+
ret = ret.concatenate(batch_data[i])
|
841 |
+
self.tokenizer.padding_side = "left"
|
842 |
+
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
|
843 |
+
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
|
844 |
+
ret.input_ids = padding_result["input_ids"]
|
845 |
+
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
|
846 |
+
|
847 |
+
if ret.audios is not None:
|
848 |
+
max_audios_len = max([x.shape[-1] for x in ret.audios])
|
849 |
+
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
|
850 |
+
|
851 |
+
ret.encoder_length = default_collate(ret.encoder_length)
|
852 |
+
ret.bridge_length = default_collate(ret.bridge_length)
|
853 |
+
|
854 |
+
if ret.images is not None:
|
855 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
|
856 |
+
ret.patch_nums = default_collate(ret.patch_nums)
|
857 |
+
|
858 |
+
if ret.videos is not None:
|
859 |
+
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
|
860 |
+
ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
|
861 |
+
|
862 |
+
return ret
|
863 |
+
|
864 |
+
else:
|
865 |
+
raise ValueError("example format supported yet")
|
sequence_parallel_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
from flash_attn import flash_attn_varlen_func
|
8 |
+
try:
|
9 |
+
import deepspeed.comm as dist
|
10 |
+
except:
|
11 |
+
dist = None
|
12 |
+
|
13 |
+
|
14 |
+
try:
|
15 |
+
from utils import (
|
16 |
+
get_sequence_parallel_group,
|
17 |
+
get_sequence_parallel_size,
|
18 |
+
get_sequence_parallel_rank
|
19 |
+
)
|
20 |
+
except (ModuleNotFoundError, ImportError):
|
21 |
+
# 从 utils 获取seq parallel设置,import不成功默认为不开启
|
22 |
+
get_sequence_parallel_group = lambda : None
|
23 |
+
get_sequence_parallel_size = lambda : 1
|
24 |
+
get_sequence_parallel_rank = lambda : 0
|
25 |
+
|
26 |
+
|
27 |
+
def single_all_to_all(input, scatter_idx, gather_idx, group):
|
28 |
+
seq_world_size = dist.get_world_size(group)
|
29 |
+
inp_shape = list(input.shape)
|
30 |
+
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
|
31 |
+
if scatter_idx < 2:
|
32 |
+
input_t = input.reshape(
|
33 |
+
[seq_world_size, inp_shape[scatter_idx]] + \
|
34 |
+
inp_shape[scatter_idx + 1:]
|
35 |
+
).contiguous()
|
36 |
+
else:
|
37 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
38 |
+
input_t = input.reshape(
|
39 |
+
[-1, seq_world_size, inp_shape[scatter_idx]] + \
|
40 |
+
inp_shape[scatter_idx + 1:]
|
41 |
+
).transpose(0, 1).contiguous()
|
42 |
+
|
43 |
+
output = torch.empty_like(input_t)
|
44 |
+
dist.all_to_all_single(output, input_t, group=group)
|
45 |
+
|
46 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
47 |
+
# [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
|
48 |
+
# [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
|
49 |
+
if scatter_idx < 2:
|
50 |
+
output = output.transpose(0, 1).transpose(1, 2).contiguous()
|
51 |
+
|
52 |
+
return output.reshape(
|
53 |
+
inp_shape[: gather_idx] + \
|
54 |
+
[inp_shape[gather_idx] * seq_world_size,] + \
|
55 |
+
inp_shape[gather_idx + 1:]).contiguous()
|
56 |
+
|
57 |
+
|
58 |
+
class _SeqAllToAll(torch.autograd.Function):
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
62 |
+
ctx.group = group
|
63 |
+
ctx.scatter_idx = scatter_idx
|
64 |
+
ctx.gather_idx = gather_idx
|
65 |
+
|
66 |
+
return single_all_to_all(input, scatter_idx, gather_idx, group)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
70 |
+
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
|
71 |
+
|
72 |
+
|
73 |
+
# import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
74 |
+
# but fix some bugs for 符合训练的维度设置
|
75 |
+
class DistributedAttention(nn.Module):
|
76 |
+
"""Initialization.
|
77 |
+
|
78 |
+
Arguments:
|
79 |
+
local_attention (Module): local attention with q,k,v
|
80 |
+
sequence_process_group (ProcessGroup): sequence parallel process group
|
81 |
+
scatter_idx (int): scatter_idx for all2all comm
|
82 |
+
gather_idx (int): gather_idx for all2all comm
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
local_attention: nn.Module,
|
88 |
+
sequence_process_group: 'dist.ProcessGroup',
|
89 |
+
scatter_idx: int = 2,
|
90 |
+
gather_idx: int = 0,
|
91 |
+
) -> None:
|
92 |
+
|
93 |
+
super(DistributedAttention, self).__init__()
|
94 |
+
self.local_attn = local_attention
|
95 |
+
self.spg = sequence_process_group
|
96 |
+
self.scatter_idx = scatter_idx
|
97 |
+
self.gather_idx = gather_idx
|
98 |
+
|
99 |
+
def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
|
100 |
+
# 将输入的head 维度pad到sp_size的倍数
|
101 |
+
sp_size = torch.distributed.get_world_size(self.spg)
|
102 |
+
pad_size = (sp_size - query.size(1) % sp_size) % sp_size
|
103 |
+
if pad_size > 0:
|
104 |
+
# [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
|
105 |
+
query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
|
106 |
+
key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
|
107 |
+
value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
|
108 |
+
return query, key, value
|
109 |
+
|
110 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
|
111 |
+
""" forward
|
112 |
+
|
113 |
+
Arguments:
|
114 |
+
query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
|
115 |
+
key (Tensor): key input to the layer
|
116 |
+
value (Tensor): value input to the layer
|
117 |
+
args: other args
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
* output (Tensor): context output
|
121 |
+
"""
|
122 |
+
# TODO Merge three alltoall calls into one
|
123 |
+
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
124 |
+
# [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
|
125 |
+
origin_num_head = query.size(1)
|
126 |
+
query, key, value = self.pad_attention_head(query,key,value)
|
127 |
+
|
128 |
+
query = query.transpose(1,2).transpose(0,1)
|
129 |
+
key = key.transpose(1,2).transpose(0,1)
|
130 |
+
value = value.transpose(1,2).transpose(0,1)
|
131 |
+
#in shape : e.g., [s/p,bs,h,head_dim]
|
132 |
+
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
133 |
+
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
134 |
+
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
135 |
+
|
136 |
+
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
|
137 |
+
context_layer = context_layer.transpose(0,1).contiguous()
|
138 |
+
# [seq_len, batch_size, num_head, head_dim]
|
139 |
+
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
140 |
+
return output.transpose(0,1)[:,:,:origin_num_head,:]
|
141 |
+
|
142 |
+
|
143 |
+
class LocalAttention(nn.Module):
|
144 |
+
def __init__(self, hidden_size, num_heads, head_dim):
|
145 |
+
super().__init__()
|
146 |
+
self.hidden_size = hidden_size
|
147 |
+
self.num_heads = num_heads
|
148 |
+
self.head_dim = head_dim
|
149 |
+
|
150 |
+
def forward(self, q, k, v, *args, use_flash=True, **kwargs):
|
151 |
+
# input q,k,v [batch_size, num_head, seq_len, head_dim]
|
152 |
+
# output [batch_size, seq_len, num_head, head_dim]
|
153 |
+
if use_flash:
|
154 |
+
q_len, num_heads = q.shape[2], q.shape[1]
|
155 |
+
q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
156 |
+
k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
157 |
+
v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
158 |
+
return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
|
159 |
+
else:
|
160 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
161 |
+
attn_output = F.scaled_dot_product_attention(
|
162 |
+
q,k,v, *args, **kwargs)
|
163 |
+
attn_output = attn_output.transpose(1, 2)
|
164 |
+
return attn_output
|
165 |
+
|
166 |
+
|
167 |
+
def create_attention_layer(hidden_size, num_heads, head_dim):
|
168 |
+
if get_sequence_parallel_group() is None:
|
169 |
+
return LocalAttention(hidden_size, num_heads, head_dim)
|
170 |
+
else:
|
171 |
+
return DistributedAttention(
|
172 |
+
local_attention=LocalAttention(hidden_size, num_heads, head_dim),
|
173 |
+
sequence_process_group=get_sequence_parallel_group()
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
|
178 |
+
assert tensor.size(dim) % get_sequence_parallel_size() == 0
|
179 |
+
original_size = tensor.size(dim)
|
180 |
+
if shift:
|
181 |
+
tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
|
182 |
+
if get_sequence_parallel_group() is None:
|
183 |
+
return tensor
|
184 |
+
else:
|
185 |
+
chunk_size = original_size // get_sequence_parallel_size()
|
186 |
+
return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]
|
special_tokens_map.json
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>",
|
16 |
+
"<B_SYS>",
|
17 |
+
"<B_USYS>",
|
18 |
+
"<C_Q>",
|
19 |
+
"<C_A>",
|
20 |
+
"<B_FUNC>",
|
21 |
+
"<B_CODE>",
|
22 |
+
"<B_APE>",
|
23 |
+
"<function_calling>",
|
24 |
+
"<calc_start>",
|
25 |
+
"<calc_end>",
|
26 |
+
"<inner_think>",
|
27 |
+
"<audio_start_baichuan>",
|
28 |
+
"<audio_end_baichuan>",
|
29 |
+
"<audio_pad_baichuan>",
|
30 |
+
"<img_start_baichuan>",
|
31 |
+
"<img_end_baichuan>",
|
32 |
+
"<img_pad_baichuan>",
|
33 |
+
"<img_newline_baichuan>",
|
34 |
+
"<box_start_baichuan>",
|
35 |
+
"<box_end_baichuan>",
|
36 |
+
"<box_delim_baichuan>",
|
37 |
+
"<ref_start_baichuan>",
|
38 |
+
"<ref_end_baichuan>",
|
39 |
+
"<img_delim_baichuan>",
|
40 |
+
"<polygon_start_baichuan>",
|
41 |
+
"<polygon_end_baichuan>",
|
42 |
+
"<baichuan_pad_token>",
|
43 |
+
"<reserved_113>",
|
44 |
+
"<audio_delim_baichuan>",
|
45 |
+
"<video_start_baichuan>",
|
46 |
+
"<video_end_baichuan>",
|
47 |
+
"<video_palce_baichuan>",
|
48 |
+
"<audiotext_start_baichuan>",
|
49 |
+
"<audiotext_end_baichuan>",
|
50 |
+
"<audiotext_pad_baichuan>",
|
51 |
+
"<audiogen_start_baichuan>",
|
52 |
+
"<audiogen_end_baichuan>"
|
53 |
+
],
|
54 |
+
"eos_token": {
|
55 |
+
"content": "<|endoftext|>",
|
56 |
+
"lstrip": false,
|
57 |
+
"normalized": false,
|
58 |
+
"rstrip": false,
|
59 |
+
"single_word": false
|
60 |
+
},
|
61 |
+
"pad_token": {
|
62 |
+
"content": "<|endoftext|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false
|
67 |
+
}
|
68 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"151643": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"151644": {
|
13 |
+
"content": "<|im_start|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"151645": {
|
21 |
+
"content": "<|im_end|>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
},
|
28 |
+
"151646": {
|
29 |
+
"content": "<B_SYS>",
|
30 |
+
"lstrip": false,
|
31 |
+
"normalized": false,
|
32 |
+
"rstrip": false,
|
33 |
+
"single_word": false,
|
34 |
+
"special": true
|
35 |
+
},
|
36 |
+
"151647": {
|
37 |
+
"content": "<B_USYS>",
|
38 |
+
"lstrip": false,
|
39 |
+
"normalized": false,
|
40 |
+
"rstrip": false,
|
41 |
+
"single_word": false,
|
42 |
+
"special": true
|
43 |
+
},
|
44 |
+
"151648": {
|
45 |
+
"content": "<C_Q>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false,
|
50 |
+
"special": true
|
51 |
+
},
|
52 |
+
"151649": {
|
53 |
+
"content": "<C_A>",
|
54 |
+
"lstrip": false,
|
55 |
+
"normalized": false,
|
56 |
+
"rstrip": false,
|
57 |
+
"single_word": false,
|
58 |
+
"special": true
|
59 |
+
},
|
60 |
+
"151650": {
|
61 |
+
"content": "<B_FUNC>",
|
62 |
+
"lstrip": false,
|
63 |
+
"normalized": false,
|
64 |
+
"rstrip": false,
|
65 |
+
"single_word": false,
|
66 |
+
"special": true
|
67 |
+
},
|
68 |
+
"151651": {
|
69 |
+
"content": "<B_CODE>",
|
70 |
+
"lstrip": false,
|
71 |
+
"normalized": false,
|
72 |
+
"rstrip": false,
|
73 |
+
"single_word": false,
|
74 |
+
"special": true
|
75 |
+
},
|
76 |
+
"151652": {
|
77 |
+
"content": "<B_APE>",
|
78 |
+
"lstrip": false,
|
79 |
+
"normalized": false,
|
80 |
+
"rstrip": false,
|
81 |
+
"single_word": true,
|
82 |
+
"special": true
|
83 |
+
},
|
84 |
+
"151653": {
|
85 |
+
"content": "<function_calling>",
|
86 |
+
"lstrip": false,
|
87 |
+
"normalized": false,
|
88 |
+
"rstrip": false,
|
89 |
+
"single_word": true,
|
90 |
+
"special": true
|
91 |
+
},
|
92 |
+
"151654": {
|
93 |
+
"content": "<calc_start>",
|
94 |
+
"lstrip": false,
|
95 |
+
"normalized": false,
|
96 |
+
"rstrip": false,
|
97 |
+
"single_word": true,
|
98 |
+
"special": true
|
99 |
+
},
|
100 |
+
"151655": {
|
101 |
+
"content": "<calc_end>",
|
102 |
+
"lstrip": false,
|
103 |
+
"normalized": false,
|
104 |
+
"rstrip": false,
|
105 |
+
"single_word": true,
|
106 |
+
"special": true
|
107 |
+
},
|
108 |
+
"151656": {
|
109 |
+
"content": "<inner_think>",
|
110 |
+
"lstrip": false,
|
111 |
+
"normalized": false,
|
112 |
+
"rstrip": false,
|
113 |
+
"single_word": true,
|
114 |
+
"special": true
|
115 |
+
},
|
116 |
+
"151657": {
|
117 |
+
"content": "<audio_start_baichuan>",
|
118 |
+
"lstrip": false,
|
119 |
+
"normalized": false,
|
120 |
+
"rstrip": false,
|
121 |
+
"single_word": false,
|
122 |
+
"special": true
|
123 |
+
},
|
124 |
+
"151658": {
|
125 |
+
"content": "<audio_end_baichuan>",
|
126 |
+
"lstrip": false,
|
127 |
+
"normalized": false,
|
128 |
+
"rstrip": false,
|
129 |
+
"single_word": false,
|
130 |
+
"special": true
|
131 |
+
},
|
132 |
+
"151659": {
|
133 |
+
"content": "<audio_pad_baichuan>",
|
134 |
+
"lstrip": false,
|
135 |
+
"normalized": false,
|
136 |
+
"rstrip": false,
|
137 |
+
"single_word": false,
|
138 |
+
"special": true
|
139 |
+
},
|
140 |
+
"151660": {
|
141 |
+
"content": "<img_start_baichuan>",
|
142 |
+
"lstrip": false,
|
143 |
+
"normalized": false,
|
144 |
+
"rstrip": false,
|
145 |
+
"single_word": false,
|
146 |
+
"special": true
|
147 |
+
},
|
148 |
+
"151661": {
|
149 |
+
"content": "<img_end_baichuan>",
|
150 |
+
"lstrip": false,
|
151 |
+
"normalized": false,
|
152 |
+
"rstrip": false,
|
153 |
+
"single_word": false,
|
154 |
+
"special": true
|
155 |
+
},
|
156 |
+
"151662": {
|
157 |
+
"content": "<img_pad_baichuan>",
|
158 |
+
"lstrip": false,
|
159 |
+
"normalized": false,
|
160 |
+
"rstrip": false,
|
161 |
+
"single_word": false,
|
162 |
+
"special": true
|
163 |
+
},
|
164 |
+
"151663": {
|
165 |
+
"content": "<img_newline_baichuan>",
|
166 |
+
"lstrip": false,
|
167 |
+
"normalized": false,
|
168 |
+
"rstrip": false,
|
169 |
+
"single_word": false,
|
170 |
+
"special": true
|
171 |
+
},
|
172 |
+
"151664": {
|
173 |
+
"content": "<box_start_baichuan>",
|
174 |
+
"lstrip": false,
|
175 |
+
"normalized": false,
|
176 |
+
"rstrip": false,
|
177 |
+
"single_word": false,
|
178 |
+
"special": true
|
179 |
+
},
|
180 |
+
"151665": {
|
181 |
+
"content": "<box_end_baichuan>",
|
182 |
+
"lstrip": false,
|
183 |
+
"normalized": false,
|
184 |
+
"rstrip": false,
|
185 |
+
"single_word": false,
|
186 |
+
"special": true
|
187 |
+
},
|
188 |
+
"151666": {
|
189 |
+
"content": "<box_delim_baichuan>",
|
190 |
+
"lstrip": false,
|
191 |
+
"normalized": false,
|
192 |
+
"rstrip": false,
|
193 |
+
"single_word": false,
|
194 |
+
"special": true
|
195 |
+
},
|
196 |
+
"151667": {
|
197 |
+
"content": "<ref_start_baichuan>",
|
198 |
+
"lstrip": false,
|
199 |
+
"normalized": false,
|
200 |
+
"rstrip": false,
|
201 |
+
"single_word": false,
|
202 |
+
"special": true
|
203 |
+
},
|
204 |
+
"151668": {
|
205 |
+
"content": "<ref_end_baichuan>",
|
206 |
+
"lstrip": false,
|
207 |
+
"normalized": false,
|
208 |
+
"rstrip": false,
|
209 |
+
"single_word": false,
|
210 |
+
"special": true
|
211 |
+
},
|
212 |
+
"151669": {
|
213 |
+
"content": "<img_delim_baichuan>",
|
214 |
+
"lstrip": false,
|
215 |
+
"normalized": false,
|
216 |
+
"rstrip": false,
|
217 |
+
"single_word": false,
|
218 |
+
"special": true
|
219 |
+
},
|
220 |
+
"151670": {
|
221 |
+
"content": "<polygon_start_baichuan>",
|
222 |
+
"lstrip": false,
|
223 |
+
"normalized": false,
|
224 |
+
"rstrip": false,
|
225 |
+
"single_word": false,
|
226 |
+
"special": true
|
227 |
+
},
|
228 |
+
"151671": {
|
229 |
+
"content": "<polygon_end_baichuan>",
|
230 |
+
"lstrip": false,
|
231 |
+
"normalized": false,
|
232 |
+
"rstrip": false,
|
233 |
+
"single_word": false,
|
234 |
+
"special": true
|
235 |
+
},
|
236 |
+
"151672": {
|
237 |
+
"content": "<baichuan_pad_token>",
|
238 |
+
"lstrip": false,
|
239 |
+
"normalized": false,
|
240 |
+
"rstrip": false,
|
241 |
+
"single_word": false,
|
242 |
+
"special": true
|
243 |
+
},
|
244 |
+
"151673": {
|
245 |
+
"content": "<reserved_113>",
|
246 |
+
"lstrip": false,
|
247 |
+
"normalized": false,
|
248 |
+
"rstrip": false,
|
249 |
+
"single_word": false,
|
250 |
+
"special": true
|
251 |
+
},
|
252 |
+
"151674": {
|
253 |
+
"content": "<audio_delim_baichuan>",
|
254 |
+
"lstrip": false,
|
255 |
+
"normalized": false,
|
256 |
+
"rstrip": false,
|
257 |
+
"single_word": false,
|
258 |
+
"special": true
|
259 |
+
},
|
260 |
+
"151675": {
|
261 |
+
"content": "<audiotext_start_baichuan>",
|
262 |
+
"lstrip": false,
|
263 |
+
"normalized": false,
|
264 |
+
"rstrip": false,
|
265 |
+
"single_word": false,
|
266 |
+
"special": true
|
267 |
+
},
|
268 |
+
"151676": {
|
269 |
+
"content": "<audiotext_end_baichuan>",
|
270 |
+
"lstrip": false,
|
271 |
+
"normalized": false,
|
272 |
+
"rstrip": false,
|
273 |
+
"single_word": false,
|
274 |
+
"special": true
|
275 |
+
},
|
276 |
+
"151677": {
|
277 |
+
"content": "<audiotext_pad_baichuan>",
|
278 |
+
"lstrip": false,
|
279 |
+
"normalized": false,
|
280 |
+
"rstrip": false,
|
281 |
+
"single_word": false,
|
282 |
+
"special": true
|
283 |
+
},
|
284 |
+
"151678": {
|
285 |
+
"content": "<audiogen_start_baichuan>",
|
286 |
+
"lstrip": false,
|
287 |
+
"normalized": false,
|
288 |
+
"rstrip": false,
|
289 |
+
"single_word": false,
|
290 |
+
"special": true
|
291 |
+
},
|
292 |
+
"151679": {
|
293 |
+
"content": "<audiogen_end_baichuan>",
|
294 |
+
"lstrip": false,
|
295 |
+
"normalized": false,
|
296 |
+
"rstrip": false,
|
297 |
+
"single_word": false,
|
298 |
+
"special": true
|
299 |
+
}
|
300 |
+
},
|
301 |
+
"additional_special_tokens": [
|
302 |
+
"<|im_start|>",
|
303 |
+
"<|im_end|>",
|
304 |
+
"<B_SYS>",
|
305 |
+
"<B_USYS>",
|
306 |
+
"<C_Q>",
|
307 |
+
"<C_A>",
|
308 |
+
"<B_FUNC>",
|
309 |
+
"<B_CODE>",
|
310 |
+
"<B_APE>",
|
311 |
+
"<function_calling>",
|
312 |
+
"<calc_start>",
|
313 |
+
"<calc_end>",
|
314 |
+
"<inner_think>",
|
315 |
+
"<audio_start_baichuan>",
|
316 |
+
"<audio_end_baichuan>",
|
317 |
+
"<audio_pad_baichuan>",
|
318 |
+
"<img_start_baichuan>",
|
319 |
+
"<img_end_baichuan>",
|
320 |
+
"<img_pad_baichuan>",
|
321 |
+
"<img_newline_baichuan>",
|
322 |
+
"<box_start_baichuan>",
|
323 |
+
"<box_end_baichuan>",
|
324 |
+
"<box_delim_baichuan>",
|
325 |
+
"<ref_start_baichuan>",
|
326 |
+
"<ref_end_baichuan>",
|
327 |
+
"<img_delim_baichuan>",
|
328 |
+
"<polygon_start_baichuan>",
|
329 |
+
"<polygon_end_baichuan>",
|
330 |
+
"<baichuan_pad_token>",
|
331 |
+
"<reserved_113>",
|
332 |
+
"<audio_delim_baichuan>",
|
333 |
+
"<audiotext_start_baichuan>",
|
334 |
+
"<audiotext_end_baichuan>",
|
335 |
+
"<audiotext_pad_baichuan>",
|
336 |
+
"<audiogen_start_baichuan>",
|
337 |
+
"<audiogen_end_baichuan>"
|
338 |
+
],
|
339 |
+
"bos_token": null,
|
340 |
+
"chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}{{'<B_SYS>' + message['content']}}{% elif message['role'] == 'user_system' %}{{'<B_USYS>' + message['content']}}{% elif message['role'] == 'user' %}{{'<H_Q>' + message['content']}}{% elif message['role'] == 'assistant' %}{{'<H_A>' + message['content']}}{% elif message['role'] == 'function' %}{{'<B_FUNC>' + message['content']}}{% elif message['role'] == 'code' %}{{'<B_CODE>' + message['content']}}{% else %}{{ raise_exception('Invalid message role: ' + message['role']) }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<H_A>'}}{% endif %}",
|
341 |
+
"clean_up_tokenization_spaces": false,
|
342 |
+
"eos_token": "<|endoftext|>",
|
343 |
+
"errors": "replace",
|
344 |
+
"model_max_length": 8192,
|
345 |
+
"pad_token": "<|endoftext|>",
|
346 |
+
"split_special_tokens": false,
|
347 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
348 |
+
"unk_token": null
|
349 |
+
}
|
vector_quantize.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, random
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
|
7 |
+
def uniform_init(*shape):
|
8 |
+
t = torch.zeros(shape)
|
9 |
+
nn.init.kaiming_uniform_(t)
|
10 |
+
return t
|
11 |
+
|
12 |
+
def cdist(x, y):
|
13 |
+
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
|
14 |
+
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
|
15 |
+
xy = torch.einsum('bd,cd->bc', x, y) * -2
|
16 |
+
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
|
17 |
+
|
18 |
+
def get_sequence_mask(inputs, inputs_length):
|
19 |
+
if inputs.dim() == 3:
|
20 |
+
bsz, tgt_len, _ = inputs.size()
|
21 |
+
else:
|
22 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
23 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
24 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
25 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
26 |
+
return sequence_mask, unpacking_index
|
27 |
+
|
28 |
+
|
29 |
+
class EuclideanCodebook(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim,
|
33 |
+
codebook_size,
|
34 |
+
init_std=0.02,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.init_std = init_std
|
38 |
+
self.dim = dim
|
39 |
+
self.codebook_size = codebook_size
|
40 |
+
|
41 |
+
embed = uniform_init(codebook_size, dim).to(torch.float32)
|
42 |
+
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
|
43 |
+
self.embed_avg = nn.Parameter(embed.clone())
|
44 |
+
self.embed = nn.Parameter(embed)
|
45 |
+
del embed
|
46 |
+
|
47 |
+
@autocast(enabled=True, dtype=torch.float32)
|
48 |
+
@torch.no_grad()
|
49 |
+
def forward(self, x):
|
50 |
+
assert(len(x.shape) == 2)
|
51 |
+
assert(x.dtype == torch.float32)
|
52 |
+
embed = self.embed.detach().to(x.device)
|
53 |
+
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
|
54 |
+
embed_ind = dist.argmax(dim=-1)
|
55 |
+
quantize = embed[embed_ind] # (bs*sl, d)
|
56 |
+
return quantize, embed_ind, dist
|
57 |
+
|
58 |
+
class VectorQuantize(nn.Module):
|
59 |
+
def __init__(self, config, *args, **kwargs):
|
60 |
+
super().__init__(*args, **kwargs)
|
61 |
+
self.config = config
|
62 |
+
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
|
63 |
+
|
64 |
+
def forward(self, x, input_length):
|
65 |
+
batch_size, seq_len, _ = x.shape
|
66 |
+
mask, unpacking_index = get_sequence_mask(x, input_length)
|
67 |
+
if x.dtype != torch.float32:
|
68 |
+
x = x.to(torch.float32)
|
69 |
+
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
|
70 |
+
quantize, embed_ind, _ = self.codebook(x)
|
71 |
+
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
|
72 |
+
quantize = torch.where(mask, quantize, 0)
|
73 |
+
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
|
74 |
+
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
|
75 |
+
return quantize, embed_ind
|
76 |
+
|
77 |
+
def get_output_from_indices(self, indices):
|
78 |
+
return self.codebook.embed[indices]
|
visual_modeling_omni.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
import torch, math
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from torch import nn
|
6 |
+
import transformers
|
7 |
+
from flash_attn import flash_attn_varlen_func
|
8 |
+
from transformers.activations import ACT2FN
|
9 |
+
from PIL import Image
|
10 |
+
import io, fire
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
self.config_attn_implementation = 'flash_attention_2'
|
17 |
+
self.gradient_checkpointing = True # 强制开启
|
18 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
19 |
+
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
|
20 |
+
del self.merger
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
pixel_values: torch.Tensor,
|
25 |
+
grid_thw: torch.Tensor,
|
26 |
+
):
|
27 |
+
hidden_states = pixel_values.to(self.get_dtype())
|
28 |
+
grid_thw = grid_thw.to(pixel_values.device)
|
29 |
+
|
30 |
+
hidden_states = self.patch_embed(hidden_states)
|
31 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
32 |
+
|
33 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
34 |
+
dim=0, dtype=torch.int32
|
35 |
+
)
|
36 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
37 |
+
|
38 |
+
for blk in self.blocks:
|
39 |
+
if self.gradient_checkpointing and self.training:
|
40 |
+
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
|
41 |
+
else:
|
42 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
43 |
+
|
44 |
+
return hidden_states
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def fake_input(self, device):
|
48 |
+
merge_size = max(self.merge_size, self.config.spatial_merge_size)
|
49 |
+
fake_image = torch.zeros([
|
50 |
+
1,
|
51 |
+
self.config.temporal_patch_size,
|
52 |
+
3,
|
53 |
+
merge_size // self.config.spatial_merge_size,
|
54 |
+
self.config.spatial_merge_size,
|
55 |
+
self.config.patch_size,
|
56 |
+
merge_size // self.config.spatial_merge_size,
|
57 |
+
self.config.spatial_merge_size,
|
58 |
+
self.config.patch_size,
|
59 |
+
], dtype=torch.float32, device=device)
|
60 |
+
patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
61 |
+
flatten_patches = patches.reshape(
|
62 |
+
merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
|
63 |
+
)
|
64 |
+
return [flatten_patches], [(1, merge_size, merge_size)], [1]
|
65 |
+
|
66 |
+
|
67 |
+
class OmniVisualBridge(nn.Module):
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__()
|
70 |
+
self.config = config
|
71 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
72 |
+
self.hidden_size = config.embed_dim * (self.merge_size**2)
|
73 |
+
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
74 |
+
self.mlp = nn.Sequential(
|
75 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Linear(self.hidden_size, config.hidden_size),
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
81 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
fire.Fire()
|
87 |
+
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
zero_to_fp32.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright (c) Microsoft Corporation.
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
# DeepSpeed Team
|
7 |
+
|
8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
11 |
+
# application.
|
12 |
+
#
|
13 |
+
# example: python zero_to_fp32.py . pytorch_model.bin
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import glob
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass
|
23 |
+
|
24 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
25 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
26 |
+
from deepspeed.utils import logger
|
27 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
28 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
29 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class zero_model_state:
|
34 |
+
buffers: dict()
|
35 |
+
param_shapes: dict()
|
36 |
+
shared_params: list
|
37 |
+
ds_version: int
|
38 |
+
frozen_param_shapes: dict()
|
39 |
+
frozen_param_fragments: dict()
|
40 |
+
|
41 |
+
|
42 |
+
debug = 0
|
43 |
+
|
44 |
+
# load to cpu
|
45 |
+
device = torch.device('cpu')
|
46 |
+
|
47 |
+
|
48 |
+
def atoi(text):
|
49 |
+
return int(text) if text.isdigit() else text
|
50 |
+
|
51 |
+
|
52 |
+
def natural_keys(text):
|
53 |
+
'''
|
54 |
+
alist.sort(key=natural_keys) sorts in human order
|
55 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
56 |
+
(See Toothy's implementation in the comments)
|
57 |
+
'''
|
58 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
59 |
+
|
60 |
+
|
61 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
62 |
+
if not os.path.isdir(checkpoint_dir):
|
63 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
64 |
+
|
65 |
+
# there should be only one file
|
66 |
+
if zero_stage <= 2:
|
67 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
68 |
+
elif zero_stage == 3:
|
69 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
70 |
+
|
71 |
+
if not os.path.exists(file):
|
72 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
73 |
+
|
74 |
+
return file
|
75 |
+
|
76 |
+
|
77 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
78 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
79 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
80 |
+
|
81 |
+
if len(ckpt_files) == 0:
|
82 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
83 |
+
|
84 |
+
return ckpt_files
|
85 |
+
|
86 |
+
|
87 |
+
def get_optim_files(checkpoint_dir):
|
88 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
89 |
+
|
90 |
+
|
91 |
+
def get_model_state_files(checkpoint_dir):
|
92 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
93 |
+
|
94 |
+
|
95 |
+
def parse_model_states(files):
|
96 |
+
zero_model_states = []
|
97 |
+
for file in files:
|
98 |
+
state_dict = torch.load(file, map_location=device)
|
99 |
+
|
100 |
+
if BUFFER_NAMES not in state_dict:
|
101 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
102 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
103 |
+
if debug:
|
104 |
+
print("Found buffers:", buffer_names)
|
105 |
+
|
106 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
107 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
108 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
109 |
+
|
110 |
+
# collect parameters that are included in param_shapes
|
111 |
+
param_names = []
|
112 |
+
for s in param_shapes:
|
113 |
+
for name in s.keys():
|
114 |
+
param_names.append(name)
|
115 |
+
|
116 |
+
# update with frozen parameters
|
117 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
118 |
+
if frozen_param_shapes is not None:
|
119 |
+
if debug:
|
120 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
121 |
+
param_names += list(frozen_param_shapes.keys())
|
122 |
+
|
123 |
+
# handle shared params
|
124 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
125 |
+
|
126 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
127 |
+
|
128 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
129 |
+
|
130 |
+
z_model_state = zero_model_state(buffers=buffers,
|
131 |
+
param_shapes=param_shapes,
|
132 |
+
shared_params=shared_params,
|
133 |
+
ds_version=ds_version,
|
134 |
+
frozen_param_shapes=frozen_param_shapes,
|
135 |
+
frozen_param_fragments=frozen_param_fragments)
|
136 |
+
zero_model_states.append(z_model_state)
|
137 |
+
|
138 |
+
return zero_model_states
|
139 |
+
|
140 |
+
|
141 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
142 |
+
|
143 |
+
total_files = len(files)
|
144 |
+
state_dicts = []
|
145 |
+
for f in files:
|
146 |
+
state_dict = torch.load(f, map_location=device)
|
147 |
+
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
148 |
+
# and also handle the case where it was already removed by another helper script
|
149 |
+
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
150 |
+
state_dicts.append(state_dict)
|
151 |
+
|
152 |
+
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
153 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
154 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
155 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
156 |
+
|
157 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
158 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
159 |
+
# use the max of the partition_count to get the dp world_size.
|
160 |
+
|
161 |
+
if type(world_size) is list:
|
162 |
+
world_size = max(world_size)
|
163 |
+
|
164 |
+
if world_size != total_files:
|
165 |
+
raise ValueError(
|
166 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
167 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
168 |
+
)
|
169 |
+
|
170 |
+
# the groups are named differently in each stage
|
171 |
+
if zero_stage <= 2:
|
172 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
173 |
+
elif zero_stage == 3:
|
174 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
175 |
+
else:
|
176 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
177 |
+
|
178 |
+
if zero_stage <= 2:
|
179 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
180 |
+
elif zero_stage == 3:
|
181 |
+
# if there is more than one param group, there will be multiple flattened tensors - one
|
182 |
+
# flattened tensor per group - for simplicity merge them into a single tensor
|
183 |
+
#
|
184 |
+
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
185 |
+
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
186 |
+
|
187 |
+
fp32_flat_groups = [
|
188 |
+
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
189 |
+
]
|
190 |
+
|
191 |
+
return zero_stage, world_size, fp32_flat_groups
|
192 |
+
|
193 |
+
|
194 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
195 |
+
"""
|
196 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
197 |
+
|
198 |
+
Args:
|
199 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
200 |
+
|
201 |
+
"""
|
202 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
203 |
+
|
204 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
205 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
206 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
207 |
+
|
208 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
209 |
+
|
210 |
+
zero_model_states = parse_model_states(model_files)
|
211 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
212 |
+
|
213 |
+
if zero_stage <= 2:
|
214 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
215 |
+
exclude_frozen_parameters)
|
216 |
+
elif zero_stage == 3:
|
217 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
218 |
+
exclude_frozen_parameters)
|
219 |
+
|
220 |
+
|
221 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
222 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
223 |
+
return
|
224 |
+
|
225 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
226 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
227 |
+
|
228 |
+
if debug:
|
229 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
230 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
231 |
+
|
232 |
+
wanted_params = len(frozen_param_shapes)
|
233 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
234 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
235 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
236 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
237 |
+
|
238 |
+
total_params = 0
|
239 |
+
total_numel = 0
|
240 |
+
for name, shape in frozen_param_shapes.items():
|
241 |
+
total_params += 1
|
242 |
+
unpartitioned_numel = shape.numel()
|
243 |
+
total_numel += unpartitioned_numel
|
244 |
+
|
245 |
+
state_dict[name] = frozen_param_fragments[name]
|
246 |
+
|
247 |
+
if debug:
|
248 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
249 |
+
|
250 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
251 |
+
|
252 |
+
|
253 |
+
def _has_callable(obj, fn):
|
254 |
+
attr = getattr(obj, fn, None)
|
255 |
+
return callable(attr)
|
256 |
+
|
257 |
+
|
258 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
259 |
+
param_shapes = zero_model_states[0].param_shapes
|
260 |
+
|
261 |
+
# Reconstruction protocol:
|
262 |
+
#
|
263 |
+
# XXX: document this
|
264 |
+
|
265 |
+
if debug:
|
266 |
+
for i in range(world_size):
|
267 |
+
for j in range(len(fp32_flat_groups[0])):
|
268 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
269 |
+
|
270 |
+
# XXX: memory usage doubles here (zero2)
|
271 |
+
num_param_groups = len(fp32_flat_groups[0])
|
272 |
+
merged_single_partition_of_fp32_groups = []
|
273 |
+
for i in range(num_param_groups):
|
274 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
275 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
276 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
277 |
+
avail_numel = sum(
|
278 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
279 |
+
|
280 |
+
if debug:
|
281 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
282 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
283 |
+
# not asserting if there is a mismatch due to possible padding
|
284 |
+
print(f"Have {avail_numel} numels to process.")
|
285 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
286 |
+
|
287 |
+
# params
|
288 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
289 |
+
# out-of-core computing solution
|
290 |
+
total_numel = 0
|
291 |
+
total_params = 0
|
292 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
293 |
+
offset = 0
|
294 |
+
avail_numel = full_single_fp32_vector.numel()
|
295 |
+
for name, shape in shapes.items():
|
296 |
+
|
297 |
+
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
298 |
+
total_numel += unpartitioned_numel
|
299 |
+
total_params += 1
|
300 |
+
|
301 |
+
if debug:
|
302 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
303 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
304 |
+
offset += unpartitioned_numel
|
305 |
+
|
306 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
307 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
308 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
309 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
310 |
+
align_to = 2 * world_size
|
311 |
+
|
312 |
+
def zero2_align(x):
|
313 |
+
return align_to * math.ceil(x / align_to)
|
314 |
+
|
315 |
+
if debug:
|
316 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
317 |
+
|
318 |
+
offset = zero2_align(offset)
|
319 |
+
avail_numel = zero2_align(avail_numel)
|
320 |
+
|
321 |
+
if debug:
|
322 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
323 |
+
|
324 |
+
# Sanity check
|
325 |
+
if offset != avail_numel:
|
326 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
327 |
+
|
328 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
329 |
+
|
330 |
+
|
331 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
332 |
+
exclude_frozen_parameters):
|
333 |
+
state_dict = OrderedDict()
|
334 |
+
|
335 |
+
# buffers
|
336 |
+
buffers = zero_model_states[0].buffers
|
337 |
+
state_dict.update(buffers)
|
338 |
+
if debug:
|
339 |
+
print(f"added {len(buffers)} buffers")
|
340 |
+
|
341 |
+
if not exclude_frozen_parameters:
|
342 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
343 |
+
|
344 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
345 |
+
|
346 |
+
# recover shared parameters
|
347 |
+
for pair in zero_model_states[0].shared_params:
|
348 |
+
if pair[1] in state_dict:
|
349 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
350 |
+
|
351 |
+
return state_dict
|
352 |
+
|
353 |
+
|
354 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
355 |
+
remainder = unpartitioned_numel % world_size
|
356 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
357 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
358 |
+
return partitioned_numel, padding_numel
|
359 |
+
|
360 |
+
|
361 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
362 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
363 |
+
return
|
364 |
+
|
365 |
+
if debug:
|
366 |
+
for i in range(world_size):
|
367 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
368 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
369 |
+
|
370 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
371 |
+
wanted_params = len(frozen_param_shapes)
|
372 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
373 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
374 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
375 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
376 |
+
|
377 |
+
total_params = 0
|
378 |
+
total_numel = 0
|
379 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
380 |
+
total_params += 1
|
381 |
+
unpartitioned_numel = shape.numel()
|
382 |
+
total_numel += unpartitioned_numel
|
383 |
+
|
384 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
385 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
386 |
+
|
387 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
388 |
+
|
389 |
+
if debug:
|
390 |
+
print(
|
391 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
392 |
+
)
|
393 |
+
|
394 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
395 |
+
|
396 |
+
|
397 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
398 |
+
param_shapes = zero_model_states[0].param_shapes
|
399 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
400 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
401 |
+
# param, re-consolidating each param, while dealing with padding if any
|
402 |
+
|
403 |
+
# merge list of dicts, preserving order
|
404 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
405 |
+
|
406 |
+
if debug:
|
407 |
+
for i in range(world_size):
|
408 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
409 |
+
|
410 |
+
wanted_params = len(param_shapes)
|
411 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
412 |
+
# not asserting if there is a mismatch due to possible padding
|
413 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
414 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
415 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
416 |
+
|
417 |
+
# params
|
418 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
419 |
+
# out-of-core computing solution
|
420 |
+
offset = 0
|
421 |
+
total_numel = 0
|
422 |
+
total_params = 0
|
423 |
+
for name, shape in param_shapes.items():
|
424 |
+
|
425 |
+
unpartitioned_numel = shape.numel()
|
426 |
+
total_numel += unpartitioned_numel
|
427 |
+
total_params += 1
|
428 |
+
|
429 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
430 |
+
|
431 |
+
if debug:
|
432 |
+
print(
|
433 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
434 |
+
)
|
435 |
+
|
436 |
+
# XXX: memory usage doubles here
|
437 |
+
state_dict[name] = torch.cat(
|
438 |
+
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
439 |
+
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
440 |
+
offset += partitioned_numel
|
441 |
+
|
442 |
+
offset *= world_size
|
443 |
+
|
444 |
+
# Sanity check
|
445 |
+
if offset != avail_numel:
|
446 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
447 |
+
|
448 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
449 |
+
|
450 |
+
|
451 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
452 |
+
exclude_frozen_parameters):
|
453 |
+
state_dict = OrderedDict()
|
454 |
+
|
455 |
+
# buffers
|
456 |
+
buffers = zero_model_states[0].buffers
|
457 |
+
state_dict.update(buffers)
|
458 |
+
if debug:
|
459 |
+
print(f"added {len(buffers)} buffers")
|
460 |
+
|
461 |
+
if not exclude_frozen_parameters:
|
462 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
463 |
+
|
464 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
465 |
+
|
466 |
+
# recover shared parameters
|
467 |
+
for pair in zero_model_states[0].shared_params:
|
468 |
+
if pair[1] in state_dict:
|
469 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
470 |
+
|
471 |
+
return state_dict
|
472 |
+
|
473 |
+
|
474 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
|
475 |
+
"""
|
476 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
477 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
478 |
+
via a model hub.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
482 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
483 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
- pytorch ``state_dict``
|
487 |
+
|
488 |
+
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
489 |
+
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
490 |
+
the checkpoint.
|
491 |
+
|
492 |
+
A typical usage might be ::
|
493 |
+
|
494 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
495 |
+
# do the training and checkpoint saving
|
496 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
497 |
+
model = model.cpu() # move to cpu
|
498 |
+
model.load_state_dict(state_dict)
|
499 |
+
# submit to model hub or save the model to share with others
|
500 |
+
|
501 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
502 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
503 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
504 |
+
|
505 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
506 |
+
|
507 |
+
"""
|
508 |
+
if tag is None:
|
509 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
510 |
+
if os.path.isfile(latest_path):
|
511 |
+
with open(latest_path, 'r') as fd:
|
512 |
+
tag = fd.read().strip()
|
513 |
+
else:
|
514 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
515 |
+
|
516 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
517 |
+
|
518 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
519 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
520 |
+
|
521 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
522 |
+
|
523 |
+
|
524 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
|
525 |
+
"""
|
526 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
527 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
531 |
+
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
532 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
533 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
534 |
+
"""
|
535 |
+
|
536 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
|
537 |
+
print(f"Saving fp32 state dict to {output_file}")
|
538 |
+
torch.save(state_dict, output_file)
|
539 |
+
|
540 |
+
|
541 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
542 |
+
"""
|
543 |
+
1. Put the provided model to cpu
|
544 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
545 |
+
3. Load it into the provided model
|
546 |
+
|
547 |
+
Args:
|
548 |
+
- ``model``: the model object to update
|
549 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
550 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
551 |
+
|
552 |
+
Returns:
|
553 |
+
- ``model`: modified model
|
554 |
+
|
555 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
556 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
557 |
+
conveniently placed for you in the checkpoint folder.
|
558 |
+
|
559 |
+
A typical usage might be ::
|
560 |
+
|
561 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
562 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
563 |
+
# submit to model hub or save the model to share with others
|
564 |
+
|
565 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
566 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
567 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
568 |
+
|
569 |
+
"""
|
570 |
+
logger.info(f"Extracting fp32 weights")
|
571 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
572 |
+
|
573 |
+
logger.info(f"Overwriting model with fp32 weights")
|
574 |
+
model = model.cpu()
|
575 |
+
model.load_state_dict(state_dict, strict=False)
|
576 |
+
|
577 |
+
return model
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
|
582 |
+
parser = argparse.ArgumentParser()
|
583 |
+
parser.add_argument("checkpoint_dir",
|
584 |
+
type=str,
|
585 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
586 |
+
parser.add_argument(
|
587 |
+
"output_file",
|
588 |
+
type=str,
|
589 |
+
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
590 |
+
parser.add_argument("-t",
|
591 |
+
"--tag",
|
592 |
+
type=str,
|
593 |
+
default=None,
|
594 |
+
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
595 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
596 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
597 |
+
args = parser.parse_args()
|
598 |
+
|
599 |
+
debug = args.debug
|
600 |
+
|
601 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
602 |
+
args.output_file,
|
603 |
+
tag=args.tag,
|
604 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|