Upload 11 files
Browse files- .gitattributes +35 -35
- README.md +96 -0
- config.json +30 -0
- configuration_intern_vit.py +119 -0
- flash_attention.py +76 -0
- model.safetensors +3 -0
- modeling_intern_vit.py +363 -0
- preprocessor_config.json +19 -0
- triton-test.py +26 -0
- triton_bert_pading.py +224 -0
- triton_flash_atn.py +654 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- laion/laion2B-en
|
5 |
+
- laion/laion-coco
|
6 |
+
- laion/laion2B-multi
|
7 |
+
- kakaobrain/coyo-700m
|
8 |
+
- conceptual_captions
|
9 |
+
- wanng/wukong100m
|
10 |
+
pipeline_tag: image-feature-extraction
|
11 |
+
---
|
12 |
+
|
13 |
+
# Model Card for InternViT-300M-448px
|
14 |
+
<p align="center">
|
15 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/AUE-3OBtfr9vDA7Elgkhd.webp" alt="Image Description" width="300" height="300">
|
16 |
+
</p>
|
17 |
+
|
18 |
+
[\[🆕 Blog\]](https://internvl.github.io/blog/) [\[📜 InternVL 1.0 Paper\]](https://arxiv.org/abs/2312.14238) [\[📜 InternVL 1.5 Report\]](https://arxiv.org/abs/2404.16821) [\[🗨️ Chat Demo\]](https://internvl.opengvlab.com/)
|
19 |
+
|
20 |
+
[\[🤗 HF Demo\]](https://huggingface.co/spaces/OpenGVLab/InternVL) [\[🚀 Quick Start\]](#model-usage) [\[🌐 Community-hosted API\]](https://rapidapi.com/adushar1320/api/internvl-chat) [\[📖 中文解读\]](https://zhuanlan.zhihu.com/p/675877376)
|
21 |
+
|
22 |
+
This update primarily focuses on enhancing the efficiency of the vision foundation model. We developed InternViT-300M-448px by distilling knowledge from the robust vision foundation model, [InternViT-6B-448px-V1-5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5). Like its predecessor, InternViT-300M-448px features a dynamic input resolution of 448×448, with a basic tile size of 448×448. During training, it allows for 1 to 12 tiles, and expands to 1 to 40 tiles during testing. Additionally, it inherits the powerful robustness, OCR capability, and high-resolution processing capacity from InternViT-6B-448px-V1-5.
|
23 |
+
|
24 |
+
## Model Details
|
25 |
+
- **Model Type:** vision foundation model, feature backbone
|
26 |
+
- **Model Stats:**
|
27 |
+
- Params (M): 304
|
28 |
+
- Image size: 448 x 448, training with 1 - 12 tiles
|
29 |
+
- **Pretrain Dataset:** LAION-en, LAION-zh, COYO, GRIT, COCO, TextCaps, Objects365, OpenImages, All-Seeing, Wukong-OCR, LaionCOCO-OCR, and other OCR-related datasets.
|
30 |
+
To enhance the OCR capability of the model, we have incorporated additional OCR data alongside the general caption datasets. Specifically, we utilized PaddleOCR to perform Chinese OCR on images from Wukong and English OCR on images from LAION-COCO.
|
31 |
+
|
32 |
+
## Released Models
|
33 |
+
### Vision Foundation model
|
34 |
+
| Model | Date | Download | Note |
|
35 |
+
| ----------------------- | ---------- | ---------------------------------------------------------------------- | -------------------------------- |
|
36 |
+
| InternViT-6B-448px-V1-5 | 2024.04.20 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5) | support dynamic resolution, super strong OCR (🔥new) |
|
37 |
+
| InternViT-6B-448px-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-2) | 448 resolution |
|
38 |
+
| InternViT-6B-448px-V1-0 | 2024.01.30 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-0) | 448 resolution |
|
39 |
+
| InternViT-6B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-224px) | vision foundation model |
|
40 |
+
| InternVL-14B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-14B-224px) | vision-language foundation model |
|
41 |
+
|
42 |
+
### Multimodal Large Language Model (MLLM)
|
43 |
+
| Model | Date | Download | Note |
|
44 |
+
| ----------------------- | ---------- | --------------------------------------------------------------------------- | ---------------------------------- |
|
45 |
+
| InternVL-Chat-V1-5 | 2024.04.18 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) | support 4K image; super strong OCR; Approaching the performance of GPT-4V and Gemini Pro on various benchmarks like MMMU, DocVQA, ChartQA, MathVista, etc. (🔥new)|
|
46 |
+
| InternVL-Chat-V1-2-Plus | 2024.02.21 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2-Plus) | more SFT data and stronger |
|
47 |
+
| InternVL-Chat-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2) | scaling up LLM to 34B |
|
48 |
+
| InternVL-Chat-V1-1 | 2024.01.24 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-1) | support Chinese and stronger OCR |
|
49 |
+
|
50 |
+
## Model Usage (Image Embeddings)
|
51 |
+
|
52 |
+
```python
|
53 |
+
import torch
|
54 |
+
from PIL import Image
|
55 |
+
from transformers import AutoModel, CLIPImageProcessor
|
56 |
+
|
57 |
+
model = AutoModel.from_pretrained(
|
58 |
+
'OpenGVLab/InternViT-300M-448px',
|
59 |
+
torch_dtype=torch.bfloat16,
|
60 |
+
low_cpu_mem_usage=True,
|
61 |
+
trust_remote_code=True).cuda().eval()
|
62 |
+
|
63 |
+
image = Image.open('./examples/image1.jpg').convert('RGB')
|
64 |
+
|
65 |
+
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-300M-448px')
|
66 |
+
|
67 |
+
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
|
68 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
69 |
+
|
70 |
+
outputs = model(pixel_values)
|
71 |
+
```
|
72 |
+
|
73 |
+
## Citation
|
74 |
+
|
75 |
+
If you find this project useful in your research, please consider citing:
|
76 |
+
|
77 |
+
```BibTeX
|
78 |
+
@article{chen2023internvl,
|
79 |
+
title={InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks},
|
80 |
+
author={Chen, Zhe and Wu, Jiannan and Wang, Wenhai and Su, Weijie and Chen, Guo and Xing, Sen and Zhong, Muyan and Zhang, Qinglong and Zhu, Xizhou and Lu, Lewei and Li, Bin and Luo, Ping and Lu, Tong and Qiao, Yu and Dai, Jifeng},
|
81 |
+
journal={arXiv preprint arXiv:2312.14238},
|
82 |
+
year={2023}
|
83 |
+
}
|
84 |
+
@article{chen2024far,
|
85 |
+
title={How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites},
|
86 |
+
author={Chen, Zhe and Wang, Weiyun and Tian, Hao and Ye, Shenglong and Gao, Zhangwei and Cui, Erfei and Tong, Wenwen and Hu, Kongzhi and Luo, Jiapeng and Ma, Zheng and others},
|
87 |
+
journal={arXiv preprint arXiv:2404.16821},
|
88 |
+
year={2024}
|
89 |
+
}
|
90 |
+
|
91 |
+
```
|
92 |
+
|
93 |
+
|
94 |
+
## Acknowledgement
|
95 |
+
|
96 |
+
InternVL is built with reference to the code of the following projects: [OpenAI CLIP](https://github.com/openai/CLIP), [Open CLIP](https://github.com/mlfoundations/open_clip), [CLIP Benchmark](https://github.com/LAION-AI/CLIP_benchmark), [EVA](https://github.com/baaivision/EVA/tree/master), [InternImage](https://github.com/OpenGVLab/InternImage), [ViT-Adapter](https://github.com/czczup/ViT-Adapter), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [Transformers](https://github.com/huggingface/transformers), [DINOv2](https://github.com/facebookresearch/dinov2), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [Qwen-VL](https://github.com/QwenLM/Qwen-VL/tree/master/eval_mm), and [LLaVA-1.5](https://github.com/haotian-liu/LLaVA). Thanks for their awesome work!
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"InternVisionModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_intern_vit.InternVisionConfig",
|
7 |
+
"AutoModel": "modeling_intern_vit.InternVisionModel"
|
8 |
+
},
|
9 |
+
"attention_dropout": 0.0,
|
10 |
+
"drop_path_rate": 0.1,
|
11 |
+
"dropout": 0.0,
|
12 |
+
"hidden_act": "gelu",
|
13 |
+
"hidden_size": 1024,
|
14 |
+
"image_size": 448,
|
15 |
+
"initializer_factor": 1.0,
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"intermediate_size": 4096,
|
18 |
+
"layer_norm_eps": 1e-06,
|
19 |
+
"model_type": "intern_vit_6b",
|
20 |
+
"norm_type": "layer_norm",
|
21 |
+
"num_attention_heads": 16,
|
22 |
+
"num_channels": 3,
|
23 |
+
"num_hidden_layers": 24,
|
24 |
+
"patch_size": 14,
|
25 |
+
"qk_normalization": false,
|
26 |
+
"qkv_bias": true,
|
27 |
+
"torch_dtype": "bfloat16",
|
28 |
+
"transformers_version": "4.37.2",
|
29 |
+
"use_flash_attn": true
|
30 |
+
}
|
configuration_intern_vit.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import os
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from transformers.configuration_utils import PretrainedConfig
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class InternVisionConfig(PretrainedConfig):
|
16 |
+
r"""
|
17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
19 |
+
|
20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
21 |
+
documentation from [`PretrainedConfig`] for more information.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
27 |
+
The size (resolution) of each patch.
|
28 |
+
image_size (`int`, *optional*, defaults to 224):
|
29 |
+
The size (resolution) of each image.
|
30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
41 |
+
Number of hidden layers in the Transformer encoder.
|
42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
43 |
+
Whether to use flash attention mechanism.
|
44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
48 |
+
The epsilon used by the layer normalization layers.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
52 |
+
Dropout rate for stochastic depth.
|
53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
54 |
+
The dropout ratio for the attention probabilities.
|
55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
58 |
+
A factor for layer scale.
|
59 |
+
"""
|
60 |
+
|
61 |
+
model_type = 'intern_vit_6b'
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
num_channels=3,
|
66 |
+
patch_size=14,
|
67 |
+
image_size=224,
|
68 |
+
qkv_bias=False,
|
69 |
+
hidden_size=3200,
|
70 |
+
num_attention_heads=25,
|
71 |
+
intermediate_size=12800,
|
72 |
+
qk_normalization=True,
|
73 |
+
num_hidden_layers=48,
|
74 |
+
use_flash_attn=True,
|
75 |
+
hidden_act='gelu',
|
76 |
+
norm_type='rms_norm',
|
77 |
+
layer_norm_eps=1e-6,
|
78 |
+
dropout=0.0,
|
79 |
+
drop_path_rate=0.0,
|
80 |
+
attention_dropout=0.0,
|
81 |
+
initializer_range=0.02,
|
82 |
+
initializer_factor=0.1,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
super().__init__(**kwargs)
|
86 |
+
|
87 |
+
self.hidden_size = hidden_size
|
88 |
+
self.intermediate_size = intermediate_size
|
89 |
+
self.dropout = dropout
|
90 |
+
self.drop_path_rate = drop_path_rate
|
91 |
+
self.num_hidden_layers = num_hidden_layers
|
92 |
+
self.num_attention_heads = num_attention_heads
|
93 |
+
self.num_channels = num_channels
|
94 |
+
self.patch_size = patch_size
|
95 |
+
self.image_size = image_size
|
96 |
+
self.initializer_range = initializer_range
|
97 |
+
self.initializer_factor = initializer_factor
|
98 |
+
self.attention_dropout = attention_dropout
|
99 |
+
self.layer_norm_eps = layer_norm_eps
|
100 |
+
self.hidden_act = hidden_act
|
101 |
+
self.norm_type = norm_type
|
102 |
+
self.qkv_bias = qkv_bias
|
103 |
+
self.qk_normalization = qk_normalization
|
104 |
+
self.use_flash_attn = use_flash_attn
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
108 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
109 |
+
|
110 |
+
if 'vision_config' in config_dict:
|
111 |
+
config_dict = config_dict['vision_config']
|
112 |
+
|
113 |
+
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
|
114 |
+
logger.warning(
|
115 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
116 |
+
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
|
117 |
+
)
|
118 |
+
|
119 |
+
return cls.from_dict(config_dict, **kwargs)
|
flash_attention.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
try: # v1
|
7 |
+
from flash_attn.flash_attn_interface import \
|
8 |
+
flash_attn_unpadded_qkvpacked_func
|
9 |
+
except: # v2
|
10 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
11 |
+
|
12 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
13 |
+
|
14 |
+
|
15 |
+
class FlashAttention(nn.Module):
|
16 |
+
"""Implement the scaled dot product attention with softmax.
|
17 |
+
Arguments
|
18 |
+
---------
|
19 |
+
softmax_scale: The temperature to use for the softmax attention.
|
20 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
21 |
+
runtime)
|
22 |
+
attention_dropout: The dropout rate to apply to the attention
|
23 |
+
(default: 0.0)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
27 |
+
super().__init__()
|
28 |
+
self.softmax_scale = softmax_scale
|
29 |
+
self.dropout_p = attention_dropout
|
30 |
+
|
31 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
32 |
+
max_s=None, need_weights=False):
|
33 |
+
"""Implements the multihead softmax attention.
|
34 |
+
Arguments
|
35 |
+
---------
|
36 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
37 |
+
if unpadded: (nnz, 3, h, d)
|
38 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
39 |
+
"""
|
40 |
+
assert not need_weights
|
41 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
42 |
+
assert qkv.is_cuda
|
43 |
+
|
44 |
+
if cu_seqlens is None:
|
45 |
+
batch_size = qkv.shape[0]
|
46 |
+
seqlen = qkv.shape[1]
|
47 |
+
if key_padding_mask is None:
|
48 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
49 |
+
max_s = seqlen
|
50 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
51 |
+
device=qkv.device)
|
52 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
53 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
54 |
+
softmax_scale=self.softmax_scale, causal=causal
|
55 |
+
)
|
56 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
57 |
+
else:
|
58 |
+
nheads = qkv.shape[-2]
|
59 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
60 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
61 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
62 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
63 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
64 |
+
softmax_scale=self.softmax_scale, causal=causal
|
65 |
+
)
|
66 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
67 |
+
indices, batch_size, seqlen),
|
68 |
+
'b s (h d) -> b s h d', h=nheads)
|
69 |
+
else:
|
70 |
+
assert max_s is not None
|
71 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
72 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
73 |
+
softmax_scale=self.softmax_scale, causal=causal
|
74 |
+
)
|
75 |
+
|
76 |
+
return output, None
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22f87624b94d869df9a619445357660b4c290970ba977f14e0a1ea2b46da4fda
|
3 |
+
size 608059320
|
modeling_intern_vit.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from einops import rearrange
|
12 |
+
from timm.models.layers import DropPath
|
13 |
+
from torch import nn
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
16 |
+
BaseModelOutputWithPooling)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.utils import logging
|
19 |
+
|
20 |
+
from .configuration_intern_vit import InternVisionConfig
|
21 |
+
|
22 |
+
try:
|
23 |
+
from .flash_attention import FlashAttention
|
24 |
+
has_flash_attn = True
|
25 |
+
except:
|
26 |
+
print('FlashAttention is not installed.')
|
27 |
+
has_flash_attn = False
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class InternRMSNorm(nn.Module):
|
34 |
+
def __init__(self, hidden_size, eps=1e-6):
|
35 |
+
super().__init__()
|
36 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
37 |
+
self.variance_epsilon = eps
|
38 |
+
|
39 |
+
def forward(self, hidden_states):
|
40 |
+
input_dtype = hidden_states.dtype
|
41 |
+
hidden_states = hidden_states.to(torch.float32)
|
42 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
43 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
44 |
+
return self.weight * hidden_states.to(input_dtype)
|
45 |
+
|
46 |
+
|
47 |
+
try:
|
48 |
+
from apex.normalization import FusedRMSNorm
|
49 |
+
|
50 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
51 |
+
|
52 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
|
53 |
+
except ImportError:
|
54 |
+
# using the normal InternRMSNorm
|
55 |
+
pass
|
56 |
+
except Exception:
|
57 |
+
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
|
58 |
+
pass
|
59 |
+
|
60 |
+
|
61 |
+
NORM2FN = {
|
62 |
+
'rms_norm': InternRMSNorm,
|
63 |
+
'layer_norm': nn.LayerNorm,
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
class InternVisionEmbeddings(nn.Module):
|
68 |
+
def __init__(self, config: InternVisionConfig):
|
69 |
+
super().__init__()
|
70 |
+
self.config = config
|
71 |
+
self.embed_dim = config.hidden_size
|
72 |
+
self.image_size = config.image_size
|
73 |
+
self.patch_size = config.patch_size
|
74 |
+
|
75 |
+
self.class_embedding = nn.Parameter(
|
76 |
+
torch.randn(1, 1, self.embed_dim),
|
77 |
+
)
|
78 |
+
|
79 |
+
self.patch_embedding = nn.Conv2d(
|
80 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
81 |
+
)
|
82 |
+
|
83 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
84 |
+
self.num_positions = self.num_patches + 1
|
85 |
+
|
86 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
87 |
+
|
88 |
+
def _get_pos_embed(self, pos_embed, H, W):
|
89 |
+
target_dtype = pos_embed.dtype
|
90 |
+
pos_embed = pos_embed.float().reshape(
|
91 |
+
1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
|
92 |
+
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
|
93 |
+
reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
|
94 |
+
return pos_embed
|
95 |
+
|
96 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
97 |
+
target_dtype = self.patch_embedding.weight.dtype
|
98 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
|
99 |
+
batch_size, _, height, width = patch_embeds.shape
|
100 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
101 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
102 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
103 |
+
position_embedding = torch.cat([
|
104 |
+
self.position_embedding[:, :1, :],
|
105 |
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
|
106 |
+
], dim=1)
|
107 |
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
108 |
+
return embeddings
|
109 |
+
|
110 |
+
|
111 |
+
class InternAttention(nn.Module):
|
112 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
113 |
+
|
114 |
+
def __init__(self, config: InternVisionConfig):
|
115 |
+
super().__init__()
|
116 |
+
self.config = config
|
117 |
+
self.embed_dim = config.hidden_size
|
118 |
+
self.num_heads = config.num_attention_heads
|
119 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
120 |
+
if config.use_flash_attn and not has_flash_attn:
|
121 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
122 |
+
self.head_dim = self.embed_dim // self.num_heads
|
123 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
124 |
+
raise ValueError(
|
125 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
126 |
+
f' {self.num_heads}).'
|
127 |
+
)
|
128 |
+
|
129 |
+
self.scale = self.head_dim ** -0.5
|
130 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
131 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
132 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
133 |
+
|
134 |
+
self.qk_normalization = config.qk_normalization
|
135 |
+
|
136 |
+
if self.qk_normalization:
|
137 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
138 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
139 |
+
|
140 |
+
if self.use_flash_attn:
|
141 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
142 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
143 |
+
|
144 |
+
def _naive_attn(self, x):
|
145 |
+
B, N, C = x.shape
|
146 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
147 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
148 |
+
|
149 |
+
if self.qk_normalization:
|
150 |
+
B_, H_, N_, D_ = q.shape
|
151 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
152 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
153 |
+
|
154 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
155 |
+
attn = attn.softmax(dim=-1)
|
156 |
+
attn = self.attn_drop(attn)
|
157 |
+
|
158 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
159 |
+
x = self.proj(x)
|
160 |
+
x = self.proj_drop(x)
|
161 |
+
return x
|
162 |
+
|
163 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
164 |
+
qkv = self.qkv(x)
|
165 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
166 |
+
|
167 |
+
if self.qk_normalization:
|
168 |
+
q, k, v = qkv.unbind(2)
|
169 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
170 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
171 |
+
qkv = torch.stack([q, k, v], dim=2)
|
172 |
+
|
173 |
+
context, _ = self.inner_attn(
|
174 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
175 |
+
)
|
176 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
177 |
+
outs = self.proj_drop(outs)
|
178 |
+
return outs
|
179 |
+
|
180 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
181 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class InternMLP(nn.Module):
|
186 |
+
def __init__(self, config: InternVisionConfig):
|
187 |
+
super().__init__()
|
188 |
+
self.config = config
|
189 |
+
self.act = ACT2FN[config.hidden_act]
|
190 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
191 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
192 |
+
|
193 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
194 |
+
hidden_states = self.fc1(hidden_states)
|
195 |
+
hidden_states = self.act(hidden_states)
|
196 |
+
hidden_states = self.fc2(hidden_states)
|
197 |
+
return hidden_states
|
198 |
+
|
199 |
+
|
200 |
+
class InternVisionEncoderLayer(nn.Module):
|
201 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
202 |
+
super().__init__()
|
203 |
+
self.embed_dim = config.hidden_size
|
204 |
+
self.intermediate_size = config.intermediate_size
|
205 |
+
self.norm_type = config.norm_type
|
206 |
+
|
207 |
+
self.attn = InternAttention(config)
|
208 |
+
self.mlp = InternMLP(config)
|
209 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
210 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
211 |
+
|
212 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
213 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
214 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
215 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
hidden_states: torch.Tensor,
|
220 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
221 |
+
"""
|
222 |
+
Args:
|
223 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
224 |
+
"""
|
225 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
226 |
+
|
227 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
|
228 |
+
|
229 |
+
return hidden_states
|
230 |
+
|
231 |
+
|
232 |
+
class InternVisionEncoder(nn.Module):
|
233 |
+
"""
|
234 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
235 |
+
[`InternEncoderLayer`].
|
236 |
+
|
237 |
+
Args:
|
238 |
+
config (`InternConfig`):
|
239 |
+
The corresponding vision configuration for the `InternEncoder`.
|
240 |
+
"""
|
241 |
+
|
242 |
+
def __init__(self, config: InternVisionConfig):
|
243 |
+
super().__init__()
|
244 |
+
self.config = config
|
245 |
+
# stochastic depth decay rule
|
246 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
247 |
+
self.layers = nn.ModuleList([
|
248 |
+
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
249 |
+
self.gradient_checkpointing = True
|
250 |
+
|
251 |
+
def forward(
|
252 |
+
self,
|
253 |
+
inputs_embeds,
|
254 |
+
output_hidden_states: Optional[bool] = None,
|
255 |
+
return_dict: Optional[bool] = None,
|
256 |
+
) -> Union[Tuple, BaseModelOutput]:
|
257 |
+
r"""
|
258 |
+
Args:
|
259 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
260 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
261 |
+
output_hidden_states (`bool`, *optional*):
|
262 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
263 |
+
for more detail.
|
264 |
+
return_dict (`bool`, *optional*):
|
265 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
266 |
+
"""
|
267 |
+
output_hidden_states = (
|
268 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
269 |
+
)
|
270 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
271 |
+
|
272 |
+
encoder_states = () if output_hidden_states else None
|
273 |
+
hidden_states = inputs_embeds
|
274 |
+
|
275 |
+
for idx, encoder_layer in enumerate(self.layers):
|
276 |
+
if output_hidden_states:
|
277 |
+
encoder_states = encoder_states + (hidden_states,)
|
278 |
+
if self.gradient_checkpointing and self.training:
|
279 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
280 |
+
encoder_layer,
|
281 |
+
hidden_states)
|
282 |
+
else:
|
283 |
+
layer_outputs = encoder_layer(
|
284 |
+
hidden_states,
|
285 |
+
)
|
286 |
+
hidden_states = layer_outputs
|
287 |
+
|
288 |
+
if output_hidden_states:
|
289 |
+
encoder_states = encoder_states + (hidden_states,)
|
290 |
+
|
291 |
+
if not return_dict:
|
292 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
293 |
+
return BaseModelOutput(
|
294 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
class InternVisionModel(PreTrainedModel):
|
299 |
+
main_input_name = 'pixel_values'
|
300 |
+
config_class = InternVisionConfig
|
301 |
+
_no_split_modules = ['InternVisionEncoderLayer']
|
302 |
+
|
303 |
+
def __init__(self, config: InternVisionConfig):
|
304 |
+
super().__init__(config)
|
305 |
+
self.config = config
|
306 |
+
|
307 |
+
self.embeddings = InternVisionEmbeddings(config)
|
308 |
+
self.encoder = InternVisionEncoder(config)
|
309 |
+
|
310 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
311 |
+
pos_emb = self.embeddings.position_embedding
|
312 |
+
_, num_positions, embed_dim = pos_emb.shape
|
313 |
+
cls_emb = pos_emb[:, :1, :]
|
314 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
315 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
316 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
317 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
318 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
319 |
+
self.embeddings.image_size = new_size
|
320 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
321 |
+
|
322 |
+
def get_input_embeddings(self):
|
323 |
+
return self.embeddings
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
328 |
+
output_hidden_states: Optional[bool] = None,
|
329 |
+
return_dict: Optional[bool] = None,
|
330 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
331 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
332 |
+
output_hidden_states = (
|
333 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
334 |
+
)
|
335 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
336 |
+
|
337 |
+
if pixel_values is None and pixel_embeds is None:
|
338 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
339 |
+
|
340 |
+
if pixel_embeds is not None:
|
341 |
+
hidden_states = pixel_embeds
|
342 |
+
else:
|
343 |
+
if len(pixel_values.shape) == 4:
|
344 |
+
hidden_states = self.embeddings(pixel_values)
|
345 |
+
else:
|
346 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
347 |
+
encoder_outputs = self.encoder(
|
348 |
+
inputs_embeds=hidden_states,
|
349 |
+
output_hidden_states=output_hidden_states,
|
350 |
+
return_dict=return_dict,
|
351 |
+
)
|
352 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
353 |
+
pooled_output = last_hidden_state[:, 0, :]
|
354 |
+
|
355 |
+
if not return_dict:
|
356 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
357 |
+
|
358 |
+
return BaseModelOutputWithPooling(
|
359 |
+
last_hidden_state=last_hidden_state,
|
360 |
+
pooler_output=pooled_output,
|
361 |
+
hidden_states=encoder_outputs.hidden_states,
|
362 |
+
attentions=encoder_outputs.attentions,
|
363 |
+
)
|
preprocessor_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 448,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_normalize": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
7 |
+
"image_mean": [
|
8 |
+
0.485,
|
9 |
+
0.456,
|
10 |
+
0.406
|
11 |
+
],
|
12 |
+
"image_std": [
|
13 |
+
0.229,
|
14 |
+
0.224,
|
15 |
+
0.225
|
16 |
+
],
|
17 |
+
"resample": 3,
|
18 |
+
"size": 448
|
19 |
+
}
|
triton-test.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from triton_flash_atn import _attention
|
3 |
+
|
4 |
+
# Define dimensions
|
5 |
+
batch_size = 2
|
6 |
+
num_heads = 4
|
7 |
+
seq_len = 128
|
8 |
+
head_dim = 64
|
9 |
+
|
10 |
+
# Create random input tensors for Q, K, V
|
11 |
+
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
12 |
+
dtype=torch.float16, device='cuda')
|
13 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
14 |
+
dtype=torch.float16, device='cuda')
|
15 |
+
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
16 |
+
dtype=torch.float16, device='cuda')
|
17 |
+
|
18 |
+
# Define whether the attention is causal and the scaling factor
|
19 |
+
causal = False
|
20 |
+
sm_scale = 1.0 / (head_dim ** 0.5)
|
21 |
+
|
22 |
+
# Apply flash attention
|
23 |
+
attention = _attention.apply
|
24 |
+
output = attention(q, k, v, causal, sm_scale)
|
25 |
+
|
26 |
+
print(output)
|
triton_bert_pading.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
|
8 |
+
class IndexFirstAxis(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, input, indices):
|
11 |
+
ctx.save_for_backward(indices)
|
12 |
+
assert input.ndim >= 2
|
13 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
14 |
+
second_dim = other_shape.numel()
|
15 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
16 |
+
# return input[indices]
|
17 |
+
return torch.gather(
|
18 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices,
|
19 |
+
"z -> z d", d=second_dim)
|
20 |
+
).reshape(-1, *other_shape)
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def backward(ctx, grad_output):
|
24 |
+
(indices,) = ctx.saved_tensors
|
25 |
+
assert grad_output.ndim >= 2
|
26 |
+
other_shape = grad_output.shape[1:]
|
27 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
28 |
+
grad_input = torch.zeros(
|
29 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
30 |
+
device=grad_output.device,
|
31 |
+
dtype=grad_output.dtype,
|
32 |
+
)
|
33 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
34 |
+
# grad_input[indices] = grad_output
|
35 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d",
|
36 |
+
d=grad_output.shape[1]), grad_output)
|
37 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
38 |
+
|
39 |
+
|
40 |
+
index_first_axis = IndexFirstAxis.apply
|
41 |
+
|
42 |
+
|
43 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
44 |
+
@staticmethod
|
45 |
+
def forward(ctx, values, indices, first_axis_dim):
|
46 |
+
ctx.save_for_backward(indices)
|
47 |
+
assert indices.ndim == 1
|
48 |
+
assert values.ndim >= 2
|
49 |
+
output = torch.zeros(
|
50 |
+
first_axis_dim, *
|
51 |
+
values.shape[1:], device=values.device, dtype=values.dtype
|
52 |
+
)
|
53 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
54 |
+
output[indices] = values
|
55 |
+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output):
|
60 |
+
(indices,) = ctx.saved_tensors
|
61 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
62 |
+
grad_values = grad_output[indices]
|
63 |
+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
64 |
+
return grad_values, None, None
|
65 |
+
|
66 |
+
|
67 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
68 |
+
|
69 |
+
|
70 |
+
class IndexFirstAxisResidual(torch.autograd.Function):
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input, indices):
|
73 |
+
ctx.save_for_backward(indices)
|
74 |
+
assert input.ndim >= 2
|
75 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
76 |
+
second_dim = other_shape.numel()
|
77 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
78 |
+
output = input[indices]
|
79 |
+
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
80 |
+
# memory format to channel_first. In other words, input might not be contiguous.
|
81 |
+
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
82 |
+
return output, input.detach()
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def backward(ctx, grad_output, grad_residual):
|
86 |
+
(indices,) = ctx.saved_tensors
|
87 |
+
assert grad_output.ndim >= 2
|
88 |
+
other_shape = grad_output.shape[1:]
|
89 |
+
assert grad_residual.shape[1:] == other_shape
|
90 |
+
grad_input = grad_residual
|
91 |
+
# grad_input[indices] += grad_output
|
92 |
+
indices = indices.reshape(
|
93 |
+
indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
94 |
+
indices = indices.expand_as(grad_output)
|
95 |
+
grad_input.scatter_add_(0, indices, grad_output)
|
96 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
97 |
+
|
98 |
+
|
99 |
+
index_first_axis_residual = IndexFirstAxisResidual.apply
|
100 |
+
|
101 |
+
|
102 |
+
def unpad_input(hidden_states, attention_mask):
|
103 |
+
"""
|
104 |
+
Arguments:
|
105 |
+
hidden_states: (batch, seqlen, ...)
|
106 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
107 |
+
Return:
|
108 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
109 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
110 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
111 |
+
max_seqlen_in_batch: int
|
112 |
+
"""
|
113 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
114 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
115 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
116 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
|
117 |
+
dtype=torch.torch.int32), (1, 0))
|
118 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
119 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
120 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
121 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
122 |
+
# so we write custom forward and backward to make it a bit faster.
|
123 |
+
return (
|
124 |
+
index_first_axis(
|
125 |
+
rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
126 |
+
indices,
|
127 |
+
cu_seqlens,
|
128 |
+
max_seqlen_in_batch,
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
|
133 |
+
"""
|
134 |
+
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
|
135 |
+
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
|
136 |
+
|
137 |
+
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
138 |
+
```
|
139 |
+
[
|
140 |
+
[2, 3, 0, 0, 0, 0],
|
141 |
+
[3, 2, 0, 0, 0, 0],
|
142 |
+
[6, 0, 0, 0, 0, 0]
|
143 |
+
]
|
144 |
+
```
|
145 |
+
, which refers to the 3D-attention mask:
|
146 |
+
```
|
147 |
+
[
|
148 |
+
[
|
149 |
+
[1, 0, 0, 0, 0, 0],
|
150 |
+
[1, 1, 0, 0, 0, 0],
|
151 |
+
[0, 0, 1, 0, 0, 0],
|
152 |
+
[0, 0, 1, 1, 0, 0],
|
153 |
+
[0, 0, 1, 1, 1, 0],
|
154 |
+
[0, 0, 0, 0, 0, 1]
|
155 |
+
],
|
156 |
+
[
|
157 |
+
[1, 0, 0, 0, 0, 0],
|
158 |
+
[1, 1, 0, 0, 0, 0],
|
159 |
+
[1, 1, 1, 0, 0, 0],
|
160 |
+
[0, 0, 0, 1, 0, 0],
|
161 |
+
[0, 0, 0, 1, 1, 0],
|
162 |
+
[0, 0, 0, 0, 0, 1]
|
163 |
+
],
|
164 |
+
[
|
165 |
+
[1, 0, 0, 0, 0, 0],
|
166 |
+
[1, 1, 0, 0, 0, 0],
|
167 |
+
[1, 1, 1, 0, 0, 0],
|
168 |
+
[1, 1, 1, 1, 0, 0],
|
169 |
+
[1, 1, 1, 1, 1, 0],
|
170 |
+
[1, 1, 1, 1, 1, 1]
|
171 |
+
]
|
172 |
+
]
|
173 |
+
```.
|
174 |
+
|
175 |
+
Arguments:
|
176 |
+
hidden_states: (batch, seqlen, ...)
|
177 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
178 |
+
Return:
|
179 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
180 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
181 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
182 |
+
max_seqlen_in_batch: int
|
183 |
+
"""
|
184 |
+
length = attention_mask_in_length.sum(dim=-1)
|
185 |
+
seqlen = attention_mask_in_length.size(-1)
|
186 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(
|
187 |
+
len(length), seqlen) < length.unsqueeze(1)
|
188 |
+
real_indices_idx = torch.nonzero(
|
189 |
+
attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
190 |
+
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
191 |
+
indices = torch.nonzero(attention_mask_2d.flatten(),
|
192 |
+
as_tuple=False).flatten()
|
193 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
194 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
|
195 |
+
dtype=torch.torch.int32), (1, 0))
|
196 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
197 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
198 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
199 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
200 |
+
# so we write custom forward and backward to make it a bit faster.
|
201 |
+
return (
|
202 |
+
index_first_axis(
|
203 |
+
rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
204 |
+
indices,
|
205 |
+
cu_seqlens,
|
206 |
+
max_seqlen_in_batch,
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
211 |
+
"""
|
212 |
+
Arguments:
|
213 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
214 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
215 |
+
batch: int, batch size for the padded sequence.
|
216 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
217 |
+
Return:
|
218 |
+
hidden_states: (batch, seqlen, ...)
|
219 |
+
"""
|
220 |
+
dim = hidden_states.shape[-1]
|
221 |
+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
222 |
+
# output[indices] = hidden_states
|
223 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
224 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
triton_flash_atn.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Fused Attention
|
3 |
+
===============
|
4 |
+
|
5 |
+
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
6 |
+
Credits: OpenAI kernel team
|
7 |
+
|
8 |
+
Extra Credits:
|
9 |
+
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
10 |
+
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
import pytest
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import triton
|
18 |
+
import triton.language as tl
|
19 |
+
|
20 |
+
# Pick the fp8 data type
|
21 |
+
|
22 |
+
# AMD E4M3B8
|
23 |
+
# Note: When picking this f8 data type, scaling is required when using f8
|
24 |
+
# for the second gemm
|
25 |
+
# TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
|
26 |
+
|
27 |
+
# AMD E5M2B16
|
28 |
+
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
|
29 |
+
|
30 |
+
|
31 |
+
@triton.jit
|
32 |
+
def _attn_fwd_inner(acc, l_i, m_i, q,
|
33 |
+
K_block_ptr, V_block_ptr,
|
34 |
+
start_m,
|
35 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
36 |
+
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
37 |
+
N_CTX,
|
38 |
+
pre_load_v: tl.constexpr):
|
39 |
+
# range of values handled by this stage
|
40 |
+
if STAGE == 1:
|
41 |
+
lo, hi = 0, start_m * BLOCK_M
|
42 |
+
elif STAGE == 2:
|
43 |
+
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
44 |
+
lo = tl.multiple_of(lo, BLOCK_M)
|
45 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
46 |
+
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
47 |
+
# causal = False
|
48 |
+
else:
|
49 |
+
lo, hi = 0, N_CTX
|
50 |
+
# loop over k, v and update accumulator
|
51 |
+
for start_n in range(lo, hi, BLOCK_N):
|
52 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
53 |
+
# -- compute qk ----
|
54 |
+
k = tl.load(K_block_ptr)
|
55 |
+
if pre_load_v:
|
56 |
+
v = tl.load(V_block_ptr)
|
57 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
58 |
+
if STAGE == 2:
|
59 |
+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
60 |
+
qk = tl.where(mask, qk, float("-inf"))
|
61 |
+
qk += tl.dot(q, k)
|
62 |
+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
63 |
+
qk = qk - m_ij[:, None]
|
64 |
+
p = tl.math.exp2(qk)
|
65 |
+
# -- update output accumulator --
|
66 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
67 |
+
acc = acc * alpha[:, None]
|
68 |
+
if not pre_load_v:
|
69 |
+
v = tl.load(V_block_ptr)
|
70 |
+
acc += tl.dot(p.to(v.dtype), v)
|
71 |
+
# -- update m_i and l_i
|
72 |
+
l_ij = tl.sum(p, 1)
|
73 |
+
l_i = l_i * alpha + l_ij
|
74 |
+
# update m_i and l_i
|
75 |
+
m_i = m_ij
|
76 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
77 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
78 |
+
return acc, l_i, m_i
|
79 |
+
|
80 |
+
|
81 |
+
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
|
82 |
+
# the code below and commenting out the equivalent parameters is convenient for
|
83 |
+
# re-tuning.
|
84 |
+
@triton.autotune(
|
85 |
+
configs=[
|
86 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
|
87 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
88 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
|
89 |
+
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
90 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
91 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
92 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
93 |
+
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
94 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
95 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
96 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
|
97 |
+
'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
|
98 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
|
99 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
100 |
+
],
|
101 |
+
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
102 |
+
)
|
103 |
+
@triton.jit
|
104 |
+
def _attn_fwd(Q, K, V, sm_scale, M, Out,
|
105 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
106 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
107 |
+
stride_vz, stride_vh, stride_vk, stride_vn,
|
108 |
+
stride_oz, stride_oh, stride_om, stride_on,
|
109 |
+
Z, H,
|
110 |
+
N_CTX,
|
111 |
+
BLOCK_DMODEL: tl.constexpr,
|
112 |
+
STAGE: tl.constexpr,
|
113 |
+
BLOCK_M: tl.constexpr,
|
114 |
+
BLOCK_N: tl.constexpr,
|
115 |
+
pre_load_v: tl.constexpr,
|
116 |
+
):
|
117 |
+
start_m = tl.program_id(0)
|
118 |
+
off_hz = tl.program_id(1)
|
119 |
+
qvk_offset = off_hz * stride_qh
|
120 |
+
|
121 |
+
# block pointers
|
122 |
+
Q_block_ptr = tl.make_block_ptr(
|
123 |
+
base=Q + qvk_offset,
|
124 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
125 |
+
strides=(stride_qm, stride_qk),
|
126 |
+
offsets=(start_m * BLOCK_M, 0),
|
127 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
128 |
+
order=(1, 0),
|
129 |
+
)
|
130 |
+
V_block_ptr = tl.make_block_ptr(
|
131 |
+
base=V + qvk_offset,
|
132 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
133 |
+
strides=(stride_vk, stride_vn),
|
134 |
+
offsets=(0, 0),
|
135 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
136 |
+
order=(1, 0),
|
137 |
+
)
|
138 |
+
K_block_ptr = tl.make_block_ptr(
|
139 |
+
base=K + qvk_offset,
|
140 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
141 |
+
strides=(stride_kk, stride_kn),
|
142 |
+
offsets=(0, 0),
|
143 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
144 |
+
order=(0, 1),
|
145 |
+
)
|
146 |
+
O_block_ptr = tl.make_block_ptr(
|
147 |
+
base=Out + qvk_offset,
|
148 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
149 |
+
strides=(stride_om, stride_on),
|
150 |
+
offsets=(start_m * BLOCK_M, 0),
|
151 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
152 |
+
order=(1, 0),
|
153 |
+
)
|
154 |
+
# initialize offsets
|
155 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
156 |
+
offs_n = tl.arange(0, BLOCK_N)
|
157 |
+
# initialize pointer to m and l
|
158 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
159 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
160 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
161 |
+
# scale sm_scale by log_2(e) and use
|
162 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
163 |
+
# don't work as expected with `exp` in the loop
|
164 |
+
qk_scale = sm_scale * 1.44269504
|
165 |
+
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
|
166 |
+
q = tl.load(Q_block_ptr)
|
167 |
+
q = (q * qk_scale).to(q.dtype)
|
168 |
+
# stage 1: off-band
|
169 |
+
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
170 |
+
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
171 |
+
if STAGE & 1:
|
172 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
173 |
+
start_m,
|
174 |
+
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
175 |
+
4 - STAGE, offs_m, offs_n, N_CTX,
|
176 |
+
pre_load_v,
|
177 |
+
)
|
178 |
+
# stage 2: on-band
|
179 |
+
if STAGE & 2:
|
180 |
+
# barrier makes it easier for compielr to schedule the
|
181 |
+
# two loops independently
|
182 |
+
tl.debug_barrier()
|
183 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
184 |
+
start_m,
|
185 |
+
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
186 |
+
2, offs_m, offs_n, N_CTX,
|
187 |
+
pre_load_v,
|
188 |
+
)
|
189 |
+
# epilogue
|
190 |
+
# write back m
|
191 |
+
acc = acc / l_i[:, None]
|
192 |
+
m_ptrs = M + off_hz * N_CTX + offs_m
|
193 |
+
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
|
194 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
195 |
+
|
196 |
+
|
197 |
+
@triton.jit
|
198 |
+
def _attn_bwd_preprocess(O, DO,
|
199 |
+
Delta,
|
200 |
+
Z, H, N_CTX,
|
201 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
|
202 |
+
):
|
203 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
204 |
+
off_hz = tl.program_id(1)
|
205 |
+
off_n = tl.arange(0, D_HEAD)
|
206 |
+
o = tl.load(O + off_hz * D_HEAD * N_CTX +
|
207 |
+
off_m[:, None] * D_HEAD + off_n[None, :])
|
208 |
+
do = tl.load(DO + off_hz * D_HEAD * N_CTX +
|
209 |
+
off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
210 |
+
delta = tl.sum(o * do, axis=1)
|
211 |
+
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
212 |
+
|
213 |
+
|
214 |
+
# The main inner-loop logic for computing dK and dV.
|
215 |
+
@triton.jit
|
216 |
+
def _attn_bwd_dkdv(dk, dv,
|
217 |
+
Q, k, v, sm_scale,
|
218 |
+
DO,
|
219 |
+
M, D,
|
220 |
+
# shared by Q/K/V/DO.
|
221 |
+
stride_tok, stride_d,
|
222 |
+
H, N_CTX, BLOCK_M1: tl.constexpr,
|
223 |
+
BLOCK_N1: tl.constexpr,
|
224 |
+
BLOCK_DMODEL: tl.constexpr,
|
225 |
+
# Filled in by the wrapper.
|
226 |
+
start_n, start_m, num_steps,
|
227 |
+
MASK: tl.constexpr):
|
228 |
+
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
229 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
230 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
231 |
+
QT_block_ptr = tl.make_block_ptr(
|
232 |
+
base=Q,
|
233 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
234 |
+
strides=(stride_d, stride_tok),
|
235 |
+
offsets=(0, start_m),
|
236 |
+
block_shape=(BLOCK_DMODEL, BLOCK_M1),
|
237 |
+
order=(0, 1)
|
238 |
+
)
|
239 |
+
DO_block_ptr = tl.make_block_ptr(
|
240 |
+
base=DO,
|
241 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
242 |
+
strides=(stride_tok, stride_d),
|
243 |
+
offsets=(start_m, 0),
|
244 |
+
block_shape=(BLOCK_M1, BLOCK_DMODEL),
|
245 |
+
order=(1, 0)
|
246 |
+
)
|
247 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
248 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
249 |
+
curr_m = start_m
|
250 |
+
step_m = BLOCK_M1
|
251 |
+
for blk_idx in range(num_steps):
|
252 |
+
qT = tl.load(QT_block_ptr)
|
253 |
+
# Load m before computing qk to reduce pipeline stall.
|
254 |
+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
|
255 |
+
m = tl.load(M + offs_m)
|
256 |
+
qkT = tl.dot(k, qT)
|
257 |
+
pT = tl.math.exp2(qkT - m[None, :])
|
258 |
+
# Autoregressive masking.
|
259 |
+
if MASK:
|
260 |
+
mask = (offs_m[None, :] >= offs_n[:, None])
|
261 |
+
pT = tl.where(mask, pT, 0.0)
|
262 |
+
do = tl.load(DO_block_ptr)
|
263 |
+
# Compute dV.
|
264 |
+
ppT = pT
|
265 |
+
ppT = ppT.to(tl.float16)
|
266 |
+
dv += tl.dot(ppT, do)
|
267 |
+
# D (= delta) is pre-divided by ds_scale.
|
268 |
+
Di = tl.load(D + offs_m)
|
269 |
+
# Compute dP and dS.
|
270 |
+
dpT = tl.dot(v, tl.trans(do))
|
271 |
+
dsT = pT * (dpT - Di[None, :])
|
272 |
+
dsT = dsT.to(tl.float16)
|
273 |
+
dk += tl.dot(dsT, tl.trans(qT))
|
274 |
+
# Increment pointers.
|
275 |
+
curr_m += step_m
|
276 |
+
QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
|
277 |
+
DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
|
278 |
+
return dk, dv
|
279 |
+
|
280 |
+
|
281 |
+
# the main inner-loop logic for computing dQ
|
282 |
+
@triton.jit
|
283 |
+
def _attn_bwd_dq(dq, q, K, V,
|
284 |
+
do, m, D,
|
285 |
+
# shared by Q/K/V/DO.
|
286 |
+
stride_tok, stride_d,
|
287 |
+
H, N_CTX,
|
288 |
+
BLOCK_M2: tl.constexpr,
|
289 |
+
BLOCK_N2: tl.constexpr,
|
290 |
+
BLOCK_DMODEL: tl.constexpr,
|
291 |
+
# Filled in by the wrapper.
|
292 |
+
start_m, start_n, num_steps,
|
293 |
+
MASK: tl.constexpr):
|
294 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
295 |
+
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
296 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
297 |
+
KT_block_ptr = tl.make_block_ptr(
|
298 |
+
base=K,
|
299 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
300 |
+
strides=(stride_d, stride_tok),
|
301 |
+
offsets=(0, start_n),
|
302 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
303 |
+
order=(0, 1)
|
304 |
+
)
|
305 |
+
VT_block_ptr = tl.make_block_ptr(
|
306 |
+
base=V,
|
307 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
308 |
+
strides=(stride_d, stride_tok),
|
309 |
+
offsets=(0, start_n),
|
310 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
311 |
+
order=(0, 1)
|
312 |
+
)
|
313 |
+
# D (= delta) is pre-divided by ds_scale.
|
314 |
+
Di = tl.load(D + offs_m)
|
315 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
316 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
317 |
+
curr_n = start_n
|
318 |
+
step_n = BLOCK_N2
|
319 |
+
for blk_idx in range(num_steps):
|
320 |
+
kT = tl.load(KT_block_ptr)
|
321 |
+
qk = tl.dot(q, kT)
|
322 |
+
p = tl.math.exp2(qk - m)
|
323 |
+
# Autoregressive masking.
|
324 |
+
if MASK:
|
325 |
+
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
326 |
+
mask = (offs_m[:, None] >= offs_n[None, :])
|
327 |
+
p = tl.where(mask, p, 0.0)
|
328 |
+
# Compute dP and dS.
|
329 |
+
vT = tl.load(VT_block_ptr)
|
330 |
+
dp = tl.dot(do, vT).to(tl.float32)
|
331 |
+
ds = p * (dp - Di[:, None])
|
332 |
+
ds = ds.to(tl.float16)
|
333 |
+
# Compute dQ.
|
334 |
+
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
335 |
+
dq += tl.dot(ds, tl.trans(kT))
|
336 |
+
# Increment pointers.
|
337 |
+
curr_n += step_n
|
338 |
+
KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
|
339 |
+
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
|
340 |
+
return dq
|
341 |
+
|
342 |
+
|
343 |
+
@triton.autotune(
|
344 |
+
configs=[
|
345 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
346 |
+
num_stages=1, num_warps=4),
|
347 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
348 |
+
num_stages=1, num_warps=4),
|
349 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
350 |
+
num_stages=1, num_warps=4),
|
351 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
352 |
+
num_stages=1, num_warps=4),
|
353 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
354 |
+
num_stages=1, num_warps=4),
|
355 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
356 |
+
num_stages=1, num_warps=4),
|
357 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
358 |
+
num_stages=1, num_warps=4),
|
359 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
360 |
+
num_stages=1, num_warps=4),
|
361 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
362 |
+
num_stages=1, num_warps=8),
|
363 |
+
],
|
364 |
+
key=['H', 'N_CTX', 'BLOCK_DMODEL'],
|
365 |
+
)
|
366 |
+
@triton.jit
|
367 |
+
def _attn_bwd(Q, K, V, sm_scale,
|
368 |
+
DO,
|
369 |
+
DQ, DK, DV,
|
370 |
+
M, D,
|
371 |
+
# shared by Q/K/V/DO.
|
372 |
+
stride_z, stride_h, stride_tok, stride_d,
|
373 |
+
# H = 16, N_CTX = 1024
|
374 |
+
H, N_CTX,
|
375 |
+
BLOCK_DMODEL: tl.constexpr,
|
376 |
+
BLOCK_M1: tl.constexpr,
|
377 |
+
BLOCK_N1: tl.constexpr,
|
378 |
+
BLOCK_M2: tl.constexpr,
|
379 |
+
BLOCK_N2: tl.constexpr,
|
380 |
+
BLK_SLICE_FACTOR: tl.constexpr):
|
381 |
+
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
382 |
+
|
383 |
+
bhid = tl.program_id(2)
|
384 |
+
off_chz = (bhid * N_CTX).to(tl.int64)
|
385 |
+
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
386 |
+
pid = tl.program_id(0)
|
387 |
+
|
388 |
+
# offset pointers for batch/head
|
389 |
+
Q += adj
|
390 |
+
K += adj
|
391 |
+
V += adj
|
392 |
+
DO += adj
|
393 |
+
DQ += adj
|
394 |
+
DK += adj
|
395 |
+
DV += adj
|
396 |
+
M += off_chz
|
397 |
+
D += off_chz
|
398 |
+
|
399 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
400 |
+
|
401 |
+
start_n = pid * BLOCK_N1
|
402 |
+
# This assignment is important. It is what allows us to pick the diagonal
|
403 |
+
# blocks. Later, when we want to do the lower triangular, we update start_m
|
404 |
+
# after the first dkdv call.
|
405 |
+
start_m = start_n
|
406 |
+
|
407 |
+
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
408 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
409 |
+
|
410 |
+
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
411 |
+
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
412 |
+
|
413 |
+
K_block_ptr = tl.make_block_ptr(
|
414 |
+
base=K,
|
415 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
416 |
+
strides=(stride_tok, stride_d),
|
417 |
+
offsets=(start_n, 0),
|
418 |
+
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
419 |
+
order=(1, 0),
|
420 |
+
)
|
421 |
+
V_block_ptr = tl.make_block_ptr(
|
422 |
+
base=V,
|
423 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
424 |
+
strides=(stride_tok, stride_d),
|
425 |
+
offsets=(start_n, 0),
|
426 |
+
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
427 |
+
order=(1, 0),
|
428 |
+
)
|
429 |
+
|
430 |
+
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
|
431 |
+
k = tl.load(K_block_ptr)
|
432 |
+
v = tl.load(V_block_ptr)
|
433 |
+
|
434 |
+
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
435 |
+
|
436 |
+
dk, dv = _attn_bwd_dkdv(dk, dv,
|
437 |
+
Q, k, v, sm_scale,
|
438 |
+
DO,
|
439 |
+
M, D,
|
440 |
+
stride_tok, stride_d,
|
441 |
+
H, N_CTX,
|
442 |
+
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
443 |
+
start_n, start_m, num_steps,
|
444 |
+
MASK=True
|
445 |
+
)
|
446 |
+
|
447 |
+
start_m += num_steps * MASK_BLOCK_M1
|
448 |
+
num_steps = (N_CTX - start_m) // BLOCK_M1
|
449 |
+
|
450 |
+
# Compute dK and dV for non-masked blocks.
|
451 |
+
dk, dv = _attn_bwd_dkdv(
|
452 |
+
dk, dv,
|
453 |
+
Q, k, v, sm_scale,
|
454 |
+
DO,
|
455 |
+
M, D,
|
456 |
+
stride_tok, stride_d,
|
457 |
+
H, N_CTX,
|
458 |
+
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
459 |
+
start_n, start_m, num_steps,
|
460 |
+
MASK=False
|
461 |
+
)
|
462 |
+
|
463 |
+
DV_block_ptrs = tl.make_block_ptr(
|
464 |
+
base=DV,
|
465 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
466 |
+
strides=(stride_tok, stride_d),
|
467 |
+
offsets=(start_n, 0),
|
468 |
+
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
469 |
+
order=(1, 0)
|
470 |
+
)
|
471 |
+
tl.store(DV_block_ptrs, dv.to(tl.float16))
|
472 |
+
|
473 |
+
# Write back dK.
|
474 |
+
dk *= sm_scale
|
475 |
+
DK_block_ptrs = tl.make_block_ptr(
|
476 |
+
base=DK,
|
477 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
478 |
+
strides=(stride_tok, stride_d),
|
479 |
+
offsets=(start_n, 0),
|
480 |
+
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
481 |
+
order=(1, 0)
|
482 |
+
)
|
483 |
+
tl.store(DK_block_ptrs, dk.to(tl.float16))
|
484 |
+
|
485 |
+
# THIS BLOCK DOES DQ:
|
486 |
+
start_m = pid * BLOCK_M2
|
487 |
+
end_n = start_m + BLOCK_M2
|
488 |
+
|
489 |
+
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
490 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
491 |
+
|
492 |
+
Q_block_ptr = tl.make_block_ptr(
|
493 |
+
base=Q,
|
494 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
495 |
+
strides=(stride_tok, stride_d),
|
496 |
+
offsets=(start_m, 0),
|
497 |
+
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
498 |
+
order=(1, 0)
|
499 |
+
)
|
500 |
+
|
501 |
+
DO_block_ptr = tl.make_block_ptr(
|
502 |
+
base=DO,
|
503 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
504 |
+
strides=(stride_tok, stride_d),
|
505 |
+
offsets=(start_m, 0),
|
506 |
+
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
507 |
+
order=(1, 0)
|
508 |
+
)
|
509 |
+
q = tl.load(Q_block_ptr)
|
510 |
+
do = tl.load(DO_block_ptr)
|
511 |
+
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
512 |
+
|
513 |
+
m = tl.load(M + offs_m)
|
514 |
+
m = m[:, None]
|
515 |
+
|
516 |
+
# Compute dQ for masked (diagonal) blocks.
|
517 |
+
# NOTE: This code scans each row of QK^T backward (from right to left,
|
518 |
+
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
519 |
+
# not due to anything important. I just wanted to reuse the loop
|
520 |
+
# structure for dK & dV above as much as possible.
|
521 |
+
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
522 |
+
dq = _attn_bwd_dq(dq, q, K, V,
|
523 |
+
do, m, D,
|
524 |
+
stride_tok, stride_d,
|
525 |
+
H, N_CTX,
|
526 |
+
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
527 |
+
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
528 |
+
MASK=True
|
529 |
+
)
|
530 |
+
end_n -= num_steps * MASK_BLOCK_N2
|
531 |
+
# stage 2
|
532 |
+
num_steps = end_n // BLOCK_N2
|
533 |
+
dq = _attn_bwd_dq(dq, q, K, V,
|
534 |
+
do, m, D,
|
535 |
+
stride_tok, stride_d,
|
536 |
+
H, N_CTX,
|
537 |
+
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
538 |
+
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
539 |
+
MASK=False
|
540 |
+
)
|
541 |
+
# Write back dQ.
|
542 |
+
DQ_block_ptr = tl.make_block_ptr(
|
543 |
+
base=DQ,
|
544 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
545 |
+
strides=(stride_tok, stride_d),
|
546 |
+
offsets=(start_m, 0),
|
547 |
+
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
548 |
+
order=(1, 0)
|
549 |
+
)
|
550 |
+
dq *= LN2
|
551 |
+
tl.store(DQ_block_ptr, dq.to(tl.float16))
|
552 |
+
|
553 |
+
|
554 |
+
empty = torch.empty(128, device="cuda")
|
555 |
+
|
556 |
+
|
557 |
+
class _attention(torch.autograd.Function):
|
558 |
+
|
559 |
+
@staticmethod
|
560 |
+
def forward(ctx, q, k, v, causal, sm_scale):
|
561 |
+
# shape constraints
|
562 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
563 |
+
assert Lq == Lk and Lk == Lv
|
564 |
+
assert Lk in {16, 32, 64, 128}
|
565 |
+
o = torch.empty_like(q, dtype=v.dtype)
|
566 |
+
if torch.version.hip is None:
|
567 |
+
BLOCK_M = 128
|
568 |
+
BLOCK_N = 64 if Lk <= 64 else 32
|
569 |
+
num_stages = 4 if Lk <= 64 else 3
|
570 |
+
num_warps = 4 if Lk <= 64 else 8
|
571 |
+
# Tuning for H100
|
572 |
+
if torch.cuda.get_device_capability()[0] == 9:
|
573 |
+
num_warps = 8
|
574 |
+
num_stages = 7 if Lk >= 64 else 3
|
575 |
+
stage = 3 if causal else 1
|
576 |
+
|
577 |
+
def grid(META): return (
|
578 |
+
triton.cdiv(q.shape[2], META['BLOCK_M']),
|
579 |
+
q.shape[0] * q.shape[1],
|
580 |
+
1
|
581 |
+
)
|
582 |
+
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
|
583 |
+
device=q.device, dtype=torch.float32)
|
584 |
+
_attn_fwd[grid](
|
585 |
+
q, k, v, sm_scale, M, o,
|
586 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
587 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
588 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
589 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
590 |
+
q.shape[0], q.shape[1],
|
591 |
+
N_CTX=q.shape[2],
|
592 |
+
BLOCK_DMODEL=Lk,
|
593 |
+
STAGE=stage,
|
594 |
+
)
|
595 |
+
|
596 |
+
# restore the grid for bwd kernel
|
597 |
+
best_config = _attn_fwd.get_best_config()
|
598 |
+
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
|
599 |
+
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
|
600 |
+
|
601 |
+
ctx.save_for_backward(q, k, v, o, M)
|
602 |
+
ctx.grid = grid
|
603 |
+
ctx.sm_scale = sm_scale
|
604 |
+
ctx.BLOCK_DMODEL = Lk
|
605 |
+
ctx.causal = causal
|
606 |
+
return o
|
607 |
+
|
608 |
+
@staticmethod
|
609 |
+
def backward(ctx, do):
|
610 |
+
if torch.version.hip is not None:
|
611 |
+
BLOCK = 64
|
612 |
+
else:
|
613 |
+
BLOCK = 128
|
614 |
+
q, k, v, o, M = ctx.saved_tensors
|
615 |
+
assert do.is_contiguous()
|
616 |
+
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
617 |
+
dq = torch.empty_like(q)
|
618 |
+
dk = torch.empty_like(k)
|
619 |
+
dv = torch.empty_like(v)
|
620 |
+
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
621 |
+
PRE_BLOCK = 128
|
622 |
+
NUM_WARPS, NUM_STAGES = 4, 1
|
623 |
+
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
624 |
+
BLK_SLICE_FACTOR = 2
|
625 |
+
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
626 |
+
arg_k = k
|
627 |
+
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
|
628 |
+
assert N_CTX % PRE_BLOCK == 0
|
629 |
+
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
630 |
+
delta = torch.empty_like(M)
|
631 |
+
_attn_bwd_preprocess[pre_grid](
|
632 |
+
o, do,
|
633 |
+
delta,
|
634 |
+
BATCH, N_HEAD, N_CTX,
|
635 |
+
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
|
636 |
+
)
|
637 |
+
|
638 |
+
def grid(META): return (
|
639 |
+
triton.cdiv(N_CTX, META['BLOCK_N1']),
|
640 |
+
1,
|
641 |
+
BATCH * N_HEAD
|
642 |
+
)
|
643 |
+
_attn_bwd[grid](
|
644 |
+
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
|
645 |
+
M, delta,
|
646 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
647 |
+
N_HEAD, N_CTX,
|
648 |
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL
|
649 |
+
)
|
650 |
+
|
651 |
+
return dq, dk, dv, None, None
|
652 |
+
|
653 |
+
|
654 |
+
attention = _attention.apply
|