radna commited on
Commit
2b894e3
1 Parent(s): 6aa34f0

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - laion/laion2B-en
5
+ - laion/laion-coco
6
+ - laion/laion2B-multi
7
+ - kakaobrain/coyo-700m
8
+ - conceptual_captions
9
+ - wanng/wukong100m
10
+ pipeline_tag: image-feature-extraction
11
+ ---
12
+
13
+ # Model Card for InternViT-300M-448px
14
+ <p align="center">
15
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/AUE-3OBtfr9vDA7Elgkhd.webp" alt="Image Description" width="300" height="300">
16
+ </p>
17
+
18
+ [\[🆕 Blog\]](https://internvl.github.io/blog/) [\[📜 InternVL 1.0 Paper\]](https://arxiv.org/abs/2312.14238) [\[📜 InternVL 1.5 Report\]](https://arxiv.org/abs/2404.16821) [\[🗨️ Chat Demo\]](https://internvl.opengvlab.com/)
19
+
20
+ [\[🤗 HF Demo\]](https://huggingface.co/spaces/OpenGVLab/InternVL) [\[🚀 Quick Start\]](#model-usage) [\[🌐 Community-hosted API\]](https://rapidapi.com/adushar1320/api/internvl-chat) [\[📖 中文解读\]](https://zhuanlan.zhihu.com/p/675877376)
21
+
22
+ This update primarily focuses on enhancing the efficiency of the vision foundation model. We developed InternViT-300M-448px by distilling knowledge from the robust vision foundation model, [InternViT-6B-448px-V1-5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5). Like its predecessor, InternViT-300M-448px features a dynamic input resolution of 448×448, with a basic tile size of 448×448. During training, it allows for 1 to 12 tiles, and expands to 1 to 40 tiles during testing. Additionally, it inherits the powerful robustness, OCR capability, and high-resolution processing capacity from InternViT-6B-448px-V1-5.
23
+
24
+ ## Model Details
25
+ - **Model Type:** vision foundation model, feature backbone
26
+ - **Model Stats:**
27
+ - Params (M): 304
28
+ - Image size: 448 x 448, training with 1 - 12 tiles
29
+ - **Pretrain Dataset:** LAION-en, LAION-zh, COYO, GRIT, COCO, TextCaps, Objects365, OpenImages, All-Seeing, Wukong-OCR, LaionCOCO-OCR, and other OCR-related datasets.
30
+ To enhance the OCR capability of the model, we have incorporated additional OCR data alongside the general caption datasets. Specifically, we utilized PaddleOCR to perform Chinese OCR on images from Wukong and English OCR on images from LAION-COCO.
31
+
32
+ ## Released Models
33
+ ### Vision Foundation model
34
+ | Model | Date | Download | Note |
35
+ | ----------------------- | ---------- | ---------------------------------------------------------------------- | -------------------------------- |
36
+ | InternViT-6B-448px-V1-5 | 2024.04.20 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5) | support dynamic resolution, super strong OCR (🔥new) |
37
+ | InternViT-6B-448px-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-2) | 448 resolution |
38
+ | InternViT-6B-448px-V1-0 | 2024.01.30 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-0) | 448 resolution |
39
+ | InternViT-6B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-224px) | vision foundation model |
40
+ | InternVL-14B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-14B-224px) | vision-language foundation model |
41
+
42
+ ### Multimodal Large Language Model (MLLM)
43
+ | Model | Date | Download | Note |
44
+ | ----------------------- | ---------- | --------------------------------------------------------------------------- | ---------------------------------- |
45
+ | InternVL-Chat-V1-5 | 2024.04.18 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) | support 4K image; super strong OCR; Approaching the performance of GPT-4V and Gemini Pro on various benchmarks like MMMU, DocVQA, ChartQA, MathVista, etc. (🔥new)|
46
+ | InternVL-Chat-V1-2-Plus | 2024.02.21 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2-Plus) | more SFT data and stronger |
47
+ | InternVL-Chat-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2) | scaling up LLM to 34B |
48
+ | InternVL-Chat-V1-1 | 2024.01.24 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-1) | support Chinese and stronger OCR |
49
+
50
+ ## Model Usage (Image Embeddings)
51
+
52
+ ```python
53
+ import torch
54
+ from PIL import Image
55
+ from transformers import AutoModel, CLIPImageProcessor
56
+
57
+ model = AutoModel.from_pretrained(
58
+ 'OpenGVLab/InternViT-300M-448px',
59
+ torch_dtype=torch.bfloat16,
60
+ low_cpu_mem_usage=True,
61
+ trust_remote_code=True).cuda().eval()
62
+
63
+ image = Image.open('./examples/image1.jpg').convert('RGB')
64
+
65
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-300M-448px')
66
+
67
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
68
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
69
+
70
+ outputs = model(pixel_values)
71
+ ```
72
+
73
+ ## Citation
74
+
75
+ If you find this project useful in your research, please consider citing:
76
+
77
+ ```BibTeX
78
+ @article{chen2023internvl,
79
+ title={InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks},
80
+ author={Chen, Zhe and Wu, Jiannan and Wang, Wenhai and Su, Weijie and Chen, Guo and Xing, Sen and Zhong, Muyan and Zhang, Qinglong and Zhu, Xizhou and Lu, Lewei and Li, Bin and Luo, Ping and Lu, Tong and Qiao, Yu and Dai, Jifeng},
81
+ journal={arXiv preprint arXiv:2312.14238},
82
+ year={2023}
83
+ }
84
+ @article{chen2024far,
85
+ title={How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites},
86
+ author={Chen, Zhe and Wang, Weiyun and Tian, Hao and Ye, Shenglong and Gao, Zhangwei and Cui, Erfei and Tong, Wenwen and Hu, Kongzhi and Luo, Jiapeng and Ma, Zheng and others},
87
+ journal={arXiv preprint arXiv:2404.16821},
88
+ year={2024}
89
+ }
90
+
91
+ ```
92
+
93
+
94
+ ## Acknowledgement
95
+
96
+ InternVL is built with reference to the code of the following projects: [OpenAI CLIP](https://github.com/openai/CLIP), [Open CLIP](https://github.com/mlfoundations/open_clip), [CLIP Benchmark](https://github.com/LAION-AI/CLIP_benchmark), [EVA](https://github.com/baaivision/EVA/tree/master), [InternImage](https://github.com/OpenGVLab/InternImage), [ViT-Adapter](https://github.com/czczup/ViT-Adapter), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [Transformers](https://github.com/huggingface/transformers), [DINOv2](https://github.com/facebookresearch/dinov2), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [Qwen-VL](https://github.com/QwenLM/Qwen-VL/tree/master/eval_mm), and [LLaVA-1.5](https://github.com/haotian-liu/LLaVA). Thanks for their awesome work!
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InternVisionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_intern_vit.InternVisionConfig",
7
+ "AutoModel": "modeling_intern_vit.InternVisionModel"
8
+ },
9
+ "attention_dropout": 0.0,
10
+ "drop_path_rate": 0.1,
11
+ "dropout": 0.0,
12
+ "hidden_act": "gelu",
13
+ "hidden_size": 1024,
14
+ "image_size": 448,
15
+ "initializer_factor": 1.0,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 4096,
18
+ "layer_norm_eps": 1e-06,
19
+ "model_type": "intern_vit_6b",
20
+ "norm_type": "layer_norm",
21
+ "num_attention_heads": 16,
22
+ "num_channels": 3,
23
+ "num_hidden_layers": 24,
24
+ "patch_size": 14,
25
+ "qk_normalization": false,
26
+ "qkv_bias": true,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.37.2",
29
+ "use_flash_attn": true
30
+ }
configuration_intern_vit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = 'intern_vit_6b'
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act='gelu',
76
+ norm_type='rms_norm',
77
+ layer_norm_eps=1e-6,
78
+ dropout=0.0,
79
+ drop_path_rate=0.0,
80
+ attention_dropout=0.0,
81
+ initializer_range=0.02,
82
+ initializer_factor=0.1,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.hidden_size = hidden_size
88
+ self.intermediate_size = intermediate_size
89
+ self.dropout = dropout
90
+ self.drop_path_rate = drop_path_rate
91
+ self.num_hidden_layers = num_hidden_layers
92
+ self.num_attention_heads = num_attention_heads
93
+ self.num_channels = num_channels
94
+ self.patch_size = patch_size
95
+ self.image_size = image_size
96
+ self.initializer_range = initializer_range
97
+ self.initializer_factor = initializer_factor
98
+ self.attention_dropout = attention_dropout
99
+ self.layer_norm_eps = layer_norm_eps
100
+ self.hidden_act = hidden_act
101
+ self.norm_type = norm_type
102
+ self.qkv_bias = qkv_bias
103
+ self.qk_normalization = qk_normalization
104
+ self.use_flash_attn = use_flash_attn
105
+
106
+ @classmethod
107
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
108
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
109
+
110
+ if 'vision_config' in config_dict:
111
+ config_dict = config_dict['vision_config']
112
+
113
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
114
+ logger.warning(
115
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
116
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
117
+ )
118
+
119
+ return cls.from_dict(config_dict, **kwargs)
flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22f87624b94d869df9a619445357660b4c290970ba977f14e0a1ea2b46da4fda
3
+ size 608059320
modeling_intern_vit.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ from .flash_attention import FlashAttention
24
+ has_flash_attn = True
25
+ except:
26
+ print('FlashAttention is not installed.')
27
+ has_flash_attn = False
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
+ except ImportError:
54
+ # using the normal InternRMSNorm
55
+ pass
56
+ except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
+ pass
59
+
60
+
61
+ NORM2FN = {
62
+ 'rms_norm': InternRMSNorm,
63
+ 'layer_norm': nn.LayerNorm,
64
+ }
65
+
66
+
67
+ class InternVisionEmbeddings(nn.Module):
68
+ def __init__(self, config: InternVisionConfig):
69
+ super().__init__()
70
+ self.config = config
71
+ self.embed_dim = config.hidden_size
72
+ self.image_size = config.image_size
73
+ self.patch_size = config.patch_size
74
+
75
+ self.class_embedding = nn.Parameter(
76
+ torch.randn(1, 1, self.embed_dim),
77
+ )
78
+
79
+ self.patch_embedding = nn.Conv2d(
80
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
81
+ )
82
+
83
+ self.num_patches = (self.image_size // self.patch_size) ** 2
84
+ self.num_positions = self.num_patches + 1
85
+
86
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
87
+
88
+ def _get_pos_embed(self, pos_embed, H, W):
89
+ target_dtype = pos_embed.dtype
90
+ pos_embed = pos_embed.float().reshape(
91
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
92
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
93
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
94
+ return pos_embed
95
+
96
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
97
+ target_dtype = self.patch_embedding.weight.dtype
98
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
99
+ batch_size, _, height, width = patch_embeds.shape
100
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
101
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
102
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
103
+ position_embedding = torch.cat([
104
+ self.position_embedding[:, :1, :],
105
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
106
+ ], dim=1)
107
+ embeddings = embeddings + position_embedding.to(target_dtype)
108
+ return embeddings
109
+
110
+
111
+ class InternAttention(nn.Module):
112
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
113
+
114
+ def __init__(self, config: InternVisionConfig):
115
+ super().__init__()
116
+ self.config = config
117
+ self.embed_dim = config.hidden_size
118
+ self.num_heads = config.num_attention_heads
119
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
120
+ if config.use_flash_attn and not has_flash_attn:
121
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
122
+ self.head_dim = self.embed_dim // self.num_heads
123
+ if self.head_dim * self.num_heads != self.embed_dim:
124
+ raise ValueError(
125
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
126
+ f' {self.num_heads}).'
127
+ )
128
+
129
+ self.scale = self.head_dim ** -0.5
130
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
131
+ self.attn_drop = nn.Dropout(config.attention_dropout)
132
+ self.proj_drop = nn.Dropout(config.dropout)
133
+
134
+ self.qk_normalization = config.qk_normalization
135
+
136
+ if self.qk_normalization:
137
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
138
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
139
+
140
+ if self.use_flash_attn:
141
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
142
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
143
+
144
+ def _naive_attn(self, x):
145
+ B, N, C = x.shape
146
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
147
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ if self.qk_normalization:
150
+ B_, H_, N_, D_ = q.shape
151
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
152
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
153
+
154
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
155
+ attn = attn.softmax(dim=-1)
156
+ attn = self.attn_drop(attn)
157
+
158
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159
+ x = self.proj(x)
160
+ x = self.proj_drop(x)
161
+ return x
162
+
163
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
164
+ qkv = self.qkv(x)
165
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
166
+
167
+ if self.qk_normalization:
168
+ q, k, v = qkv.unbind(2)
169
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
170
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
171
+ qkv = torch.stack([q, k, v], dim=2)
172
+
173
+ context, _ = self.inner_attn(
174
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
175
+ )
176
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
177
+ outs = self.proj_drop(outs)
178
+ return outs
179
+
180
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
182
+ return x
183
+
184
+
185
+ class InternMLP(nn.Module):
186
+ def __init__(self, config: InternVisionConfig):
187
+ super().__init__()
188
+ self.config = config
189
+ self.act = ACT2FN[config.hidden_act]
190
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
191
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
192
+
193
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194
+ hidden_states = self.fc1(hidden_states)
195
+ hidden_states = self.act(hidden_states)
196
+ hidden_states = self.fc2(hidden_states)
197
+ return hidden_states
198
+
199
+
200
+ class InternVisionEncoderLayer(nn.Module):
201
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
202
+ super().__init__()
203
+ self.embed_dim = config.hidden_size
204
+ self.intermediate_size = config.intermediate_size
205
+ self.norm_type = config.norm_type
206
+
207
+ self.attn = InternAttention(config)
208
+ self.mlp = InternMLP(config)
209
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
210
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
211
+
212
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
213
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
214
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
215
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.Tensor,
220
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
221
+ """
222
+ Args:
223
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
224
+ """
225
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
226
+
227
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
228
+
229
+ return hidden_states
230
+
231
+
232
+ class InternVisionEncoder(nn.Module):
233
+ """
234
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
235
+ [`InternEncoderLayer`].
236
+
237
+ Args:
238
+ config (`InternConfig`):
239
+ The corresponding vision configuration for the `InternEncoder`.
240
+ """
241
+
242
+ def __init__(self, config: InternVisionConfig):
243
+ super().__init__()
244
+ self.config = config
245
+ # stochastic depth decay rule
246
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
247
+ self.layers = nn.ModuleList([
248
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
249
+ self.gradient_checkpointing = True
250
+
251
+ def forward(
252
+ self,
253
+ inputs_embeds,
254
+ output_hidden_states: Optional[bool] = None,
255
+ return_dict: Optional[bool] = None,
256
+ ) -> Union[Tuple, BaseModelOutput]:
257
+ r"""
258
+ Args:
259
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
260
+ Embedded representation of the inputs. Should be float, not int tokens.
261
+ output_hidden_states (`bool`, *optional*):
262
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
263
+ for more detail.
264
+ return_dict (`bool`, *optional*):
265
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
266
+ """
267
+ output_hidden_states = (
268
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269
+ )
270
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271
+
272
+ encoder_states = () if output_hidden_states else None
273
+ hidden_states = inputs_embeds
274
+
275
+ for idx, encoder_layer in enumerate(self.layers):
276
+ if output_hidden_states:
277
+ encoder_states = encoder_states + (hidden_states,)
278
+ if self.gradient_checkpointing and self.training:
279
+ layer_outputs = torch.utils.checkpoint.checkpoint(
280
+ encoder_layer,
281
+ hidden_states)
282
+ else:
283
+ layer_outputs = encoder_layer(
284
+ hidden_states,
285
+ )
286
+ hidden_states = layer_outputs
287
+
288
+ if output_hidden_states:
289
+ encoder_states = encoder_states + (hidden_states,)
290
+
291
+ if not return_dict:
292
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
293
+ return BaseModelOutput(
294
+ last_hidden_state=hidden_states, hidden_states=encoder_states
295
+ )
296
+
297
+
298
+ class InternVisionModel(PreTrainedModel):
299
+ main_input_name = 'pixel_values'
300
+ config_class = InternVisionConfig
301
+ _no_split_modules = ['InternVisionEncoderLayer']
302
+
303
+ def __init__(self, config: InternVisionConfig):
304
+ super().__init__(config)
305
+ self.config = config
306
+
307
+ self.embeddings = InternVisionEmbeddings(config)
308
+ self.encoder = InternVisionEncoder(config)
309
+
310
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
311
+ pos_emb = self.embeddings.position_embedding
312
+ _, num_positions, embed_dim = pos_emb.shape
313
+ cls_emb = pos_emb[:, :1, :]
314
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
315
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
316
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
317
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
318
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
319
+ self.embeddings.image_size = new_size
320
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
321
+
322
+ def get_input_embeddings(self):
323
+ return self.embeddings
324
+
325
+ def forward(
326
+ self,
327
+ pixel_values: Optional[torch.FloatTensor] = None,
328
+ output_hidden_states: Optional[bool] = None,
329
+ return_dict: Optional[bool] = None,
330
+ pixel_embeds: Optional[torch.FloatTensor] = None,
331
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
332
+ output_hidden_states = (
333
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
334
+ )
335
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
336
+
337
+ if pixel_values is None and pixel_embeds is None:
338
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
339
+
340
+ if pixel_embeds is not None:
341
+ hidden_states = pixel_embeds
342
+ else:
343
+ if len(pixel_values.shape) == 4:
344
+ hidden_states = self.embeddings(pixel_values)
345
+ else:
346
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
347
+ encoder_outputs = self.encoder(
348
+ inputs_embeds=hidden_states,
349
+ output_hidden_states=output_hidden_states,
350
+ return_dict=return_dict,
351
+ )
352
+ last_hidden_state = encoder_outputs.last_hidden_state
353
+ pooled_output = last_hidden_state[:, 0, :]
354
+
355
+ if not return_dict:
356
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
357
+
358
+ return BaseModelOutputWithPooling(
359
+ last_hidden_state=last_hidden_state,
360
+ pooler_output=pooled_output,
361
+ hidden_states=encoder_outputs.hidden_states,
362
+ attentions=encoder_outputs.attentions,
363
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 448,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.485,
9
+ 0.456,
10
+ 0.406
11
+ ],
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 3,
18
+ "size": 448
19
+ }
triton-test.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from triton_flash_atn import _attention
3
+
4
+ # Define dimensions
5
+ batch_size = 2
6
+ num_heads = 4
7
+ seq_len = 128
8
+ head_dim = 64
9
+
10
+ # Create random input tensors for Q, K, V
11
+ q = torch.randn(batch_size, num_heads, seq_len, head_dim,
12
+ dtype=torch.float16, device='cuda')
13
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim,
14
+ dtype=torch.float16, device='cuda')
15
+ v = torch.randn(batch_size, num_heads, seq_len, head_dim,
16
+ dtype=torch.float16, device='cuda')
17
+
18
+ # Define whether the attention is causal and the scaling factor
19
+ causal = False
20
+ sm_scale = 1.0 / (head_dim ** 0.5)
21
+
22
+ # Apply flash attention
23
+ attention = _attention.apply
24
+ output = attention(q, k, v, causal, sm_scale)
25
+
26
+ print(output)
triton_bert_pading.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices,
19
+ "z -> z d", d=second_dim)
20
+ ).reshape(-1, *other_shape)
21
+
22
+ @staticmethod
23
+ def backward(ctx, grad_output):
24
+ (indices,) = ctx.saved_tensors
25
+ assert grad_output.ndim >= 2
26
+ other_shape = grad_output.shape[1:]
27
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
28
+ grad_input = torch.zeros(
29
+ [ctx.first_axis_dim, grad_output.shape[1]],
30
+ device=grad_output.device,
31
+ dtype=grad_output.dtype,
32
+ )
33
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
34
+ # grad_input[indices] = grad_output
35
+ grad_input.scatter_(0, repeat(indices, "z -> z d",
36
+ d=grad_output.shape[1]), grad_output)
37
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
38
+
39
+
40
+ index_first_axis = IndexFirstAxis.apply
41
+
42
+
43
+ class IndexPutFirstAxis(torch.autograd.Function):
44
+ @staticmethod
45
+ def forward(ctx, values, indices, first_axis_dim):
46
+ ctx.save_for_backward(indices)
47
+ assert indices.ndim == 1
48
+ assert values.ndim >= 2
49
+ output = torch.zeros(
50
+ first_axis_dim, *
51
+ values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(
93
+ indices.shape[0], *((1,) * (grad_output.ndim - 1)))
94
+ indices = indices.expand_as(grad_output)
95
+ grad_input.scatter_add_(0, indices, grad_output)
96
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
97
+
98
+
99
+ index_first_axis_residual = IndexFirstAxisResidual.apply
100
+
101
+
102
+ def unpad_input(hidden_states, attention_mask):
103
+ """
104
+ Arguments:
105
+ hidden_states: (batch, seqlen, ...)
106
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
107
+ Return:
108
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
109
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
110
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
111
+ max_seqlen_in_batch: int
112
+ """
113
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
115
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
116
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
117
+ dtype=torch.torch.int32), (1, 0))
118
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
119
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
120
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
121
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
122
+ # so we write custom forward and backward to make it a bit faster.
123
+ return (
124
+ index_first_axis(
125
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
126
+ indices,
127
+ cu_seqlens,
128
+ max_seqlen_in_batch,
129
+ )
130
+
131
+
132
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
133
+ """
134
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
135
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
136
+
137
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
138
+ ```
139
+ [
140
+ [2, 3, 0, 0, 0, 0],
141
+ [3, 2, 0, 0, 0, 0],
142
+ [6, 0, 0, 0, 0, 0]
143
+ ]
144
+ ```
145
+ , which refers to the 3D-attention mask:
146
+ ```
147
+ [
148
+ [
149
+ [1, 0, 0, 0, 0, 0],
150
+ [1, 1, 0, 0, 0, 0],
151
+ [0, 0, 1, 0, 0, 0],
152
+ [0, 0, 1, 1, 0, 0],
153
+ [0, 0, 1, 1, 1, 0],
154
+ [0, 0, 0, 0, 0, 1]
155
+ ],
156
+ [
157
+ [1, 0, 0, 0, 0, 0],
158
+ [1, 1, 0, 0, 0, 0],
159
+ [1, 1, 1, 0, 0, 0],
160
+ [0, 0, 0, 1, 0, 0],
161
+ [0, 0, 0, 1, 1, 0],
162
+ [0, 0, 0, 0, 0, 1]
163
+ ],
164
+ [
165
+ [1, 0, 0, 0, 0, 0],
166
+ [1, 1, 0, 0, 0, 0],
167
+ [1, 1, 1, 0, 0, 0],
168
+ [1, 1, 1, 1, 0, 0],
169
+ [1, 1, 1, 1, 1, 0],
170
+ [1, 1, 1, 1, 1, 1]
171
+ ]
172
+ ]
173
+ ```.
174
+
175
+ Arguments:
176
+ hidden_states: (batch, seqlen, ...)
177
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
178
+ Return:
179
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
180
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
181
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
182
+ max_seqlen_in_batch: int
183
+ """
184
+ length = attention_mask_in_length.sum(dim=-1)
185
+ seqlen = attention_mask_in_length.size(-1)
186
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(
187
+ len(length), seqlen) < length.unsqueeze(1)
188
+ real_indices_idx = torch.nonzero(
189
+ attention_mask_in_length.flatten(), as_tuple=False).flatten()
190
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
191
+ indices = torch.nonzero(attention_mask_2d.flatten(),
192
+ as_tuple=False).flatten()
193
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
194
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
195
+ dtype=torch.torch.int32), (1, 0))
196
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
197
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
198
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
199
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
200
+ # so we write custom forward and backward to make it a bit faster.
201
+ return (
202
+ index_first_axis(
203
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
204
+ indices,
205
+ cu_seqlens,
206
+ max_seqlen_in_batch,
207
+ )
208
+
209
+
210
+ def pad_input(hidden_states, indices, batch, seqlen):
211
+ """
212
+ Arguments:
213
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
214
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
215
+ batch: int, batch size for the padded sequence.
216
+ seqlen: int, maximum sequence length for the padded sequence.
217
+ Return:
218
+ hidden_states: (batch, seqlen, ...)
219
+ """
220
+ dim = hidden_states.shape[-1]
221
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
222
+ # output[indices] = hidden_states
223
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
224
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
triton_flash_atn.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fused Attention
3
+ ===============
4
+
5
+ This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6
+ Credits: OpenAI kernel team
7
+
8
+ Extra Credits:
9
+ - Original flash attention paper (https://arxiv.org/abs/2205.14135)
10
+ - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
11
+
12
+ """
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ # Pick the fp8 data type
21
+
22
+ # AMD E4M3B8
23
+ # Note: When picking this f8 data type, scaling is required when using f8
24
+ # for the second gemm
25
+ # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
+
27
+ # AMD E5M2B16
28
+ TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
29
+
30
+
31
+ @triton.jit
32
+ def _attn_fwd_inner(acc, l_i, m_i, q,
33
+ K_block_ptr, V_block_ptr,
34
+ start_m,
35
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
36
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
37
+ N_CTX,
38
+ pre_load_v: tl.constexpr):
39
+ # range of values handled by this stage
40
+ if STAGE == 1:
41
+ lo, hi = 0, start_m * BLOCK_M
42
+ elif STAGE == 2:
43
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
44
+ lo = tl.multiple_of(lo, BLOCK_M)
45
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
46
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
47
+ # causal = False
48
+ else:
49
+ lo, hi = 0, N_CTX
50
+ # loop over k, v and update accumulator
51
+ for start_n in range(lo, hi, BLOCK_N):
52
+ start_n = tl.multiple_of(start_n, BLOCK_N)
53
+ # -- compute qk ----
54
+ k = tl.load(K_block_ptr)
55
+ if pre_load_v:
56
+ v = tl.load(V_block_ptr)
57
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
58
+ if STAGE == 2:
59
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
60
+ qk = tl.where(mask, qk, float("-inf"))
61
+ qk += tl.dot(q, k)
62
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
63
+ qk = qk - m_ij[:, None]
64
+ p = tl.math.exp2(qk)
65
+ # -- update output accumulator --
66
+ alpha = tl.math.exp2(m_i - m_ij)
67
+ acc = acc * alpha[:, None]
68
+ if not pre_load_v:
69
+ v = tl.load(V_block_ptr)
70
+ acc += tl.dot(p.to(v.dtype), v)
71
+ # -- update m_i and l_i
72
+ l_ij = tl.sum(p, 1)
73
+ l_i = l_i * alpha + l_ij
74
+ # update m_i and l_i
75
+ m_i = m_ij
76
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
77
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
78
+ return acc, l_i, m_i
79
+
80
+
81
+ # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
82
+ # the code below and commenting out the equivalent parameters is convenient for
83
+ # re-tuning.
84
+ @triton.autotune(
85
+ configs=[
86
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
87
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
88
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
89
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
90
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
91
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
92
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
93
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
94
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
95
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
96
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
97
+ 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
98
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
99
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
100
+ ],
101
+ key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
102
+ )
103
+ @triton.jit
104
+ def _attn_fwd(Q, K, V, sm_scale, M, Out,
105
+ stride_qz, stride_qh, stride_qm, stride_qk,
106
+ stride_kz, stride_kh, stride_kn, stride_kk,
107
+ stride_vz, stride_vh, stride_vk, stride_vn,
108
+ stride_oz, stride_oh, stride_om, stride_on,
109
+ Z, H,
110
+ N_CTX,
111
+ BLOCK_DMODEL: tl.constexpr,
112
+ STAGE: tl.constexpr,
113
+ BLOCK_M: tl.constexpr,
114
+ BLOCK_N: tl.constexpr,
115
+ pre_load_v: tl.constexpr,
116
+ ):
117
+ start_m = tl.program_id(0)
118
+ off_hz = tl.program_id(1)
119
+ qvk_offset = off_hz * stride_qh
120
+
121
+ # block pointers
122
+ Q_block_ptr = tl.make_block_ptr(
123
+ base=Q + qvk_offset,
124
+ shape=(N_CTX, BLOCK_DMODEL),
125
+ strides=(stride_qm, stride_qk),
126
+ offsets=(start_m * BLOCK_M, 0),
127
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
128
+ order=(1, 0),
129
+ )
130
+ V_block_ptr = tl.make_block_ptr(
131
+ base=V + qvk_offset,
132
+ shape=(N_CTX, BLOCK_DMODEL),
133
+ strides=(stride_vk, stride_vn),
134
+ offsets=(0, 0),
135
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
136
+ order=(1, 0),
137
+ )
138
+ K_block_ptr = tl.make_block_ptr(
139
+ base=K + qvk_offset,
140
+ shape=(BLOCK_DMODEL, N_CTX),
141
+ strides=(stride_kk, stride_kn),
142
+ offsets=(0, 0),
143
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
144
+ order=(0, 1),
145
+ )
146
+ O_block_ptr = tl.make_block_ptr(
147
+ base=Out + qvk_offset,
148
+ shape=(N_CTX, BLOCK_DMODEL),
149
+ strides=(stride_om, stride_on),
150
+ offsets=(start_m * BLOCK_M, 0),
151
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
152
+ order=(1, 0),
153
+ )
154
+ # initialize offsets
155
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
156
+ offs_n = tl.arange(0, BLOCK_N)
157
+ # initialize pointer to m and l
158
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
159
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
160
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
161
+ # scale sm_scale by log_2(e) and use
162
+ # 2^x instead of exp in the loop because CSE and LICM
163
+ # don't work as expected with `exp` in the loop
164
+ qk_scale = sm_scale * 1.44269504
165
+ # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
166
+ q = tl.load(Q_block_ptr)
167
+ q = (q * qk_scale).to(q.dtype)
168
+ # stage 1: off-band
169
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
170
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
171
+ if STAGE & 1:
172
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
173
+ start_m,
174
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
175
+ 4 - STAGE, offs_m, offs_n, N_CTX,
176
+ pre_load_v,
177
+ )
178
+ # stage 2: on-band
179
+ if STAGE & 2:
180
+ # barrier makes it easier for compielr to schedule the
181
+ # two loops independently
182
+ tl.debug_barrier()
183
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
184
+ start_m,
185
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
186
+ 2, offs_m, offs_n, N_CTX,
187
+ pre_load_v,
188
+ )
189
+ # epilogue
190
+ # write back m
191
+ acc = acc / l_i[:, None]
192
+ m_ptrs = M + off_hz * N_CTX + offs_m
193
+ tl.store(m_ptrs, m_i + tl.math.log2(l_i))
194
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
195
+
196
+
197
+ @triton.jit
198
+ def _attn_bwd_preprocess(O, DO,
199
+ Delta,
200
+ Z, H, N_CTX,
201
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
202
+ ):
203
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
204
+ off_hz = tl.program_id(1)
205
+ off_n = tl.arange(0, D_HEAD)
206
+ o = tl.load(O + off_hz * D_HEAD * N_CTX +
207
+ off_m[:, None] * D_HEAD + off_n[None, :])
208
+ do = tl.load(DO + off_hz * D_HEAD * N_CTX +
209
+ off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
210
+ delta = tl.sum(o * do, axis=1)
211
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
212
+
213
+
214
+ # The main inner-loop logic for computing dK and dV.
215
+ @triton.jit
216
+ def _attn_bwd_dkdv(dk, dv,
217
+ Q, k, v, sm_scale,
218
+ DO,
219
+ M, D,
220
+ # shared by Q/K/V/DO.
221
+ stride_tok, stride_d,
222
+ H, N_CTX, BLOCK_M1: tl.constexpr,
223
+ BLOCK_N1: tl.constexpr,
224
+ BLOCK_DMODEL: tl.constexpr,
225
+ # Filled in by the wrapper.
226
+ start_n, start_m, num_steps,
227
+ MASK: tl.constexpr):
228
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
229
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
230
+ offs_k = tl.arange(0, BLOCK_DMODEL)
231
+ QT_block_ptr = tl.make_block_ptr(
232
+ base=Q,
233
+ shape=(BLOCK_DMODEL, N_CTX),
234
+ strides=(stride_d, stride_tok),
235
+ offsets=(0, start_m),
236
+ block_shape=(BLOCK_DMODEL, BLOCK_M1),
237
+ order=(0, 1)
238
+ )
239
+ DO_block_ptr = tl.make_block_ptr(
240
+ base=DO,
241
+ shape=(N_CTX, BLOCK_DMODEL),
242
+ strides=(stride_tok, stride_d),
243
+ offsets=(start_m, 0),
244
+ block_shape=(BLOCK_M1, BLOCK_DMODEL),
245
+ order=(1, 0)
246
+ )
247
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
248
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
249
+ curr_m = start_m
250
+ step_m = BLOCK_M1
251
+ for blk_idx in range(num_steps):
252
+ qT = tl.load(QT_block_ptr)
253
+ # Load m before computing qk to reduce pipeline stall.
254
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
255
+ m = tl.load(M + offs_m)
256
+ qkT = tl.dot(k, qT)
257
+ pT = tl.math.exp2(qkT - m[None, :])
258
+ # Autoregressive masking.
259
+ if MASK:
260
+ mask = (offs_m[None, :] >= offs_n[:, None])
261
+ pT = tl.where(mask, pT, 0.0)
262
+ do = tl.load(DO_block_ptr)
263
+ # Compute dV.
264
+ ppT = pT
265
+ ppT = ppT.to(tl.float16)
266
+ dv += tl.dot(ppT, do)
267
+ # D (= delta) is pre-divided by ds_scale.
268
+ Di = tl.load(D + offs_m)
269
+ # Compute dP and dS.
270
+ dpT = tl.dot(v, tl.trans(do))
271
+ dsT = pT * (dpT - Di[None, :])
272
+ dsT = dsT.to(tl.float16)
273
+ dk += tl.dot(dsT, tl.trans(qT))
274
+ # Increment pointers.
275
+ curr_m += step_m
276
+ QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
277
+ DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
278
+ return dk, dv
279
+
280
+
281
+ # the main inner-loop logic for computing dQ
282
+ @triton.jit
283
+ def _attn_bwd_dq(dq, q, K, V,
284
+ do, m, D,
285
+ # shared by Q/K/V/DO.
286
+ stride_tok, stride_d,
287
+ H, N_CTX,
288
+ BLOCK_M2: tl.constexpr,
289
+ BLOCK_N2: tl.constexpr,
290
+ BLOCK_DMODEL: tl.constexpr,
291
+ # Filled in by the wrapper.
292
+ start_m, start_n, num_steps,
293
+ MASK: tl.constexpr):
294
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
295
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
296
+ offs_k = tl.arange(0, BLOCK_DMODEL)
297
+ KT_block_ptr = tl.make_block_ptr(
298
+ base=K,
299
+ shape=(BLOCK_DMODEL, N_CTX),
300
+ strides=(stride_d, stride_tok),
301
+ offsets=(0, start_n),
302
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
303
+ order=(0, 1)
304
+ )
305
+ VT_block_ptr = tl.make_block_ptr(
306
+ base=V,
307
+ shape=(BLOCK_DMODEL, N_CTX),
308
+ strides=(stride_d, stride_tok),
309
+ offsets=(0, start_n),
310
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
311
+ order=(0, 1)
312
+ )
313
+ # D (= delta) is pre-divided by ds_scale.
314
+ Di = tl.load(D + offs_m)
315
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
316
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
317
+ curr_n = start_n
318
+ step_n = BLOCK_N2
319
+ for blk_idx in range(num_steps):
320
+ kT = tl.load(KT_block_ptr)
321
+ qk = tl.dot(q, kT)
322
+ p = tl.math.exp2(qk - m)
323
+ # Autoregressive masking.
324
+ if MASK:
325
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
326
+ mask = (offs_m[:, None] >= offs_n[None, :])
327
+ p = tl.where(mask, p, 0.0)
328
+ # Compute dP and dS.
329
+ vT = tl.load(VT_block_ptr)
330
+ dp = tl.dot(do, vT).to(tl.float32)
331
+ ds = p * (dp - Di[:, None])
332
+ ds = ds.to(tl.float16)
333
+ # Compute dQ.
334
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
335
+ dq += tl.dot(ds, tl.trans(kT))
336
+ # Increment pointers.
337
+ curr_n += step_n
338
+ KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
339
+ VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
340
+ return dq
341
+
342
+
343
+ @triton.autotune(
344
+ configs=[
345
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
346
+ num_stages=1, num_warps=4),
347
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
348
+ num_stages=1, num_warps=4),
349
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
350
+ num_stages=1, num_warps=4),
351
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
352
+ num_stages=1, num_warps=4),
353
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
354
+ num_stages=1, num_warps=4),
355
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
356
+ num_stages=1, num_warps=4),
357
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
358
+ num_stages=1, num_warps=4),
359
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
360
+ num_stages=1, num_warps=4),
361
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
362
+ num_stages=1, num_warps=8),
363
+ ],
364
+ key=['H', 'N_CTX', 'BLOCK_DMODEL'],
365
+ )
366
+ @triton.jit
367
+ def _attn_bwd(Q, K, V, sm_scale,
368
+ DO,
369
+ DQ, DK, DV,
370
+ M, D,
371
+ # shared by Q/K/V/DO.
372
+ stride_z, stride_h, stride_tok, stride_d,
373
+ # H = 16, N_CTX = 1024
374
+ H, N_CTX,
375
+ BLOCK_DMODEL: tl.constexpr,
376
+ BLOCK_M1: tl.constexpr,
377
+ BLOCK_N1: tl.constexpr,
378
+ BLOCK_M2: tl.constexpr,
379
+ BLOCK_N2: tl.constexpr,
380
+ BLK_SLICE_FACTOR: tl.constexpr):
381
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
382
+
383
+ bhid = tl.program_id(2)
384
+ off_chz = (bhid * N_CTX).to(tl.int64)
385
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
386
+ pid = tl.program_id(0)
387
+
388
+ # offset pointers for batch/head
389
+ Q += adj
390
+ K += adj
391
+ V += adj
392
+ DO += adj
393
+ DQ += adj
394
+ DK += adj
395
+ DV += adj
396
+ M += off_chz
397
+ D += off_chz
398
+
399
+ offs_k = tl.arange(0, BLOCK_DMODEL)
400
+
401
+ start_n = pid * BLOCK_N1
402
+ # This assignment is important. It is what allows us to pick the diagonal
403
+ # blocks. Later, when we want to do the lower triangular, we update start_m
404
+ # after the first dkdv call.
405
+ start_m = start_n
406
+
407
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
408
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
409
+
410
+ dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
411
+ dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
412
+
413
+ K_block_ptr = tl.make_block_ptr(
414
+ base=K,
415
+ shape=(N_CTX, BLOCK_DMODEL),
416
+ strides=(stride_tok, stride_d),
417
+ offsets=(start_n, 0),
418
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
419
+ order=(1, 0),
420
+ )
421
+ V_block_ptr = tl.make_block_ptr(
422
+ base=V,
423
+ shape=(N_CTX, BLOCK_DMODEL),
424
+ strides=(stride_tok, stride_d),
425
+ offsets=(start_n, 0),
426
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
427
+ order=(1, 0),
428
+ )
429
+
430
+ # load K and V: they stay in SRAM throughout the inner loop for dkdv.
431
+ k = tl.load(K_block_ptr)
432
+ v = tl.load(V_block_ptr)
433
+
434
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
435
+
436
+ dk, dv = _attn_bwd_dkdv(dk, dv,
437
+ Q, k, v, sm_scale,
438
+ DO,
439
+ M, D,
440
+ stride_tok, stride_d,
441
+ H, N_CTX,
442
+ MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
443
+ start_n, start_m, num_steps,
444
+ MASK=True
445
+ )
446
+
447
+ start_m += num_steps * MASK_BLOCK_M1
448
+ num_steps = (N_CTX - start_m) // BLOCK_M1
449
+
450
+ # Compute dK and dV for non-masked blocks.
451
+ dk, dv = _attn_bwd_dkdv(
452
+ dk, dv,
453
+ Q, k, v, sm_scale,
454
+ DO,
455
+ M, D,
456
+ stride_tok, stride_d,
457
+ H, N_CTX,
458
+ BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
459
+ start_n, start_m, num_steps,
460
+ MASK=False
461
+ )
462
+
463
+ DV_block_ptrs = tl.make_block_ptr(
464
+ base=DV,
465
+ shape=(N_CTX, BLOCK_DMODEL),
466
+ strides=(stride_tok, stride_d),
467
+ offsets=(start_n, 0),
468
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
469
+ order=(1, 0)
470
+ )
471
+ tl.store(DV_block_ptrs, dv.to(tl.float16))
472
+
473
+ # Write back dK.
474
+ dk *= sm_scale
475
+ DK_block_ptrs = tl.make_block_ptr(
476
+ base=DK,
477
+ shape=(N_CTX, BLOCK_DMODEL),
478
+ strides=(stride_tok, stride_d),
479
+ offsets=(start_n, 0),
480
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
481
+ order=(1, 0)
482
+ )
483
+ tl.store(DK_block_ptrs, dk.to(tl.float16))
484
+
485
+ # THIS BLOCK DOES DQ:
486
+ start_m = pid * BLOCK_M2
487
+ end_n = start_m + BLOCK_M2
488
+
489
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
490
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
491
+
492
+ Q_block_ptr = tl.make_block_ptr(
493
+ base=Q,
494
+ shape=(N_CTX, BLOCK_DMODEL),
495
+ strides=(stride_tok, stride_d),
496
+ offsets=(start_m, 0),
497
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
498
+ order=(1, 0)
499
+ )
500
+
501
+ DO_block_ptr = tl.make_block_ptr(
502
+ base=DO,
503
+ shape=(N_CTX, BLOCK_DMODEL),
504
+ strides=(stride_tok, stride_d),
505
+ offsets=(start_m, 0),
506
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
507
+ order=(1, 0)
508
+ )
509
+ q = tl.load(Q_block_ptr)
510
+ do = tl.load(DO_block_ptr)
511
+ dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
512
+
513
+ m = tl.load(M + offs_m)
514
+ m = m[:, None]
515
+
516
+ # Compute dQ for masked (diagonal) blocks.
517
+ # NOTE: This code scans each row of QK^T backward (from right to left,
518
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
519
+ # not due to anything important. I just wanted to reuse the loop
520
+ # structure for dK & dV above as much as possible.
521
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
522
+ dq = _attn_bwd_dq(dq, q, K, V,
523
+ do, m, D,
524
+ stride_tok, stride_d,
525
+ H, N_CTX,
526
+ BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
527
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
528
+ MASK=True
529
+ )
530
+ end_n -= num_steps * MASK_BLOCK_N2
531
+ # stage 2
532
+ num_steps = end_n // BLOCK_N2
533
+ dq = _attn_bwd_dq(dq, q, K, V,
534
+ do, m, D,
535
+ stride_tok, stride_d,
536
+ H, N_CTX,
537
+ BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
538
+ start_m, end_n - num_steps * BLOCK_N2, num_steps,
539
+ MASK=False
540
+ )
541
+ # Write back dQ.
542
+ DQ_block_ptr = tl.make_block_ptr(
543
+ base=DQ,
544
+ shape=(N_CTX, BLOCK_DMODEL),
545
+ strides=(stride_tok, stride_d),
546
+ offsets=(start_m, 0),
547
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
548
+ order=(1, 0)
549
+ )
550
+ dq *= LN2
551
+ tl.store(DQ_block_ptr, dq.to(tl.float16))
552
+
553
+
554
+ empty = torch.empty(128, device="cuda")
555
+
556
+
557
+ class _attention(torch.autograd.Function):
558
+
559
+ @staticmethod
560
+ def forward(ctx, q, k, v, causal, sm_scale):
561
+ # shape constraints
562
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
563
+ assert Lq == Lk and Lk == Lv
564
+ assert Lk in {16, 32, 64, 128}
565
+ o = torch.empty_like(q, dtype=v.dtype)
566
+ if torch.version.hip is None:
567
+ BLOCK_M = 128
568
+ BLOCK_N = 64 if Lk <= 64 else 32
569
+ num_stages = 4 if Lk <= 64 else 3
570
+ num_warps = 4 if Lk <= 64 else 8
571
+ # Tuning for H100
572
+ if torch.cuda.get_device_capability()[0] == 9:
573
+ num_warps = 8
574
+ num_stages = 7 if Lk >= 64 else 3
575
+ stage = 3 if causal else 1
576
+
577
+ def grid(META): return (
578
+ triton.cdiv(q.shape[2], META['BLOCK_M']),
579
+ q.shape[0] * q.shape[1],
580
+ 1
581
+ )
582
+ M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
583
+ device=q.device, dtype=torch.float32)
584
+ _attn_fwd[grid](
585
+ q, k, v, sm_scale, M, o,
586
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
587
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
588
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
589
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
590
+ q.shape[0], q.shape[1],
591
+ N_CTX=q.shape[2],
592
+ BLOCK_DMODEL=Lk,
593
+ STAGE=stage,
594
+ )
595
+
596
+ # restore the grid for bwd kernel
597
+ best_config = _attn_fwd.get_best_config()
598
+ block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
599
+ grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
600
+
601
+ ctx.save_for_backward(q, k, v, o, M)
602
+ ctx.grid = grid
603
+ ctx.sm_scale = sm_scale
604
+ ctx.BLOCK_DMODEL = Lk
605
+ ctx.causal = causal
606
+ return o
607
+
608
+ @staticmethod
609
+ def backward(ctx, do):
610
+ if torch.version.hip is not None:
611
+ BLOCK = 64
612
+ else:
613
+ BLOCK = 128
614
+ q, k, v, o, M = ctx.saved_tensors
615
+ assert do.is_contiguous()
616
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
617
+ dq = torch.empty_like(q)
618
+ dk = torch.empty_like(k)
619
+ dv = torch.empty_like(v)
620
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
621
+ PRE_BLOCK = 128
622
+ NUM_WARPS, NUM_STAGES = 4, 1
623
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
624
+ BLK_SLICE_FACTOR = 2
625
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
626
+ arg_k = k
627
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
628
+ assert N_CTX % PRE_BLOCK == 0
629
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
630
+ delta = torch.empty_like(M)
631
+ _attn_bwd_preprocess[pre_grid](
632
+ o, do,
633
+ delta,
634
+ BATCH, N_HEAD, N_CTX,
635
+ BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
636
+ )
637
+
638
+ def grid(META): return (
639
+ triton.cdiv(N_CTX, META['BLOCK_N1']),
640
+ 1,
641
+ BATCH * N_HEAD
642
+ )
643
+ _attn_bwd[grid](
644
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
645
+ M, delta,
646
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
647
+ N_HEAD, N_CTX,
648
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL
649
+ )
650
+
651
+ return dq, dk, dv, None, None
652
+
653
+
654
+ attention = _attention.apply