svjack commited on
Commit
f7a83c6
·
1 Parent(s): 647f002

Upload with huggingface_hub

Browse files
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from summary_reverse_pred_native import *
2
+ import gradio as gr
3
+ import os
4
+
5
+ text0 = "飓风格特是1993年9月在墨西哥和整个中美洲引发严重洪灾的大规模热带气旋,源于9月14日西南加勒比海上空一股东风波。次日从尼加拉瓜登岸,经过洪都拉斯后于9月17日在洪都拉斯湾再次达到热带风暴标准,但次日进入伯利兹上空后就减弱成热带低气压。穿过尤卡坦半岛后,在9月20日强化成二级飓风,从韦拉克鲁斯州的图斯潘附近登陆墨西哥。9月21日从纳亚里特州进入太平洋时已降级成热带低气压,最终于5天后在开放水域上空消散。"
6
+ text1 = "珊瑚坝是长江中的一处河漫滩,位于长江重庆市渝中区区段主航道左侧[1],靠近渝中半岛,原分属重庆市市中区菜园坝街道和石板坡街道[2],现属渝中区菜园坝街道石板坡社区[3],是长江上游缓冲地段自然冲积沙洲,略呈纺锤形[4]或椭圆形,长约1800米,宽约600米,坝上遍布鹅卵石和水草。每年夏季洪水时均被淹没,其余时间常露水面,枯水期则与长江左岸相连[5]。"
7
+
8
+ example_sample = [
9
+ [text0, False],
10
+ [text1, False],
11
+ ]
12
+
13
+ def demo_func(prefix, do_sample):
14
+ l = simple_pred(prefix, do_sample = do_sample)
15
+ return {
16
+ "Dialogue Context": l
17
+ }
18
+
19
+ demo = gr.Interface(
20
+ fn=demo_func,
21
+ inputs=[gr.Text(label = "Context"),
22
+ gr.Checkbox(label="do sample"),
23
+ ],
24
+ outputs="json",
25
+ title=f"Chinese Context Dialogue Generator 🐰 demonstration",
26
+ examples=example_sample if example_sample else None,
27
+ cache_examples = False
28
+ )
29
+
30
+ demo.launch(server_name=None, server_port=None)
component/argument.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ @dataclass
4
+ class CaptionArguments:
5
+ """
6
+ 自定义的一些参数
7
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
8
+ """
9
+ max_seq_length: int = field(metadata={"help": "输入最大长度"})
10
+ train_caption_file: str = field(metadata={"help": "训练集"})
11
+ train_image_file: str = field(metadata={"help": "训练集"})
12
+ test_caption_file: str = field(metadata={"help": "测试集"})
13
+ test_image_file: str = field(metadata={"help": "测试集"})
14
+ model_name_or_path: str = field(metadata={"help": "预训练权重路径"})
15
+ freeze_encoder: bool = field(metadata={"help": "是否将encoder的权重冻结,仅对decoder进行finetune"})
16
+ freeze_word_embed: bool = field(
17
+ metadata={"help": "是否将encoder的词向量的权重冻结,由于OFA模型的enocder与decoder共享词向量权重,所以freeze_encoder会将词向量冻结。当freeze_word_embed=False时,词向量会一起训练"})
18
+
component/datacollator.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import torch
3
+
4
+
5
+ class CaptionCollator(object):
6
+ def __init__(self, tokenizer, max_seq_length):
7
+ self.tokenizer = tokenizer
8
+ self.max_seq_length = max_seq_length
9
+
10
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
11
+ captions, patch_images = [], []
12
+ for data in features:
13
+ # 如果图片预处理失败,则跳过该图片
14
+ if data['patch_image'] is None:
15
+ continue
16
+ captions.append(data['caption'])
17
+ patch_images.append(data['patch_image'])
18
+ # 获得encoder的输入
19
+ input_ids = self.tokenizer(
20
+ ['图片描述了什么?']*len(captions), return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True
21
+ ).input_ids
22
+ patch_images = torch.concat(patch_images, dim=0)
23
+
24
+ # 获得decoder的输入
25
+ inputs = self.tokenizer(
26
+ captions, return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True
27
+ )
28
+ decoder_input_ids = inputs.input_ids
29
+ attention_mask = inputs.attention_mask
30
+
31
+ inputs = {
32
+ 'input_ids': input_ids,
33
+ 'patch_images': patch_images,
34
+ 'decoder_input_ids': decoder_input_ids,
35
+ 'attention_mask': attention_mask,
36
+ 'return_loss': True
37
+ }
38
+ return inputs
component/dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from torch.utils.data import Dataset
4
+ from tqdm import tqdm
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from loguru import logger
10
+
11
+
12
+ class CaptionDataset(Dataset):
13
+
14
+ def __init__(self, caption_file, image_file):
15
+ logger.info('loading data from:{} and {}'.format(caption_file, image_file))
16
+ # 读取每个图片的内容
17
+ image_id2content = {}
18
+ with open(image_file, 'r', encoding='utf8') as f:
19
+ lines = f.readlines()
20
+ for line in tqdm(lines):
21
+ image_id, image_content = line.split('\t')
22
+ image_id2content[image_id] = image_content
23
+
24
+ # 读取每个图片的所有caption,得到所有训练数据
25
+ data_list = []
26
+ with open(caption_file, 'r', encoding='utf8') as f:
27
+ lines = f.readlines()
28
+ for line in tqdm(lines):
29
+ line = json.loads(line)
30
+ image_id = line['image_id']
31
+ captions = line['text']
32
+ for caption in captions:
33
+ data = {'caption': caption, 'image_base64': image_id2content[image_id], 'image_id': image_id}
34
+ data_list.append(data)
35
+
36
+ logger.info('len of data:{}'.format(len(data_list)))
37
+ self.data_list = data_list
38
+
39
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
40
+ resolution = 256
41
+ patch_resize_transform = transforms.Compose([
42
+ lambda image: image.convert("RGB"),
43
+ transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=mean, std=std)
46
+ ])
47
+ self.patch_resize_transform = patch_resize_transform
48
+
49
+ def __len__(self):
50
+ return len(self.data_list)
51
+
52
+ def __getitem__(self, index):
53
+ row = self.data_list[index]
54
+ caption = row['caption'].strip()
55
+ image_base64 = row['image_base64']
56
+ image_id = row['image_id']
57
+
58
+ # 加载图片,并进行预处理
59
+ try:
60
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64)))
61
+ patch_image = self.patch_resize_transform(image).unsqueeze(0)
62
+ except Exception as e:
63
+ # 图片加载失败
64
+ logger.info('open image error, image_id: {}'.format(image_id))
65
+ logger.info(e)
66
+ patch_image = None
67
+
68
+ data = {'patch_image': patch_image, 'caption': caption}
69
+ return data
component/ofa/configuration_ofa.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ OFA model configuration"""
16
+ import warnings
17
+ from transformers import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "OFA-Sys/OFA-tiny": "https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/config.json",
25
+ "OFA-Sys/OFA-medium": "https://huggingface.co/OFA-Sys/OFA-medium/blob/main/config.json",
26
+ "OFA-Sys/OFA-base": "https://huggingface.co/OFA-Sys/OFA-base/blob/main/config.json",
27
+ "OFA-Sys/OFA-large": "https://huggingface.co/OFA-Sys/OFA-large/blob/main/config.json",
28
+ # See all OFA models at https://huggingface.co/models?filter=ofa
29
+ }
30
+
31
+
32
+ class OFAConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base)
37
+ architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 50265):
45
+ Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the
46
+ `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`].
47
+ d_model (`int`, *optional*, defaults to 1024):
48
+ Dimension of the layers and the pooler layer.
49
+ encoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of encoder layers.
51
+ decoder_layers (`int`, *optional*, defaults to 12):
52
+ Number of decoder layers.
53
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
56
+ Number of attention heads for each attention layer in the Transformer decoder.
57
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
59
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
60
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
61
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
62
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
63
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
64
+ dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for the attention probabilities.
68
+ activation_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for activations inside the fully connected layer.
70
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
71
+ The dropout ratio for classifier.
72
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
73
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
74
+ just in case (e.g., 512 or 1024 or 2048).
75
+ init_std (`float`, *optional*, defaults to 0.02):
76
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
77
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
78
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
79
+ for more details.
80
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
81
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
82
+ for more details.
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ """
86
+
87
+ model_type = "ofa"
88
+ keys_to_ignore_at_inference = ["past_key_values"]
89
+
90
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_size=59457,
95
+ max_position_embeddings=1024,
96
+ encoder_layers=4,
97
+ encoder_ffn_dim=512 * 4,
98
+ encoder_attention_heads=8,
99
+ decoder_layers=4,
100
+ decoder_ffn_dim=512 * 4,
101
+ decoder_attention_heads=8,
102
+ encoder_layerdrop=0.0,
103
+ decoder_layerdrop=0.0,
104
+ use_cache=True,
105
+ is_encoder_decoder=True,
106
+ activation_function="gelu",
107
+ d_model=512,
108
+ dropout=0.1,
109
+ attention_dropout=0.0,
110
+ activation_dropout=0.0,
111
+ init_std=0.02,
112
+ classifier_dropout=0.0,
113
+ scale_embedding=False,
114
+ pad_token_id=1,
115
+ bos_token_id=0,
116
+ decoder_start_token_id=0,
117
+ eos_token_id=2,
118
+ forced_eos_token_id=2,
119
+ encoder_normalize_before=True,
120
+ decoder_normalize_before=True,
121
+ normformer=True,
122
+ encoder_drop_path_rate=0.0,
123
+ decoder_drop_path_rate=0.0,
124
+ layernorm_embedding=True,
125
+ patch_layernorm_embedding=True,
126
+ resnet_type="resnet101",
127
+ resnet_model_path=None,
128
+ resnet_drop_path_rate=0.0,
129
+ token_bucket_size=256,
130
+ image_bucket_size=42,
131
+ add_type_embedding=True,
132
+ share_decoder_input_output_embed=True,
133
+ attn_scale_factor=2.0,
134
+ code_layernorm_embedding=True,
135
+ code_image_size=128,
136
+ entangle_position_embedding=False,
137
+ **kwargs
138
+ ):
139
+ self.vocab_size = vocab_size
140
+ self.max_position_embeddings = max_position_embeddings
141
+ self.d_model = d_model
142
+ self.encoder_ffn_dim = encoder_ffn_dim
143
+ self.encoder_layers = encoder_layers
144
+ self.encoder_attention_heads = encoder_attention_heads
145
+ self.decoder_ffn_dim = decoder_ffn_dim
146
+ self.decoder_layers = decoder_layers
147
+ self.decoder_attention_heads = decoder_attention_heads
148
+ self.dropout = dropout
149
+ self.attention_dropout = attention_dropout
150
+ self.activation_dropout = activation_dropout
151
+ self.activation_function = activation_function
152
+ self.init_std = init_std
153
+ self.encoder_layerdrop = encoder_layerdrop
154
+ self.decoder_layerdrop = decoder_layerdrop
155
+ self.classifier_dropout = classifier_dropout
156
+ self.use_cache = use_cache
157
+ self.num_hidden_layers = encoder_layers
158
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
159
+ self.encoder_normalize_before = encoder_normalize_before
160
+ self.decoder_normalize_before = decoder_normalize_before
161
+ self.normformer = normformer
162
+ self.encoder_drop_path_rate = encoder_drop_path_rate
163
+ self.decoder_drop_path_rate = decoder_drop_path_rate
164
+ self.layernorm_embedding = layernorm_embedding
165
+ self.patch_layernorm_embedding = patch_layernorm_embedding
166
+ self.resnet_type = resnet_type
167
+ self.resnet_model_path = resnet_model_path
168
+ self.resnet_drop_path_rate = resnet_drop_path_rate
169
+ self.token_bucket_size = token_bucket_size
170
+ self.image_bucket_size = image_bucket_size
171
+ self.add_type_embedding = add_type_embedding
172
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
173
+ self.attn_scale_factor = attn_scale_factor
174
+ self.code_layernorm_embedding = code_layernorm_embedding
175
+ self.code_image_size = code_image_size
176
+ self.entangle_position_embedding = entangle_position_embedding
177
+
178
+ super().__init__(
179
+ pad_token_id=pad_token_id,
180
+ bos_token_id=bos_token_id,
181
+ eos_token_id=eos_token_id,
182
+ is_encoder_decoder=is_encoder_decoder,
183
+ decoder_start_token_id=bos_token_id,
184
+ forced_eos_token_id=forced_eos_token_id,
185
+ **kwargs,
186
+ )
187
+
188
+ # ensure backward compatibility for BART CNN models
189
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
190
+ self.forced_bos_token_id = self.bos_token_id
191
+ warnings.warn(
192
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
193
+ "The config can simply be saved and uploaded again to be fixed."
194
+ )
component/ofa/modeling_ofa.py ADDED
@@ -0,0 +1,2139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch OFA model."""
16
+
17
+ import math
18
+ import random
19
+ from typing import Optional, Tuple
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ # start fixing
28
+ # from ...activations import ACT2FN
29
+ # from ...file_utils import (
30
+ # add_code_sample_docstrings,
31
+ # add_end_docstrings,
32
+ # add_start_docstrings,
33
+ # add_start_docstrings_to_model_forward,
34
+ # replace_return_docstrings,
35
+ # )
36
+ # from ...file_utils import ModelOutput
37
+ # from ...modeling_outputs import (
38
+ # BaseModelOutputWithPastAndCrossAttentions,
39
+ # Seq2SeqLMOutput,
40
+ # Seq2SeqModelOutput,
41
+ # )
42
+ # from ...modeling_utils import PreTrainedModel
43
+ # from ...utils import logging
44
+ from transformers.activations import ACT2FN
45
+ from transformers.file_utils import (
46
+ add_code_sample_docstrings,
47
+ add_end_docstrings,
48
+ add_start_docstrings,
49
+ add_start_docstrings_to_model_forward,
50
+ replace_return_docstrings,
51
+ ModelOutput
52
+ )
53
+ from transformers.modeling_outputs import (
54
+ BaseModelOutputWithPastAndCrossAttentions,
55
+ Seq2SeqLMOutput,
56
+ Seq2SeqModelOutput,
57
+ )
58
+ from transformers.modeling_utils import PreTrainedModel
59
+ from transformers.utils import logging
60
+
61
+ # end fixing
62
+
63
+ from .configuration_ofa import OFAConfig
64
+ from .resnet import ResNet
65
+ from torch import Tensor
66
+ from typing import Dict, List, Optional, Tuple
67
+
68
+ logger = logging.get_logger(__name__)
69
+
70
+ _CHECKPOINT_FOR_DOC = "OFA-Sys/OFA-base"
71
+ _CONFIG_FOR_DOC = "OFAConfig"
72
+ _TOKENIZER_FOR_DOC = "OFATokenizer"
73
+
74
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
75
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
76
+
77
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
78
+
79
+ OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [
80
+ "OFA-Sys/OFA-tiny",
81
+ "OFA-Sys/OFA-medium",
82
+ "OFA-Sys/OFA-base",
83
+ "OFA-Sys/OFA-large",
84
+ ]
85
+
86
+ try:
87
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
88
+
89
+ has_fused_layernorm = True
90
+
91
+ class FusedLayerNorm(_FusedLayerNorm):
92
+ @torch.jit.unused
93
+ def forward(self, x):
94
+ if not x.is_cuda:
95
+ return super().forward(x)
96
+ else:
97
+ with torch.cuda.device(x.device):
98
+ return super().forward(x)
99
+
100
+ except ImportError:
101
+ has_fused_layernorm = False
102
+
103
+
104
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
105
+ r"""
106
+ Layer normalization.
107
+ If apex is available, use `FusedLayerNorm` instead.
108
+ """
109
+ if torch.jit.is_scripting():
110
+ export = True
111
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
112
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
113
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
114
+
115
+
116
+ def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
117
+ r"""
118
+ Make relative position indices for the text.
119
+ """
120
+ context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
121
+ memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
122
+ relative_pos = context_pos - memory_pos
123
+ sign = torch.sign(relative_pos)
124
+ mid = bucket_size // 2
125
+ abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
126
+ log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
127
+ log_pos = log_pos.int()
128
+ bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
129
+ return bucket_pos + bucket_size - 1
130
+
131
+
132
+ def make_image_bucket_position(bucket_size, num_relative_distance):
133
+ r"""
134
+ Make relative position indices for the image.
135
+ """
136
+ coords_h = torch.arange(bucket_size)
137
+ coords_w = torch.arange(bucket_size)
138
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
139
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
140
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
141
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
142
+ relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
143
+ relative_coords[:, :, 1] += bucket_size - 1
144
+ relative_coords[:, :, 0] *= 2 * bucket_size - 1
145
+ relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
146
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
147
+ relative_position_index[0, 0:] = num_relative_distance - 3
148
+ relative_position_index[0:, 0] = num_relative_distance - 2
149
+ relative_position_index[0, 0] = num_relative_distance - 1
150
+ return relative_position_index
151
+
152
+
153
+ def new_arange(x, *size):
154
+ r"""
155
+ Return a Tensor of `size` filled with a range function on the device of x.
156
+ If size is empty, using the size of the variable x.
157
+ """
158
+ if len(size) == 0:
159
+ size = x.size()
160
+ return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
161
+
162
+
163
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
164
+ r"""
165
+ Shift input ids one token to the right.
166
+ """
167
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
168
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
169
+ shifted_input_ids[:, 0] = decoder_start_token_id
170
+
171
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
172
+ # replace possible -100 values in labels by `pad_token_id`
173
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
174
+
175
+ return shifted_input_ids
176
+
177
+
178
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
179
+ r"""
180
+ Make causal mask used for uni-directional self-attention.
181
+ """
182
+ bsz, tgt_len = input_ids_shape
183
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
184
+ mask_cond = torch.arange(mask.size(-1))
185
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
186
+ mask = mask.to(dtype)
187
+
188
+ if past_key_values_length > 0:
189
+ mask = torch.cat([torch.ones(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
190
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
191
+
192
+
193
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
194
+ r"""
195
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
196
+ """
197
+ bsz, src_len = mask.size()
198
+ tgt_len = tgt_len if tgt_len is not None else src_len
199
+
200
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
201
+ inverted_mask = 1.0 - expanded_mask
202
+
203
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
204
+
205
+
206
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
207
+ r"""
208
+ Embedding for tokens
209
+ """
210
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
211
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
212
+ if padding_idx is not None:
213
+ nn.init.constant_(m.weight[padding_idx], 0)
214
+ if zero_init:
215
+ nn.init.constant_(m.weight, 0)
216
+ return m
217
+
218
+
219
+ def Linear(in_features, out_features, bias=True):
220
+ r"""
221
+ Implementation of linear projection with xavier initialization
222
+ """
223
+ m = nn.Linear(in_features, out_features, bias)
224
+ nn.init.xavier_uniform_(m.weight)
225
+ if bias:
226
+ nn.init.constant_(m.bias, 0.0)
227
+ return m
228
+
229
+
230
+ class LayerDropModuleList(nn.ModuleList):
231
+ r"""
232
+ A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
233
+
234
+ Args:
235
+ p (float): probability of dropping out each layer
236
+ modules (iterable, optional): an iterable of modules to add
237
+ """
238
+
239
+ def __init__(self, p, modules=None):
240
+ super().__init__(modules)
241
+ self.p = p
242
+
243
+ def __iter__(self):
244
+ dropout_probs = torch.empty(len(self)).uniform_()
245
+ for i, m in enumerate(super().__iter__()):
246
+ if not self.training or (dropout_probs[i] > self.p):
247
+ yield m
248
+
249
+
250
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
251
+ r"""
252
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
253
+
254
+ Args:
255
+ x (`nn.Modules`): input nn layers.
256
+ drop_prob (`float`): drop path ratio.
257
+ training (`bool`): whether is training or inference.
258
+ """
259
+ if drop_prob == 0.0 or not training:
260
+ return x
261
+ keep_prob = 1 - drop_prob
262
+ shape = (1, x.shape[1], 1)
263
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
264
+ random_tensor.floor_() # binarize
265
+ output = x.div(keep_prob) * random_tensor
266
+ return output
267
+
268
+
269
+ class DropPath(nn.Module):
270
+ r"""
271
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
272
+
273
+ Args:
274
+ drop_prob: drop path ratio.
275
+ """
276
+
277
+ def __init__(self, drop_prob=None):
278
+ super().__init__()
279
+ self.drop_prob = drop_prob
280
+
281
+ def forward(self, x):
282
+ return drop_path(x, self.drop_prob, self.training)
283
+
284
+ def extra_repr(self) -> str:
285
+ return "p={}".format(self.drop_prob)
286
+
287
+
288
+ class OFAAttention(nn.Module):
289
+ r"""
290
+ Multi-headed attention, with additional implementation for NormFormer.
291
+
292
+ Args:
293
+ embed_dim (`int`): embedding dimension.
294
+ num_heads (`int`): the number of attention heads.
295
+ dropout (`float32`): the ratio for dropout.
296
+ is_decoder (`bool`): whether or not decoder attention.
297
+ bias (`bool`): whether to add bias.
298
+ scale_heads (`bool`): whether to learn scaling heads, only for Normformer.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ embed_dim: int,
304
+ num_heads: int,
305
+ dropout: float = 0.0,
306
+ is_decoder: bool = False,
307
+ bias: bool = True,
308
+ scale_heads: bool = True,
309
+ ):
310
+ super().__init__()
311
+ self.embed_dim = embed_dim
312
+ self.num_heads = num_heads
313
+ self.dropout = dropout
314
+ self.head_dim = embed_dim // num_heads
315
+ assert (
316
+ self.head_dim * num_heads == self.embed_dim
317
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
318
+ scale_factor=2
319
+ self.scaling = float(self.head_dim * scale_factor) ** -0.5
320
+ self.is_decoder = is_decoder
321
+
322
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
323
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
324
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
325
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
326
+ self.attn_dropout = nn.Dropout(p=dropout)
327
+ self.c_attn = nn.Parameter(torch.ones((self.num_heads,)), requires_grad=True) if scale_heads else None
328
+
329
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
330
+ r"""
331
+ Reshape tensors for multi-head attention.
332
+ """
333
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
334
+
335
+ def forward(
336
+ self,
337
+ hidden_states: torch.Tensor,
338
+ key_value_states: Optional[torch.Tensor] = None,
339
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ output_attentions: bool = False,
342
+ attn_bias: Optional[torch.Tensor] = None,
343
+ ):
344
+ r"""
345
+ Args:
346
+ hidden_states (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`)`: input states.
347
+ key_value_states (`torch.FloatTensor` of shape (bsz, tgt_len, embed_dim), *optional*): key value states.
348
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
349
+ cached past key value states for fast inference.
350
+ attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, seq_len)`, *optional*): attention mask.
351
+ output_attentions (`bool`, *optional*): whether to output attention weights of all layers.
352
+ attn_bias (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`, *optional*):
353
+ the attention bias for positional information.
354
+
355
+ Returns:
356
+ attn_output (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): attention outputs.
357
+ attn_weights_reshaped (`torch.FloatTensor`, *optional*): attention weights of all layers.
358
+ past_key_value (`torch.FloatTensor`, *optional*): cached key value states for fast inference.
359
+ """
360
+
361
+ # if key_value_states are provided this layer is used as a cross-attention layer
362
+ # for the decoder
363
+ is_cross_attention = key_value_states is not None
364
+ bsz, tgt_len, embed_dim = hidden_states.size()
365
+
366
+ # get query proj
367
+ query_states = self.q_proj(hidden_states) * self.scaling
368
+ # get key, value proj
369
+ if is_cross_attention and past_key_value is not None:
370
+ # reuse k,v, cross_attentions
371
+ key_states = past_key_value[0]
372
+ value_states = past_key_value[1]
373
+ elif is_cross_attention:
374
+ # cross_attentions
375
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
376
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
377
+ elif past_key_value is not None:
378
+ # reuse k, v, self_attention
379
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
380
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
381
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
382
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
383
+ else:
384
+ # self_attention
385
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
386
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
387
+
388
+ if self.is_decoder:
389
+ past_key_value = (key_states, value_states)
390
+
391
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
392
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
393
+ key_states = key_states.view(*proj_shape)
394
+ value_states = value_states.view(*proj_shape)
395
+
396
+ src_len = key_states.size(1)
397
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
398
+
399
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
400
+ raise ValueError(
401
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
402
+ )
403
+
404
+ # Add attention bias for positional information
405
+ if attn_bias is not None:
406
+ attn_weights += attn_bias
407
+
408
+ if attention_mask is not None:
409
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
410
+ raise ValueError(
411
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
412
+ )
413
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
414
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
415
+
416
+ attn_weights = F.softmax(attn_weights, dim=-1)
417
+
418
+ if output_attentions:
419
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
420
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
421
+ else:
422
+ attn_weights_reshaped = None
423
+
424
+ attn_probs = self.attn_dropout(attn_weights)
425
+
426
+ attn_output = torch.bmm(attn_probs, value_states)
427
+
428
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
429
+ raise ValueError(
430
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
431
+ )
432
+
433
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
434
+ attn_output = attn_output.transpose(1, 2)
435
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
436
+
437
+ if self.c_attn is not None:
438
+ attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
439
+ attn_output = torch.einsum("bthd,h->bthd", attn_output, self.c_attn)
440
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
441
+
442
+ attn_output = self.out_proj(attn_output)
443
+
444
+ return attn_output, attn_weights_reshaped, past_key_value
445
+
446
+
447
+ class OFAEncoderLayer(nn.Module):
448
+ r"""
449
+ OFA encoder layer implementation.
450
+
451
+ Args:
452
+ config: configuration for OFA.
453
+ drop_path_rate: the ratio for drop path.
454
+ """
455
+
456
+ def __init__(self, config: OFAConfig, drop_path_rate=0.0):
457
+ super().__init__()
458
+ self.embed_dim = config.d_model
459
+ self.self_attn = OFAAttention(
460
+ embed_dim=self.embed_dim,
461
+ num_heads=config.encoder_attention_heads,
462
+ dropout=config.attention_dropout,
463
+ )
464
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
465
+ self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
466
+ self.dropout = nn.Dropout(config.dropout)
467
+ self.activation_fn = ACT2FN[config.activation_function]
468
+ self.activation_dropout = nn.Dropout(config.activation_dropout)
469
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
470
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
471
+ self.ffn_layer_norm = LayerNorm(config.encoder_ffn_dim) if config.normformer else None
472
+ self.final_layer_norm = LayerNorm(self.embed_dim)
473
+ self.normalize_before = config.encoder_normalize_before
474
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
475
+
476
+ def residual_connection(self, x, residual):
477
+ r"""
478
+ Residual connection with drop path.
479
+ """
480
+ return residual + self.drop_path(x)
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states: torch.Tensor,
485
+ attention_mask: torch.Tensor,
486
+ output_attentions: bool = False,
487
+ attn_bias: Optional[torch.Tensor] = None,
488
+ ):
489
+ r"""
490
+ Args:
491
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(bsz, src_len, embed_dim)*
492
+ attention_mask (`torch.FloatTensor`): attention mask of size
493
+ *(bsz, 1, src_len, src_len)* where padding elements are indicated by very large negative values.
494
+ output_attentions (`bool`, *optional*):
495
+ whether to return the attentions tensors of all attention layers. See `attentions` under
496
+ returned tensors for more detail.
497
+ attn_bias (`torch.FloatTensor`): bias for positional information.
498
+
499
+ Returns:
500
+ outputs (`tuple(torch.FloatTensor)`):
501
+ output hidden states of size (bsz, src_len, embed_dim), optionally with attention weights.
502
+ """
503
+
504
+ residual = hidden_states
505
+ if self.normalize_before:
506
+ hidden_states = self.self_attn_layer_norm(hidden_states)
507
+ hidden_states, attn_weights, _ = self.self_attn(
508
+ hidden_states=hidden_states,
509
+ attention_mask=attention_mask,
510
+ output_attentions=output_attentions,
511
+ attn_bias=attn_bias,
512
+ )
513
+ if self.self_attn_mid_layer_norm:
514
+ hidden_states = self.self_attn_mid_layer_norm(hidden_states)
515
+ hidden_states = self.dropout(hidden_states)
516
+ hidden_states = self.residual_connection(hidden_states, residual)
517
+ if not self.normalize_before:
518
+ hidden_states = self.self_attn_layer_norm(hidden_states)
519
+
520
+ residual = hidden_states
521
+
522
+ if self.normalize_before:
523
+ hidden_states = self.final_layer_norm(hidden_states)
524
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
525
+ hidden_states = self.activation_dropout(hidden_states)
526
+ if self.ffn_layer_norm:
527
+ hidden_states = self.ffn_layer_norm(hidden_states)
528
+ hidden_states = self.fc2(hidden_states)
529
+ hidden_states = self.dropout(hidden_states)
530
+ hidden_states = self.residual_connection(hidden_states, residual)
531
+ if not self.normalize_before:
532
+ hidden_states = self.final_layer_norm(hidden_states)
533
+
534
+ if hidden_states.dtype == torch.float16 and (
535
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
536
+ ):
537
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
538
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
539
+
540
+ outputs = (hidden_states,)
541
+
542
+ if output_attentions:
543
+ outputs += (attn_weights,)
544
+
545
+ return outputs
546
+
547
+
548
+ class OFADecoderLayer(nn.Module):
549
+ r"""
550
+ OFA decoder layer implementation.
551
+
552
+ Args:
553
+ config: configuration for OFA.
554
+ drop_path_rate: the ratio for drop path.
555
+ """
556
+
557
+ def __init__(self, config: OFAConfig, drop_path_rate=0.0):
558
+ super().__init__()
559
+ self.embed_dim = config.d_model
560
+
561
+ self.self_attn = OFAAttention(
562
+ embed_dim=self.embed_dim,
563
+ num_heads=config.decoder_attention_heads,
564
+ dropout=config.attention_dropout,
565
+ is_decoder=True,
566
+ )
567
+ self.dropout = nn.Dropout(p=config.dropout)
568
+ self.activation_fn = ACT2FN[config.activation_function]
569
+ self.activation_dropout = nn.Dropout(p=config.activation_dropout)
570
+
571
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
572
+ self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
573
+ self.cross_attn = OFAAttention(
574
+ self.embed_dim,
575
+ config.decoder_attention_heads,
576
+ dropout=config.attention_dropout,
577
+ is_decoder=True,
578
+ )
579
+ self.cross_attn_layer_norm = LayerNorm(self.embed_dim)
580
+ self.cross_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
581
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
582
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
583
+ self.ffn_layer_norm = LayerNorm(config.decoder_ffn_dim) if config.normformer else None
584
+ self.final_layer_norm = LayerNorm(self.embed_dim)
585
+ self.normalize_before = config.decoder_normalize_before
586
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
587
+
588
+ def residual_connection(self, x, residual):
589
+ r"""
590
+ Residual connection with drop path.
591
+ """
592
+ return residual + self.drop_path(x)
593
+
594
+ def forward(
595
+ self,
596
+ hidden_states: torch.Tensor,
597
+ attention_mask: Optional[torch.Tensor] = None,
598
+ encoder_hidden_states: Optional[torch.Tensor] = None,
599
+ encoder_attention_mask: Optional[torch.Tensor] = None,
600
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
601
+ output_attentions: Optional[bool] = False,
602
+ use_cache: Optional[bool] = False,
603
+ self_attn_bias: Optional[torch.Tensor] = None,
604
+ cross_attn_bias: Optional[torch.Tensor] = None,
605
+ ):
606
+ r"""
607
+ Args:
608
+ hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): input to the layer.
609
+ attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`):
610
+ attention mask where padding elements are indicated by very large negative values.
611
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
612
+ cross attention input to the layer.
613
+ encoder_attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`):
614
+ encoder attention mask where padding elements are indicated by very large negative values.
615
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
616
+ output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers.
617
+ use_cache (`bool`, *optional*): whether to use cache
618
+ self_attn_bias (`torch.FloatTensor`): self attention bias for positional information.
619
+ cross_attn_bias (`torch.FloatTensor`): cross attention bias for positional information.
620
+ """
621
+
622
+ # Self attention with intermediate layernorm
623
+ residual = hidden_states
624
+ if self.normalize_before:
625
+ hidden_states = self.self_attn_layer_norm(hidden_states)
626
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
627
+ # add present self-attn cache to position 1,2 of present_key_value tuple
628
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
629
+ hidden_states=hidden_states,
630
+ past_key_value=self_attn_past_key_value,
631
+ attention_mask=attention_mask,
632
+ output_attentions=output_attentions,
633
+ attn_bias=self_attn_bias,
634
+ )
635
+ if self.self_attn_mid_layer_norm:
636
+ hidden_states = self.self_attn_mid_layer_norm(hidden_states)
637
+ hidden_states = self.dropout(hidden_states)
638
+ hidden_states = self.residual_connection(hidden_states, residual)
639
+ if not self.normalize_before:
640
+ hidden_states = self.self_attn_layer_norm(hidden_states)
641
+
642
+ # Cross attention with intermediate layernorm
643
+ cross_attn_present_key_value = None
644
+ cross_attn_weights = None
645
+ if encoder_hidden_states is not None:
646
+ residual = hidden_states
647
+ if self.normalize_before:
648
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
649
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
650
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
651
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
652
+ hidden_states=hidden_states,
653
+ key_value_states=encoder_hidden_states,
654
+ attention_mask=encoder_attention_mask,
655
+ past_key_value=cross_attn_past_key_value,
656
+ output_attentions=output_attentions,
657
+ attn_bias=cross_attn_bias,
658
+ )
659
+ if self.cross_attn_mid_layer_norm:
660
+ hidden_states = self.cross_attn_mid_layer_norm(hidden_states)
661
+ hidden_states = self.dropout(hidden_states)
662
+ hidden_states = self.residual_connection(hidden_states, residual)
663
+ if not self.normalize_before:
664
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
665
+
666
+ # add cross-attn to positions 3,4 of present_key_value tuple
667
+ present_key_value = present_key_value + cross_attn_present_key_value
668
+
669
+ # FFN with intermediate layernorm
670
+ residual = hidden_states
671
+ if self.normalize_before:
672
+ hidden_states = self.final_layer_norm(hidden_states)
673
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
674
+ hidden_states = self.activation_dropout(hidden_states)
675
+ if self.ffn_layer_norm:
676
+ hidden_states = self.ffn_layer_norm(hidden_states)
677
+ hidden_states = self.fc2(hidden_states)
678
+ hidden_states = self.dropout(hidden_states)
679
+ hidden_states = self.residual_connection(hidden_states, residual)
680
+ if not self.normalize_before:
681
+ hidden_states = self.final_layer_norm(hidden_states)
682
+
683
+ outputs = (hidden_states,)
684
+
685
+ if output_attentions:
686
+ outputs += (self_attn_weights, cross_attn_weights)
687
+
688
+ if use_cache:
689
+ outputs += (present_key_value,)
690
+
691
+ return outputs
692
+
693
+
694
+ class OFAPreTrainedModel(PreTrainedModel):
695
+ r"""
696
+ Base class OFA
697
+ """
698
+
699
+ config_class = OFAConfig
700
+ base_model_prefix = "model"
701
+ supports_gradient_checkpointing = True
702
+
703
+ def _init_weights(self, module):
704
+ r"""
705
+ Weight initialization which follows BERT.
706
+ """
707
+ std = self.config.init_std
708
+ if isinstance(module, nn.Linear):
709
+ module.weight.data.normal_(mean=0.0, std=std)
710
+ if module.bias is not None:
711
+ module.bias.data.zero_()
712
+ elif isinstance(module, nn.Embedding):
713
+ module.weight.data.normal_(mean=0.0, std=std)
714
+ if module.padding_idx is not None:
715
+ module.weight.data[module.padding_idx].zero_()
716
+
717
+ def _set_gradient_checkpointing(self, module, value=False):
718
+ r"""
719
+ Turn on the switch of gradient checkpointing.
720
+ """
721
+ if isinstance(module, (OFADecoder, OFAEncoder)):
722
+ module.gradient_checkpointing = value
723
+
724
+
725
+ @dataclass
726
+ class OFAEncoderOutput(ModelOutput):
727
+ r"""
728
+ Base class for OFA's outputs.
729
+
730
+ Args:
731
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`):
732
+ Sequence of hidden-states at the output of the last layer of the model.
733
+
734
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
735
+ or when `config.output_hidden_states=True`):
736
+
737
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
738
+ shape `(bsz, seq_len, hidden)`.
739
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
740
+
741
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
742
+ or when `config.output_attentions=True`):
743
+
744
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`.
745
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
746
+ heads.
747
+
748
+ position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`):
749
+ postional embeddings of the inputs.
750
+ """
751
+
752
+ last_hidden_state: torch.FloatTensor = None
753
+ padding_mask: torch.Tensor = None
754
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
755
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
756
+ position_embedding: Optional[torch.FloatTensor] = None
757
+
758
+
759
+ OFA_START_DOCSTRING = r"""
760
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
761
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
762
+ etc.)
763
+
764
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
765
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
766
+ and behavior.
767
+
768
+ Parameters:
769
+ config ([`~OFAConfig`]):
770
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
771
+ load the weights associated with the model, only the configuration. Check out the
772
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
773
+ """
774
+
775
+
776
+ OFA_GENERATION_EXAMPLE = r"""
777
+ Image captioning example:
778
+
779
+ ```python
780
+ >>> from PIL import Image
781
+ >>> from torchvision import transforms
782
+ >>> from transformers import OFATokenizer, OFAForConditionalGeneration
783
+
784
+ >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
785
+ >>> resolution = 256
786
+ >>> patch_resize_transform = transforms.Compose([
787
+ lambda image: image.convert("RGB"),
788
+ transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
789
+ transforms.ToTensor(),
790
+ transforms.Normalize(mean=mean, std=std)
791
+ ])
792
+
793
+ >>> model = OFAForConditionalGeneration.from_pretrained(ckpt_dir)
794
+ >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
795
+
796
+ >>> txt = " what is the description of the image?"
797
+ >>> inputs = tokenizer([txt], max_length=1024, return_tensors="pt")["input_ids"]
798
+ >>> img = Image.open(path_to_image)
799
+ >>> patch_img = patch_resize_transform(img).unsqueeze(0)
800
+
801
+ >>> gen = model.generate(inputs, patch_img=patch_img, num_beams=4)
802
+ >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False))
803
+ ```
804
+ """
805
+
806
+
807
+ OFA_INPUTS_DOCSTRING = r"""
808
+ Args:
809
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
810
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
811
+
812
+ indices can be obtained using [`~OFATokenizer`].
813
+
814
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
815
+ the resized image, which are transformed by the default operations.
816
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
817
+ the second (if it exists) image.
818
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
819
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
820
+ sample_patch_num (`int`): the number of patches to sample.
821
+ decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
822
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
823
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding.
824
+ encoder_outputs (`OFAEncoderOutput`):
825
+ encoder outputs with hidden states, positional embeddings, and padding masks.
826
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
827
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
828
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
829
+ shape `(bsz, num_heads, src_len, head_size)`.
830
+ use_cache (`bool`): whether to use cache for faster inference.
831
+ output_attentions (`bool`): whether to output attention weights.
832
+ output_hidden_states (`bool`): whether to output hidden states.
833
+ return_dict (`bool`): unused. Keep it for generation only.
834
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
835
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
836
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
837
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
838
+ """
839
+
840
+
841
+ class OFAEncoder(OFAPreTrainedModel):
842
+ r"""
843
+ OFA encoder consisting of layers of [`OFAEncoderLayer`].
844
+
845
+ Args:
846
+ config: OFAConfig
847
+ embed_tokens (`nn.Embedding`, *optional*): output embedding
848
+ """
849
+
850
+ def __init__(self, config: OFAConfig, embed_tokens: Optional[nn.Embedding] = None):
851
+ super().__init__(config)
852
+
853
+ self.dropout = nn.Dropout(config.dropout)
854
+ self.encoder_layerdrop = config.encoder_layerdrop
855
+
856
+ embed_dim = config.d_model
857
+ self.padding_idx = config.pad_token_id
858
+ self.max_source_positions = config.max_position_embeddings
859
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
860
+ self.num_attention_heads = config.encoder_attention_heads
861
+
862
+ if getattr(config, "layernorm_embedding", False):
863
+ self.layernorm_embedding = LayerNorm(embed_dim)
864
+ else:
865
+ self.layernorm_embedding = None
866
+
867
+ if embed_tokens is not None:
868
+ self.embed_tokens = embed_tokens
869
+ else:
870
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
871
+
872
+ if config.add_type_embedding:
873
+ self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
874
+ else:
875
+ self.type_embedding = None
876
+
877
+ if config.resnet_type == "resnet18":
878
+ self.embed_images = ResNet([2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
879
+ elif config.resnet_type == "resnet34":
880
+ self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
881
+ elif config.resnet_type == "resnet50":
882
+ self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
883
+ elif config.resnet_type == "resnet101":
884
+ self.embed_images = ResNet([3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
885
+ elif config.resnet_type == "resnet152":
886
+ self.embed_images = ResNet([3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
887
+ else:
888
+ raise NotImplementedError
889
+ self.image_proj = Linear(1024, embed_dim)
890
+
891
+ if config.resnet_model_path:
892
+ resnet_state_dict = torch.load(config.resnet_model_path)
893
+ self.embed_images.load_state_dict(resnet_state_dict)
894
+ if config.patch_layernorm_embedding:
895
+ self.patch_layernorm_embedding = LayerNorm(embed_dim)
896
+ else:
897
+ self.patch_layernorm_embedding = None
898
+
899
+ self.embed_positions = Embedding(self.max_source_positions + 2, embed_dim)
900
+ self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, embed_dim)
901
+ self.pos_ln = LayerNorm(embed_dim)
902
+ self.image_pos_ln = LayerNorm(embed_dim)
903
+ self.pos_scaling = float(embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5
904
+ self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
905
+ self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
906
+
907
+ if self.encoder_layerdrop > 0.0:
908
+ self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
909
+ else:
910
+ self.layers = nn.ModuleList([])
911
+
912
+ dpr = [x.item() for x in torch.linspace(0, config.encoder_drop_path_rate, config.encoder_layers)]
913
+ self.layers.extend(
914
+ [OFAEncoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.encoder_layers)]
915
+ )
916
+ self.num_layers = len(self.layers)
917
+
918
+ if config.encoder_normalize_before:
919
+ self.layer_norm = LayerNorm(embed_dim)
920
+ else:
921
+ self.layer_norm = None
922
+
923
+ self.token_bucket_size = config.token_bucket_size
924
+ token_num_rel_dis = 2 * config.token_bucket_size - 1
925
+ token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
926
+ self.token_rel_pos_table_list = nn.ModuleList(
927
+ [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in
928
+ range(config.encoder_layers)]
929
+ )
930
+
931
+ self.image_bucket_size = config.image_bucket_size
932
+ image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3
933
+ image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis)
934
+ self.image_rel_pos_table_list = nn.ModuleList(
935
+ [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in
936
+ range(config.encoder_layers)]
937
+ )
938
+
939
+ if config.layernorm_embedding:
940
+ self.layernorm_embedding = LayerNorm(embed_dim)
941
+ else:
942
+ self.layernorm_embedding = None
943
+
944
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
945
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
946
+ self.entangle_position_embedding = config.entangle_position_embedding
947
+
948
+ self.gradient_checkpointing = False
949
+ # Initialize weights and apply final processing
950
+ self.post_init()
951
+
952
+ def get_input_embeddings(self):
953
+ r"""
954
+ Get the embedding weight.
955
+ """
956
+ return self.embed_tokens
957
+
958
+ def set_input_embeddings(self, value):
959
+ r"""
960
+ Set the weight of embedding with the given tensor.
961
+ """
962
+ self.embed_tokens = value
963
+
964
+ def get_rel_pos_bias(self, x, idx):
965
+ r"""
966
+ Get the relative positional bias of the text, for attention.
967
+ """
968
+
969
+ seq_len = x.size(1)
970
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
971
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
972
+ values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
973
+ values = values.permute([0, 3, 1, 2])
974
+ return values.contiguous()
975
+
976
+ def get_image_rel_pos_bias(self, image_position_ids, idx):
977
+ r"""
978
+ Get the relative positional bias of the image, for attention.
979
+ """
980
+
981
+ bsz, seq_len = image_position_ids.shape
982
+ rp_bucket_size = self.image_rp_bucket.size(1)
983
+
984
+ rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
985
+ bsz, rp_bucket_size, rp_bucket_size
986
+ ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
987
+ ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
988
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
989
+ values = values.permute(0, 3, 1, 2)
990
+ return values
991
+
992
+ def get_patch_images_info(self, patch_images, sample_patch_num, device):
993
+ r"""
994
+ Get the basic information of the resized image.
995
+
996
+ Args:
997
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image.
998
+ sample_patch_num (`int`):
999
+ the number of patches to sample. If it is equal to -1, no sampling will be performed.
1000
+ device: GPU device.
1001
+
1002
+ Returns:
1003
+ image_embed (`torch.FloatTensor` of shape `(bsz, h * w, hidden)`): the output of the visual encoder.
1004
+ image_num_patches (`int`, equal to `h * w`): the number of patches.
1005
+ image_padding_mask (`torch.BooleanTensor` of shape `(bsz, h*w)`): image padding mask.
1006
+ image_position_ids (`torch.LongTensor` of shape `(bsz, h*w)`): image position ids.
1007
+ image_pos_embed (`torch.FloatTensor` of shape (bsz, h*w, hidden)): the positional embedding.
1008
+ """
1009
+
1010
+ image_embed = self.embed_images(patch_images)
1011
+ h, w = image_embed.shape[-2:]
1012
+ image_num_patches = h * w
1013
+ image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
1014
+ image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
1015
+ torch.arange(h).unsqueeze(1) * self.image_bucket_size + 1
1016
+ image_position_idx = image_position_idx.view(-1).to(device)
1017
+ image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
1018
+
1019
+ image_embed = image_embed.flatten(2).transpose(1, 2)
1020
+ if sample_patch_num is not None:
1021
+ patch_orders = [
1022
+ random.sample(range(image_num_patches), k=sample_patch_num)
1023
+ for _ in range(patch_images.size(0))
1024
+ ]
1025
+ patch_orders = torch.LongTensor(patch_orders).to(device)
1026
+ image_embed = image_embed.gather(
1027
+ 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
1028
+ )
1029
+ image_num_patches = sample_patch_num
1030
+ image_padding_mask = image_padding_mask.gather(1, patch_orders)
1031
+ image_position_ids = image_position_ids.gather(1, patch_orders)
1032
+ image_pos_embed = self.embed_image_positions(image_position_ids)
1033
+
1034
+ return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
1035
+
1036
+ def forward_embedding(
1037
+ self,
1038
+ input_ids,
1039
+ image_embed: Optional[torch.Tensor] = None,
1040
+ image_embed_2: Optional[torch.Tensor] = None,
1041
+ token_embedding: Optional[torch.Tensor] = None,
1042
+ pos_embed: Optional[torch.Tensor] = None,
1043
+ image_pos_embed: Optional[torch.Tensor] = None,
1044
+ image_pos_embed_2: Optional[torch.Tensor] = None
1045
+ ):
1046
+ r"""
1047
+ Generate embeddings of both the image and the text.
1048
+ Actually since OFA unifies both unimodal and multimodal data,
1049
+ image inputs are optional.
1050
+
1051
+ Args:
1052
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the tokens in the vocabulary.
1053
+ image_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings.
1054
+ image_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1055
+ image embeddings of the second image (if it exists).
1056
+ token_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`, *optional*):
1057
+ input token embeddings to replace the embeddings of input ids.
1058
+ image_pos_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1059
+ positional embeddings of the image.
1060
+ image_pos_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1061
+ positional embeddings of the second image.
1062
+
1063
+ Returns:
1064
+ x (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings of the input.
1065
+ embed (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`):
1066
+ embeddings without adding positional and type embeddings.
1067
+ """
1068
+
1069
+ # embed tokens and positions
1070
+ if token_embedding is None:
1071
+ token_embedding = self.embed_tokens(input_ids)
1072
+ x = embed = self.embed_scale * token_embedding
1073
+ if self.entangle_position_embedding and pos_embed is not None:
1074
+ x += pos_embed
1075
+ if self.type_embedding is not None:
1076
+ x += self.type_embedding(input_ids.new_zeros(x.size()[:2]))
1077
+ if self.layernorm_embedding is not None:
1078
+ x = self.layernorm_embedding(x)
1079
+ x = self.dropout(x)
1080
+
1081
+ # embed raw images
1082
+ if image_embed is not None:
1083
+ image_embed = self.image_proj(image_embed)
1084
+ image_x = image_embed = self.embed_scale * image_embed
1085
+ if self.entangle_position_embedding and image_pos_embed is not None:
1086
+ image_x += image_pos_embed
1087
+ if self.type_embedding is not None:
1088
+ image_x += self.type_embedding(input_ids.new_ones(image_x.size()[:2]))
1089
+ if self.patch_layernorm_embedding is not None:
1090
+ image_x = self.patch_layernorm_embedding(image_x)
1091
+ image_x = self.dropout(image_x)
1092
+ x = torch.cat([image_x, x], dim=1)
1093
+ embed = torch.cat([image_embed, embed], dim=1)
1094
+
1095
+ if image_embed_2 is not None:
1096
+ assert self.type_embedding is not None
1097
+ image_embed_2 = self.image_proj(image_embed_2)
1098
+ image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
1099
+ if self.entangle_position_embedding and image_pos_embed_2 is not None:
1100
+ image_x_2 += image_pos_embed_2
1101
+ if self.type_embedding is not None:
1102
+ image_x_2 += self.type_embedding(input_ids.new_full(image_x_2.size()[:2], fill_value=2))
1103
+ if self.patch_layernorm_embedding is not None:
1104
+ image_x_2 = self.patch_layernorm_embedding(image_x_2)
1105
+ image_x_2 = self.dropout(image_x_2)
1106
+ if self.quant_noise is not None:
1107
+ image_x_2 = self.quant_noise(image_x_2)
1108
+ x = torch.cat([image_x_2, x], dim=1)
1109
+ embed = torch.cat([image_embed_2, embed], dim=1)
1110
+
1111
+ return x, embed
1112
+
1113
+ def reorder_encoder_out(self, encoder_out, new_order):
1114
+ """
1115
+ Reorder encoder output according to *new_order*.
1116
+
1117
+ Args:
1118
+ encoder_out: output from the ``forward()`` method
1119
+ new_order (LongTensor): desired order
1120
+
1121
+ Returns:
1122
+ *encoder_out* rearranged according to *new_order*
1123
+ """
1124
+
1125
+ if "last_hidden_state" not in encoder_out:
1126
+ new_encoder_out = None
1127
+ else:
1128
+ new_encoder_out = encoder_out["last_hidden_state"].index_select(0, new_order)
1129
+
1130
+ if "padding_mask" not in encoder_out:
1131
+ new_encoder_padding_mask = None
1132
+ else:
1133
+ new_encoder_padding_mask = encoder_out["padding_mask"].index_select(0, new_order)
1134
+
1135
+
1136
+ if "position_embedding" not in encoder_out:
1137
+ new_position_embeddings = None
1138
+ else:
1139
+ new_position_embeddings = encoder_out["position_embedding"].index_select(0, new_order)
1140
+
1141
+ if "hidden_states" not in encoder_out:
1142
+ new_encoer_states = None
1143
+ else:
1144
+ encoder_states = encoder_out["hidden_states"]
1145
+ new_encoer_states = ()
1146
+ if len(encoder_states) > 0:
1147
+ for idx, state in enumerate(encoder_states):
1148
+ new_encoer_states += (state.index_select(0, new_order),)
1149
+
1150
+ if "attentions" not in encoder_out:
1151
+ attentions = None
1152
+ else:
1153
+ attentions = encoder_out["attentions"]
1154
+
1155
+ return OFAEncoderOutput(
1156
+ last_hidden_state=new_encoder_out,
1157
+ padding_mask=new_encoder_padding_mask,
1158
+ hidden_states=new_encoer_states,
1159
+ attentions=attentions,
1160
+ position_embedding=new_position_embeddings
1161
+ )
1162
+
1163
+ def forward(
1164
+ self,
1165
+ input_ids=None,
1166
+ patch_images: Optional[torch.Tensor] = None,
1167
+ patch_images_2: Optional[torch.Tensor] = None,
1168
+ patch_masks: Optional[torch.Tensor] = None,
1169
+ output_attentions: bool = False,
1170
+ output_hidden_states: bool = False,
1171
+ token_embeddings: Optional[torch.Tensor] = None,
1172
+ sample_patch_num: Optional[int] = None,
1173
+ ):
1174
+ r"""
1175
+ Args:
1176
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
1177
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
1178
+
1179
+ indices can be obtained using [`~OFATokenizer`].
1180
+
1181
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1182
+ the resized image, which are transformed by the default operations.
1183
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1184
+ the second (if it exists) image.
1185
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
1186
+ output_attentions (`bool`): whether to return all attention weights,
1187
+ output_hidden_states (`bool`): whether to return all hidden states.
1188
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
1189
+ sample_patch_num (`int`): the number of patches to sample.
1190
+
1191
+ Returns:
1192
+ [`OFAEncoderOutput`]:
1193
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1194
+ the states of the last layer.
1195
+ padding_mask (`torch.BoolTensor` of shape `(bsz, seq_len)`):
1196
+ the padding mask of the source context.
1197
+ hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1198
+ the states of all layers including the embeddings.
1199
+ attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`):
1200
+ the attention weights of all layers.
1201
+ position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1202
+ positional embeddings of the input image and tokens.
1203
+ """
1204
+
1205
+ image_embed = None
1206
+ image_embed_2 = None
1207
+ image_pos_embed = None
1208
+ image_pos_embed_2 = None
1209
+ if patch_images is not None:
1210
+ image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
1211
+ self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device)
1212
+ # image_padding_mask[~patch_masks] = True # comment the line to temporarily fix the bug of mismatch
1213
+ if patch_images_2 is not None:
1214
+ image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
1215
+ self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device)
1216
+ image_padding_mask_2[~patch_masks] = True
1217
+
1218
+ encoder_padding_mask = input_ids.eq(self.padding_idx)
1219
+ if patch_images is not None:
1220
+ encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
1221
+ if patch_images_2 is not None:
1222
+ encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
1223
+ has_pads = encoder_padding_mask.any()
1224
+
1225
+ pos_embed = self.embed_positions(new_arange(input_ids))
1226
+ x, encoder_embedding = self.forward_embedding(
1227
+ input_ids, image_embed, image_embed_2, token_embeddings,
1228
+ pos_embed, image_pos_embed, image_pos_embed_2
1229
+ )
1230
+
1231
+ # account for padding while computing the representation
1232
+ if has_pads:
1233
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
1234
+
1235
+ pos_embed = self.pos_ln(pos_embed)
1236
+ if patch_images is not None:
1237
+ image_pos_embed = self.image_pos_ln(image_pos_embed)
1238
+ pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
1239
+ if patch_images_2 is not None:
1240
+ image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
1241
+ pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
1242
+
1243
+ pos_q = self.pos_q_linear(pos_embed).view(
1244
+ x.size(0), x.size(1), self.num_attention_heads, -1
1245
+ ).transpose(1, 2) * self.pos_scaling
1246
+ pos_k = self.pos_k_linear(pos_embed).view(
1247
+ x.size(0), x.size(1), self.num_attention_heads, -1
1248
+ ).transpose(1, 2)
1249
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
1250
+
1251
+ # expand attention_mask
1252
+ if has_pads:
1253
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1254
+ attention_mask = _expand_mask(~encoder_padding_mask, dtype=x.dtype)
1255
+
1256
+ encoder_states = () if output_hidden_states else None
1257
+ all_attentions = () if output_attentions else None
1258
+
1259
+ # encoder layers
1260
+ for idx, layer in enumerate(self.layers):
1261
+ if output_hidden_states:
1262
+ encoder_states += (x,)
1263
+ self_attn_bias = abs_pos_bias.clone()
1264
+ self_attn_bias[:, :, -input_ids.size(1):, -input_ids.size(1):] += self.get_rel_pos_bias(input_ids, idx)
1265
+ if patch_images_2 is not None:
1266
+ self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
1267
+ self.get_image_rel_pos_bias(image_position_ids_2, idx)
1268
+ self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + image_num_patches,
1269
+ image_num_patches_2:image_num_patches_2 + image_num_patches] += \
1270
+ self.get_image_rel_pos_bias(image_position_ids, idx)
1271
+ elif patch_images is not None:
1272
+ self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
1273
+ self.get_image_rel_pos_bias(image_position_ids, idx)
1274
+ self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))
1275
+
1276
+ hidden_outputs = layer(x, attention_mask if has_pads else None, attn_bias=self_attn_bias, output_attentions=output_attentions)
1277
+ x = hidden_outputs[0]
1278
+
1279
+ if output_attentions:
1280
+ attention = hidden_outputs[1]
1281
+ all_attentions = all_attentions + (attention,)
1282
+
1283
+ if output_hidden_states:
1284
+ encoder_states += (x,)
1285
+
1286
+ if self.layer_norm is not None:
1287
+ x = self.layer_norm(x)
1288
+
1289
+ return OFAEncoderOutput(
1290
+ last_hidden_state=x,
1291
+ padding_mask=encoder_padding_mask,
1292
+ hidden_states=encoder_states,
1293
+ attentions=all_attentions,
1294
+ position_embedding=pos_embed,
1295
+ )
1296
+
1297
+
1298
+ class OFADecoder(OFAPreTrainedModel):
1299
+ r"""
1300
+ OFA decoder consisting of layers of [`OFADecoderLayer`]
1301
+
1302
+ Args:
1303
+ config: OFAConfig
1304
+ embed_tokens (`nn.Embedding`, *optional*): output embedding
1305
+ """
1306
+
1307
+ def __init__(self, config: OFAConfig, embed_tokens: Optional[nn.Embedding] = None, output_projection=None):
1308
+ super().__init__(config)
1309
+ self.dropout = nn.Dropout(config.dropout)
1310
+ self.decoder_layerdrop = config.decoder_layerdrop
1311
+ self.padding_idx = config.pad_token_id
1312
+ self.max_target_positions = config.max_position_embeddings
1313
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1314
+
1315
+ self._future_mask = torch.empty(0)
1316
+ self.share_input_output_embed = config.share_decoder_input_output_embed
1317
+ self.num_attention_heads = config.decoder_attention_heads
1318
+
1319
+ if embed_tokens is not None:
1320
+ self.embed_tokens = embed_tokens
1321
+ else:
1322
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
1323
+
1324
+ self.embed_dim = config.d_model
1325
+ self.output_embed_dim = config.d_model
1326
+
1327
+ self.layers = nn.ModuleList([OFADecoderLayer(config) for _ in range(config.decoder_layers)])
1328
+ if config.layernorm_embedding:
1329
+ self.layernorm_embedding = LayerNorm(self.embed_dim)
1330
+ else:
1331
+ self.layernorm_embedding = None
1332
+
1333
+ self.window_size = config.code_image_size // 8
1334
+
1335
+ self.embed_positions = Embedding(self.max_target_positions + 2, self.embed_dim)
1336
+ self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, self.embed_dim)
1337
+ self.pos_ln = LayerNorm(self.embed_dim)
1338
+ self.image_pos_ln = LayerNorm(self.embed_dim)
1339
+ self.pos_scaling = float(self.embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5
1340
+ self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
1341
+ self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
1342
+ self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
1343
+ self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
1344
+
1345
+ if config.code_layernorm_embedding:
1346
+ self.code_layernorm_embedding = LayerNorm(self.embed_dim)
1347
+ else:
1348
+ self.code_layernorm_embedding = None
1349
+
1350
+ if self.decoder_layerdrop > 0.0:
1351
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
1352
+ else:
1353
+ self.layers = nn.ModuleList([])
1354
+
1355
+ dpr = [x.item() for x in torch.linspace(0, config.decoder_drop_path_rate, config.decoder_layers)]
1356
+ self.layers.extend([OFADecoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.decoder_layers)])
1357
+ self.num_layers = len(self.layers)
1358
+
1359
+ if config.decoder_normalize_before:
1360
+ self.layer_norm = LayerNorm(self.embed_dim)
1361
+ else:
1362
+ self.layer_norm = None
1363
+
1364
+ self.adaptive_softmax = None
1365
+ self.output_projection = output_projection
1366
+ if self.output_projection is None:
1367
+ self.build_output_projection(config)
1368
+
1369
+ self.token_bucket_size = config.token_bucket_size
1370
+ token_num_rel_dis = 2 * config.token_bucket_size - 1
1371
+ token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
1372
+ self.token_rel_pos_table_list = nn.ModuleList(
1373
+ [
1374
+ Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True)
1375
+ for _ in range(config.decoder_layers)
1376
+ ]
1377
+ )
1378
+
1379
+ self.image_bucket_size = config.image_bucket_size
1380
+ image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3
1381
+ image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis)
1382
+ image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
1383
+ torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1
1384
+ image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
1385
+ image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)])
1386
+ self.image_rel_pos_table_list = nn.ModuleList(
1387
+ [
1388
+ Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True)
1389
+ for _ in range(config.decoder_layers)
1390
+ ]
1391
+ )
1392
+
1393
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
1394
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
1395
+ self.register_buffer("image_position_idx", image_position_idx)
1396
+ self.entangle_position_embedding = config.entangle_position_embedding
1397
+
1398
+ self.gradient_checkpointing = False
1399
+ # Initialize weights and apply final processing
1400
+ self.post_init()
1401
+
1402
+ def build_output_projection(self, config):
1403
+ if self.share_input_output_embed:
1404
+ self.output_projection = nn.Linear(
1405
+ self.embed_tokens.weight.shape[1],
1406
+ self.embed_tokens.weight.shape[0],
1407
+ bias=False,
1408
+ )
1409
+ self.output_projection.weight = self.embed_tokens.weight
1410
+ else:
1411
+ self.output_projection = nn.Linear(
1412
+ self.output_embed_dim, config.vocab_size, bias=False
1413
+ )
1414
+ nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5)
1415
+
1416
+ def get_rel_pos_bias(self, x, idx):
1417
+ r"""
1418
+ Get the relative positional bias of the text, for attention.
1419
+ """
1420
+
1421
+ seq_len = x.size(1)
1422
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
1423
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
1424
+ values = values.permute([2, 0, 1])
1425
+ return values.contiguous()
1426
+
1427
+ def get_image_rel_pos_bias(self, x, idx):
1428
+ r"""
1429
+ Get the relative positional bias of the image, for attention.
1430
+ """
1431
+
1432
+ seq_len = x.size(1)
1433
+ image_position_idx = self.image_position_idx[:seq_len]
1434
+ rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
1435
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
1436
+ values = values.permute(2, 0, 1)
1437
+ return values
1438
+
1439
+ def get_pos_info(self, tgt_pos_embed, src_pos_embed=None, use_image=False):
1440
+ r"""
1441
+ Get the positional information.
1442
+
1443
+ Args:
1444
+ tgt_pos_embed (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`):
1445
+ the target-side positional embeddings.
1446
+ src_pos_embed (`torch.FloatTensor` of shape `(bsz, src_len, embed_dim)`, *optional*):
1447
+ the source-side positional embeddings.
1448
+ use_image (`bool`): whether to use image.
1449
+
1450
+ Returns:
1451
+ abs_pos_bias (`torch.FloatTensor` of shape `(bsz, src_len, tgt_len, src_len)`):
1452
+ absolute positional bias for attention.
1453
+ """
1454
+
1455
+ batch_size = tgt_pos_embed.size(0)
1456
+ tgt_len = tgt_pos_embed.size(1)
1457
+ tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
1458
+
1459
+ if src_pos_embed is not None:
1460
+ src_len = src_pos_embed.size(1)
1461
+ pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
1462
+ batch_size, tgt_len, self.num_attention_heads, -1
1463
+ ).transpose(1, 2) * self.pos_scaling
1464
+ pos_k = self.cross_pos_k_linear(src_pos_embed).view(
1465
+ batch_size, src_len, self.num_attention_heads, -1
1466
+ ).transpose(1, 2)
1467
+ else:
1468
+ src_len = tgt_pos_embed.size(1)
1469
+ pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
1470
+ batch_size, tgt_len, self.num_attention_heads, -1
1471
+ ).transpose(1, 2) * self.pos_scaling
1472
+ pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
1473
+ batch_size, src_len, self.num_attention_heads, -1
1474
+ ).transpose(1, 2)
1475
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
1476
+
1477
+ return abs_pos_bias
1478
+
1479
+ def get_input_embeddings(self):
1480
+ r"""
1481
+ Get the input embeddings
1482
+ """
1483
+ return self.embed_tokens
1484
+
1485
+ def set_input_embeddings(self, value):
1486
+ r"""
1487
+ Set the weights of the embeddings with the given tensor.
1488
+ """
1489
+ self.embed_tokens = value
1490
+
1491
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length):
1492
+ r"""
1493
+ Create causal mask for unidirectional decoding.
1494
+ [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1495
+ """
1496
+ combined_attention_mask = None
1497
+ if input_shape[-1] > 1:
1498
+ combined_attention_mask = _make_causal_mask(
1499
+ input_shape, dtype, past_key_values_length=past_key_values_length
1500
+ ).to(self.device)
1501
+
1502
+ if attention_mask is not None:
1503
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1504
+ expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1])
1505
+ combined_attention_mask = (
1506
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1507
+ )
1508
+
1509
+ return combined_attention_mask
1510
+
1511
+ def max_positions(self):
1512
+ """Maximum output length supported by the decoder."""
1513
+ if self.embed_positions is None:
1514
+ return self.max_target_positions
1515
+ return self.max_target_positions
1516
+
1517
+ def get_normalized_probs(
1518
+ self,
1519
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1520
+ log_probs: bool,
1521
+ sample: Optional[Dict[str, Tensor]] = None,
1522
+ ):
1523
+ """Get normalized probabilities (or log probs) from a net's output."""
1524
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
1525
+
1526
+ def get_normalized_probs_scriptable(
1527
+ self,
1528
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1529
+ log_probs: bool,
1530
+ sample: Optional[Dict[str, Tensor]] = None,
1531
+ ):
1532
+ """Get normalized probabilities (or log probs) from a net's output."""
1533
+
1534
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
1535
+ if sample is not None:
1536
+ assert "target" in sample
1537
+ target = sample["target"]
1538
+ else:
1539
+ target = None
1540
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
1541
+ return out.exp_() if not log_probs else out
1542
+
1543
+ logits = net_output[0]
1544
+ if log_probs:
1545
+ return F.log_softmax(logits, dim=-1)
1546
+ else:
1547
+ return F.softmax(logits, dim=-1)
1548
+
1549
+ def reorder_incremental_state_scripting(
1550
+ self,
1551
+ # incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1552
+ past_key_values: Optional[torch.Tensor],
1553
+ new_order: Tensor,
1554
+ ):
1555
+ """Main entry point for reordering the incremental state.
1556
+
1557
+ Due to limitations in TorchScript, we call this function in
1558
+ :class:`fairseq.sequence_generator.SequenceGenerator` instead of
1559
+ calling :func:`reorder_incremental_state` directly.
1560
+ """
1561
+ input_buffer = past_key_values
1562
+ new_past_key_values = []
1563
+ if input_buffer is not None:
1564
+ for input_buffer_k in input_buffer:
1565
+ new_input_buffer_k = []
1566
+ for input in input_buffer_k:
1567
+ if input is None:
1568
+ input = None
1569
+ else:
1570
+ input = input.index_select(0, new_order)
1571
+ new_input_buffer_k.append(input)
1572
+ new_past_key_values.append(new_input_buffer_k)
1573
+ return new_past_key_values
1574
+
1575
+ def forward(
1576
+ self,
1577
+ input_ids: torch.Tensor = None,
1578
+ attention_mask: torch.Tensor = None,
1579
+ encoder_hidden_states: torch.Tensor = None,
1580
+ encoder_attention_mask: torch.Tensor = None,
1581
+ code_masks: Optional[torch.Tensor] = None,
1582
+ src_pos_embed: torch.Tensor = None,
1583
+ past_key_values: Optional[torch.Tensor] = None,
1584
+ use_cache: bool = False,
1585
+ output_attentions: bool = False,
1586
+ output_hidden_states: bool = False,
1587
+ ):
1588
+ r"""
1589
+ Args:
1590
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
1591
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): mask to avoid attention on padding tokens.
1592
+ encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden state of the encoder.
1593
+ encoder_attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): the padding mask of the source side.
1594
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
1595
+ src_pos_embed (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the positional embeddings of the source side.
1596
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
1597
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1598
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
1599
+ shape `(bsz, num_heads, src_len, head_size)`.
1600
+ use_cache (`bool`): whether to use cache for faster inference.
1601
+ output_attentions (`bool`): whether to output attention weights.
1602
+ output_hidden_states (`bool`): whether to output hidden states.
1603
+
1604
+ Returns:
1605
+ BaseModelOutputWithPastAndCrossAttentions or a plain tuple:
1606
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden states.
1607
+ past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference.
1608
+ hidden_states (`tuple(torch.FloatTensor)`): hidden states of all layers.
1609
+ attentions (`tuple(torch.FloatTensor)): self attention weights of all layers.
1610
+ cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers.
1611
+ """
1612
+
1613
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1614
+ output_hidden_states = (
1615
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1616
+ )
1617
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1618
+
1619
+ if past_key_values is not None and len(past_key_values)>0:
1620
+ size = past_key_values[0][0].size()
1621
+ bsz, tgt_len = size[0], size[-2] + 1
1622
+ token_position_idx = torch.arange(tgt_len, device=input_ids.device).expand([bsz, tgt_len]).contiguous()
1623
+ else:
1624
+ bsz, tgt_len = input_ids.shape
1625
+ token_position_idx = new_arange(input_ids)
1626
+ tgt_pos_embed = self.embed_positions(token_position_idx)
1627
+ if code_masks is not None and torch.any(code_masks):
1628
+ image_position_idx = self.image_position_idx[:input_ids.size(1)].unsqueeze(0).expand(bsz, tgt_len)
1629
+ tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
1630
+
1631
+ # self attn position bias
1632
+ self_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=False)
1633
+ if code_masks is not None and torch.any(code_masks):
1634
+ self_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=True)
1635
+ self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
1636
+ # cross attn position bias
1637
+ cross_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed)
1638
+ if code_masks is not None and torch.any(code_masks):
1639
+ cross_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
1640
+ cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
1641
+ cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
1642
+
1643
+ all_prev_output_tokens = input_ids.clone()
1644
+ if past_key_values is not None and len(past_key_values)>0:
1645
+ input_ids = input_ids[:, -1:]
1646
+ cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
1647
+ tgt_pos_embed = tgt_pos_embed[:, -1:, :]
1648
+
1649
+ # embed tokens and positions
1650
+ x = self.embed_scale * self.embed_tokens(input_ids)
1651
+
1652
+
1653
+ if self.entangle_position_embedding and not self.disable_entangle:
1654
+ x += tgt_pos_embed
1655
+
1656
+ if self.layernorm_embedding is not None:
1657
+ if code_masks is None or not code_masks.any() or not self.code_layernorm_embedding:
1658
+ x = self.layernorm_embedding(x)
1659
+ elif code_masks is not None and code_masks.all():
1660
+ x = self.code_layernorm_embedding(x)
1661
+ else:
1662
+ x[~code_masks] = self.layernorm_embedding(x[~code_masks])
1663
+ x[code_masks] = self.code_layernorm_embedding(x[code_masks])
1664
+
1665
+ hidden_states = self.dropout(x)
1666
+
1667
+ # past_key_values_length
1668
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None and len(past_key_values)>0 else 0
1669
+
1670
+ shape, dtype = input_ids.shape, hidden_states.dtype
1671
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, shape, dtype, past_key_values_length)
1672
+
1673
+ # decoder layers
1674
+ all_hidden_states = () if output_hidden_states else None
1675
+ all_self_attns = () if output_attentions else None
1676
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1677
+ next_decoder_cache = () if use_cache else None
1678
+
1679
+ # decoder layers
1680
+ for idx, layer in enumerate(self.layers):
1681
+ # add hidden states from the last decoder layer
1682
+ if output_hidden_states:
1683
+ all_hidden_states += (hidden_states,)
1684
+
1685
+ past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values)>0 else None
1686
+
1687
+ self_attn_bias = self_abs_pos_bias.clone()
1688
+ if code_masks is None or not code_masks.any():
1689
+ self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1690
+ elif code_masks is not None and code_masks.all():
1691
+ self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1692
+ else:
1693
+ self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1694
+ self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1695
+ self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
1696
+ if past_key_value is not None and len(past_key_values)>0 :
1697
+ self_attn_bias = self_attn_bias[:, -1:, :]
1698
+
1699
+ layer_outputs = layer(
1700
+ hidden_states,
1701
+ attention_mask=attention_mask,
1702
+ encoder_hidden_states=encoder_hidden_states,
1703
+ encoder_attention_mask=encoder_attention_mask,
1704
+ past_key_value=past_key_value,
1705
+ output_attentions=output_attentions,
1706
+ use_cache=use_cache,
1707
+ self_attn_bias=self_attn_bias,
1708
+ cross_attn_bias=cross_abs_pos_bias,
1709
+ )
1710
+ hidden_states = layer_outputs[0]
1711
+
1712
+ if use_cache:
1713
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1714
+
1715
+ if output_attentions:
1716
+ all_self_attns += (layer_outputs[1],)
1717
+
1718
+ if encoder_hidden_states is not None:
1719
+ all_cross_attentions += (layer_outputs[2],)
1720
+
1721
+ # add hidden states from the last decoder layer
1722
+ if output_hidden_states:
1723
+ all_hidden_states += (hidden_states,)
1724
+
1725
+ next_cache = next_decoder_cache if use_cache else None
1726
+
1727
+ if self.layer_norm is not None:
1728
+ hidden_states = self.layer_norm(hidden_states)
1729
+
1730
+ if self.output_projection is not None:
1731
+ hidden_states = self.output_projection(hidden_states)
1732
+
1733
+ return BaseModelOutputWithPastAndCrossAttentions(
1734
+ last_hidden_state=hidden_states,
1735
+ past_key_values=next_cache,
1736
+ hidden_states=all_hidden_states,
1737
+ attentions=all_self_attns,
1738
+ cross_attentions=all_cross_attentions,
1739
+ )
1740
+
1741
+
1742
+ @add_start_docstrings(
1743
+ "The bare OFA Model outputting raw hidden-states without any specific head on top.",
1744
+ OFA_START_DOCSTRING,
1745
+ )
1746
+ class OFAModel(OFAPreTrainedModel):
1747
+ r"""
1748
+ The OFA model built with an encoder and a decoder only, without any classification head.
1749
+
1750
+ Args:
1751
+ config (OFAConfig): OFA configuration.
1752
+ """
1753
+
1754
+ def __init__(self, config: OFAConfig, **kwargs):
1755
+ super().__init__(config)
1756
+ self.disable_entangle = getattr(kwargs,'disable_entangle',False)
1757
+
1758
+ self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1759
+ shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx)
1760
+
1761
+ self.encoder = OFAEncoder(config, shared)
1762
+ self.decoder = OFADecoder(config, shared)
1763
+
1764
+ # Initialize weights and apply final processing
1765
+ self.post_init()
1766
+
1767
+ def get_input_embeddings(self):
1768
+ r"""
1769
+ Retrieve input embeddings.
1770
+ """
1771
+ return self.encoder.get_input_embeddings()
1772
+
1773
+ def set_input_embeddings(self, value):
1774
+ r"""
1775
+ Set values for input embeddings
1776
+ """
1777
+ shared = value
1778
+ self.encoder.embed_tokens = shared
1779
+ self.decoder.embed_tokens = shared
1780
+
1781
+ def get_encoder(self):
1782
+ r"""
1783
+ Retrieve the encoder
1784
+ """
1785
+ return self.encoder
1786
+
1787
+ def get_decoder(self):
1788
+ r"""
1789
+ Retrieve the decoder
1790
+ """
1791
+ return self.decoder
1792
+
1793
+ @add_start_docstrings_to_model_forward(OFA_INPUTS_DOCSTRING)
1794
+ @add_code_sample_docstrings(
1795
+ processor_class=_TOKENIZER_FOR_DOC,
1796
+ checkpoint=_CHECKPOINT_FOR_DOC,
1797
+ output_type=Seq2SeqModelOutput,
1798
+ config_class=_CONFIG_FOR_DOC,
1799
+ )
1800
+
1801
+ def max_decoder_positions(self):
1802
+ """Maximum length supported by the decoder."""
1803
+ return self.decoder.max_positions()
1804
+
1805
+ def get_normalized_probs(
1806
+ self,
1807
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1808
+ log_probs: bool,
1809
+ sample: Optional[Dict[str, Tensor]] = None,
1810
+ ):
1811
+ """Get normalized probabilities (or log probs) from a net's output."""
1812
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
1813
+
1814
+
1815
+ def get_normalized_probs_scriptable(
1816
+ self,
1817
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1818
+ log_probs: bool,
1819
+ sample: Optional[Dict[str, Tensor]] = None,
1820
+ ):
1821
+ """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
1822
+ if hasattr(self, "decoder"):
1823
+ return self.decoder.get_normalized_probs(net_output, log_probs, sample)
1824
+ elif torch.is_tensor(net_output):
1825
+ # syntactic sugar for simple models which don't have a decoder
1826
+ # (e.g., the classification tutorial)
1827
+ logits = net_output.float()
1828
+ if log_probs:
1829
+ return F.log_softmax(logits, dim=-1)
1830
+ else:
1831
+ return F.softmax(logits, dim=-1)
1832
+ raise NotImplementedError
1833
+
1834
+ def forward(
1835
+ self,
1836
+ input_ids=None,
1837
+ patch_images=None,
1838
+ patch_images_2=None,
1839
+ patch_masks=None,
1840
+ token_embeddings=None,
1841
+ sample_patch_num=None,
1842
+ decoder_input_ids=None,
1843
+ code_masks=None,
1844
+ attention_mask=None,
1845
+ encoder_outputs=None,
1846
+ past_key_values=None,
1847
+ use_cache=False,
1848
+ output_attentions=False,
1849
+ output_hidden_states=False,
1850
+ return_dict=False
1851
+ ):
1852
+ r"""
1853
+ Args:
1854
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
1855
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
1856
+
1857
+ indices can be obtained using [`~OFATokenizer`].
1858
+
1859
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1860
+ the resized image, which are transformed by the default operations.
1861
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1862
+ the second (if it exists) image.
1863
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
1864
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
1865
+ sample_patch_num (`int`): the number of patches to sample.
1866
+ decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
1867
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
1868
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding.
1869
+ encoder_outputs (`OFAEncoderOutput`):
1870
+ encoder outputs with hidden states, positional embeddings, and padding masks.
1871
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
1872
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1873
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
1874
+ shape `(bsz, num_heads, src_len, head_size)`.
1875
+ use_cache (`bool`): whether to use cache for faster inference.
1876
+ output_attentions (`bool`): whether to output attention weights.
1877
+ output_hidden_states (`bool`): whether to output hidden states.
1878
+ return_dict (`bool`): unused. Keep it for generation only.
1879
+
1880
+ Returns:
1881
+ Seq2SeqLMOutput:
1882
+ logits (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states.
1883
+ past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference.
1884
+ decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers.
1885
+ decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers.
1886
+ cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers.
1887
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1888
+ the encoder last hidden state.
1889
+ encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1890
+ the encoder states of all layers including the embeddings.
1891
+ encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`):
1892
+ the encoder attention weights of all layers.
1893
+ """
1894
+
1895
+ output_attentions = output_attentions if output_attentions else self.config.output_attentions
1896
+ output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
1897
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1898
+
1899
+ if encoder_outputs is None:
1900
+ encoder_outputs = self.encoder(
1901
+ input_ids=input_ids,
1902
+ patch_images=patch_images,
1903
+ patch_images_2=patch_images_2,
1904
+ patch_masks=patch_masks,
1905
+ output_attentions=output_attentions,
1906
+ output_hidden_states=output_hidden_states,
1907
+ token_embeddings=token_embeddings,
1908
+ sample_patch_num=sample_patch_num,
1909
+ )
1910
+
1911
+ # if decoder_input_ids.eq(self.config.pad_token_id).any():
1912
+ # attention_mask = decoder_input_ids.eq(self.padding_idx)
1913
+
1914
+ encoder_hidden_states = encoder_outputs.last_hidden_state
1915
+ if past_key_values is not None and len(past_key_values)>0:
1916
+ encoder_attention_mask = _expand_mask(
1917
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids[:, -1:].shape[-1]
1918
+ )
1919
+ else:
1920
+ encoder_attention_mask = _expand_mask(
1921
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids.shape[-1]
1922
+ )
1923
+ src_pos_embed = encoder_outputs.position_embedding
1924
+
1925
+ decoder_outputs = self.decoder(
1926
+ input_ids=decoder_input_ids,
1927
+ attention_mask=attention_mask,
1928
+ encoder_hidden_states=encoder_hidden_states,
1929
+ encoder_attention_mask=encoder_attention_mask,
1930
+ code_masks=code_masks,
1931
+ src_pos_embed=src_pos_embed,
1932
+ past_key_values=past_key_values,
1933
+ use_cache=use_cache,
1934
+ output_attentions=output_attentions,
1935
+ output_hidden_states=output_hidden_states,
1936
+ )
1937
+
1938
+ return Seq2SeqLMOutput(
1939
+ logits=decoder_outputs.last_hidden_state,
1940
+ past_key_values=decoder_outputs.past_key_values,
1941
+ decoder_hidden_states=decoder_outputs.hidden_states,
1942
+ decoder_attentions=decoder_outputs.attentions,
1943
+ cross_attentions=decoder_outputs.cross_attentions,
1944
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1945
+ encoder_hidden_states=encoder_outputs.hidden_states,
1946
+ encoder_attentions=encoder_outputs.attentions,
1947
+ )
1948
+
1949
+ def prepare_inputs_for_generation(
1950
+ self,
1951
+ decoder_input_ids=None,
1952
+ past=None,
1953
+ attention_mask=None,
1954
+ code_masks=None,
1955
+ use_cache=False,
1956
+ encoder_outputs=None,
1957
+ **kwargs
1958
+ ):
1959
+ # if attention_mask is None:
1960
+ attention_mask = decoder_input_ids.new_ones(decoder_input_ids.shape)
1961
+
1962
+ # cut decoder_input_ids if past is used
1963
+ # if past is not None:
1964
+ # decoder_input_ids = decoder_input_ids[:, -1:]
1965
+
1966
+ return {
1967
+ "input_ids": None,
1968
+ "patch_images": None,
1969
+ "patch_images_2": None,
1970
+ "patch_masks": None,
1971
+ "token_embeddings": None,
1972
+ "sample_patch_num": None,
1973
+ "attention_mask": attention_mask,
1974
+ "encoder_outputs": encoder_outputs,
1975
+ "past_key_values": past,
1976
+ "decoder_input_ids": decoder_input_ids,
1977
+ "code_masks": code_masks,
1978
+ "use_cache": use_cache,
1979
+ }
1980
+
1981
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1982
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1983
+
1984
+ def _prepare_encoder_decoder_kwargs_for_generation(
1985
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
1986
+ ):
1987
+ # 1. get encoder
1988
+ encoder = self.get_encoder()
1989
+
1990
+ # 2. prepare encoder args and encoder kwargs from model kwargs
1991
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "attention_mask"]
1992
+ encoder_kwargs = {
1993
+ argument: value
1994
+ for argument, value in model_kwargs.items()
1995
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
1996
+ }
1997
+
1998
+ if encoder_kwargs.get("patch_masks") is None:
1999
+ encoder_kwargs["patch_masks"] = torch.ones((len(inputs_tensor), 1), dtype=torch.bool, device=inputs_tensor.device)
2000
+
2001
+ # 3. make sure that encoder returns `ModelOutput`
2002
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
2003
+ encoder_kwargs[model_input_name] = inputs_tensor
2004
+ model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
2005
+ model_kwargs["attention_mask"] = None
2006
+
2007
+ return model_kwargs
2008
+
2009
+ @staticmethod
2010
+ def _reorder_cache(past, beam_idx):
2011
+ reordered_past = ()
2012
+ for layer_past in past:
2013
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2014
+ return reordered_past
2015
+
2016
+ @staticmethod
2017
+ def _expand_inputs_for_generation(
2018
+ input_ids: torch.LongTensor,
2019
+ expand_size: int = 1,
2020
+ is_encoder_decoder: bool = False,
2021
+ attention_mask: Optional[torch.LongTensor] = None,
2022
+ encoder_outputs: Optional[ModelOutput] = None,
2023
+ **model_kwargs,
2024
+ ):
2025
+ expanded_return_idx = (
2026
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
2027
+ )
2028
+ input_ids = input_ids.index_select(0, expanded_return_idx)
2029
+
2030
+ if "token_type_ids" in model_kwargs:
2031
+ token_type_ids = model_kwargs["token_type_ids"]
2032
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
2033
+
2034
+ if attention_mask is not None:
2035
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
2036
+
2037
+ if is_encoder_decoder:
2038
+ if encoder_outputs is None:
2039
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
2040
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
2041
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
2042
+ )
2043
+ encoder_outputs["position_embedding"] = encoder_outputs.position_embedding.index_select(
2044
+ 0, expanded_return_idx.to(encoder_outputs.position_embedding.device)
2045
+ )
2046
+ encoder_outputs["padding_mask"] = encoder_outputs.padding_mask.index_select(
2047
+ 0, expanded_return_idx.to(encoder_outputs.padding_mask.device)
2048
+ )
2049
+ model_kwargs["encoder_outputs"] = encoder_outputs
2050
+ return input_ids, model_kwargs
2051
+
2052
+
2053
+ class OFAModelForCaption(OFAModel):
2054
+
2055
+ def forward(
2056
+ self,
2057
+ input_ids=None,
2058
+ patch_images=None,
2059
+ patch_images_2=None,
2060
+ patch_masks=None,
2061
+ token_embeddings=None,
2062
+ sample_patch_num=None,
2063
+ decoder_input_ids=None,
2064
+ code_masks=None,
2065
+ attention_mask=None,
2066
+ encoder_outputs=None,
2067
+ past_key_values=None,
2068
+ use_cache=False,
2069
+ output_attentions=False,
2070
+ output_hidden_states=False,
2071
+ return_dict=False,
2072
+ return_loss=False
2073
+ ):
2074
+
2075
+ output_attentions = output_attentions if output_attentions else self.config.output_attentions
2076
+ output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
2077
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
2078
+
2079
+ if encoder_outputs is None:
2080
+ encoder_outputs = self.encoder(
2081
+ input_ids=input_ids,
2082
+ patch_images=patch_images,
2083
+ patch_images_2=patch_images_2,
2084
+ patch_masks=patch_masks,
2085
+ output_attentions=output_attentions,
2086
+ output_hidden_states=output_hidden_states,
2087
+ token_embeddings=token_embeddings,
2088
+ sample_patch_num=sample_patch_num,
2089
+ )
2090
+
2091
+ # if decoder_input_ids.eq(self.config.pad_token_id).any():
2092
+ # attention_mask = decoder_input_ids.eq(self.padding_idx)
2093
+
2094
+ encoder_hidden_states = encoder_outputs.last_hidden_state
2095
+ if past_key_values is not None and len(past_key_values)>0:
2096
+ encoder_attention_mask = _expand_mask(
2097
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids[:, -1:].shape[-1]
2098
+ )
2099
+ else:
2100
+ encoder_attention_mask = _expand_mask(
2101
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids.shape[-1]
2102
+ )
2103
+ src_pos_embed = encoder_outputs.position_embedding
2104
+
2105
+ decoder_outputs = self.decoder(
2106
+ input_ids=decoder_input_ids,
2107
+ attention_mask=attention_mask,
2108
+ encoder_hidden_states=encoder_hidden_states,
2109
+ encoder_attention_mask=encoder_attention_mask,
2110
+ code_masks=code_masks,
2111
+ src_pos_embed=src_pos_embed,
2112
+ past_key_values=past_key_values,
2113
+ use_cache=use_cache,
2114
+ output_attentions=output_attentions,
2115
+ output_hidden_states=output_hidden_states,
2116
+ )
2117
+
2118
+ loss = None
2119
+ if return_loss:
2120
+ lm_logits = decoder_outputs.last_hidden_state
2121
+ # Shift so that tokens < n predict n
2122
+ shift_logits = lm_logits[..., :-1, :].contiguous()
2123
+ shift_labels = decoder_input_ids[..., 1:].contiguous()
2124
+ # Flatten the tokens
2125
+ loss_fct = CrossEntropyLoss()
2126
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
2127
+
2128
+ return Seq2SeqLMOutput(
2129
+ loss=loss,
2130
+ logits=decoder_outputs.last_hidden_state,
2131
+ past_key_values=decoder_outputs.past_key_values,
2132
+ decoder_hidden_states=decoder_outputs.hidden_states,
2133
+ decoder_attentions=decoder_outputs.attentions,
2134
+ cross_attentions=decoder_outputs.cross_attentions,
2135
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
2136
+ encoder_hidden_states=encoder_outputs.hidden_states,
2137
+ encoder_attentions=encoder_outputs.attentions,
2138
+ )
2139
+
component/ofa/resnet.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team. All rights reserved.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
9
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
10
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
11
+ the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper...
12
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
13
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use
14
+ 'survival rate' as the argument.
15
+ """
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
21
+ random_tensor.floor_() # binarize
22
+ output = x.div(keep_prob) * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
35
+
36
+
37
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
38
+ """3x3 convolution with padding"""
39
+ return nn.Conv2d(
40
+ in_planes,
41
+ out_planes,
42
+ kernel_size=3,
43
+ stride=stride,
44
+ padding=dilation,
45
+ groups=groups,
46
+ bias=False,
47
+ dilation=dilation,
48
+ )
49
+
50
+
51
+ def conv1x1(in_planes, out_planes, stride=1):
52
+ """1x1 convolution"""
53
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
54
+
55
+
56
+ class BasicBlock(nn.Module):
57
+ expansion = 1
58
+
59
+ def __init__(
60
+ self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
61
+ ):
62
+ super(BasicBlock, self).__init__()
63
+ if norm_layer is None:
64
+ norm_layer = nn.BatchNorm2d
65
+ if groups != 1 or base_width != 64:
66
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
67
+ if dilation > 1:
68
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
69
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
70
+ self.conv1 = conv3x3(inplanes, planes, stride)
71
+ self.bn1 = norm_layer(planes)
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.conv2 = conv3x3(planes, planes)
74
+ self.bn2 = norm_layer(planes)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ assert False
80
+ identity = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+
89
+ if self.downsample is not None:
90
+ identity = self.downsample(x)
91
+
92
+ out += identity
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class Bottleneck(nn.Module):
99
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
100
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
101
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
102
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
103
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
104
+
105
+ expansion = 4
106
+
107
+ def __init__(
108
+ self,
109
+ inplanes,
110
+ planes,
111
+ stride=1,
112
+ downsample=None,
113
+ groups=1,
114
+ base_width=64,
115
+ dilation=1,
116
+ norm_layer=None,
117
+ drop_path_rate=0.0,
118
+ ):
119
+ super(Bottleneck, self).__init__()
120
+ if norm_layer is None:
121
+ norm_layer = nn.BatchNorm2d
122
+ width = int(planes * (base_width / 64.0)) * groups
123
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
124
+ self.conv1 = conv1x1(inplanes, width)
125
+ self.bn1 = norm_layer(width)
126
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
127
+ self.bn2 = norm_layer(width)
128
+ self.conv3 = conv1x1(width, planes * self.expansion)
129
+ self.bn3 = norm_layer(planes * self.expansion)
130
+ self.relu = nn.ReLU(inplace=True)
131
+ self.downsample = downsample
132
+ self.stride = stride
133
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
134
+
135
+ def forward(self, x):
136
+ identity = x
137
+
138
+ out = self.conv1(x)
139
+ out = self.bn1(out)
140
+ out = self.relu(out)
141
+
142
+ out = self.conv2(out)
143
+ out = self.bn2(out)
144
+ out = self.relu(out)
145
+
146
+ out = self.conv3(out)
147
+ out = self.bn3(out)
148
+
149
+ if self.downsample is not None:
150
+ identity = self.downsample(x)
151
+
152
+ out = identity + self.drop_path(out)
153
+ out = self.relu(out)
154
+
155
+ return out
156
+
157
+
158
+ class ResNet(nn.Module):
159
+ def __init__(
160
+ self,
161
+ layers,
162
+ zero_init_residual=False,
163
+ groups=1,
164
+ width_per_group=64,
165
+ replace_stride_with_dilation=None,
166
+ norm_layer=None,
167
+ drop_path_rate=0.0,
168
+ ):
169
+ super(ResNet, self).__init__()
170
+ if norm_layer is None:
171
+ norm_layer = nn.BatchNorm2d
172
+ self._norm_layer = norm_layer
173
+
174
+ self.inplanes = 64
175
+ self.dilation = 1
176
+ if replace_stride_with_dilation is None:
177
+ # each element in the tuple indicates if we should replace
178
+ # the 2x2 stride with a dilated convolution instead
179
+ replace_stride_with_dilation = [False, False, False]
180
+ if len(replace_stride_with_dilation) != 3:
181
+ raise ValueError(
182
+ "replace_stride_with_dilation should be None "
183
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
184
+ )
185
+ self.groups = groups
186
+ self.base_width = width_per_group
187
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
188
+ self.bn1 = norm_layer(self.inplanes)
189
+ self.relu = nn.ReLU(inplace=True)
190
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
191
+ self.layer1 = self._make_layer(Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate)
192
+ self.layer2 = self._make_layer(
193
+ Bottleneck, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], drop_path_rate=drop_path_rate
194
+ )
195
+ self.layer3 = self._make_layer(
196
+ Bottleneck, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], drop_path_rate=drop_path_rate
197
+ )
198
+
199
+ for m in self.modules():
200
+ if isinstance(m, nn.Conv2d):
201
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
202
+ elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)):
203
+ nn.init.constant_(m.weight, 1)
204
+ nn.init.constant_(m.bias, 0)
205
+
206
+ # Zero-initialize the last BN in each residual branch,
207
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
208
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
209
+ if zero_init_residual:
210
+ for m in self.modules():
211
+ if isinstance(m, Bottleneck):
212
+ nn.init.constant_(m.bn3.weight, 0)
213
+ elif isinstance(m, BasicBlock):
214
+ nn.init.constant_(m.bn2.weight, 0)
215
+
216
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, drop_path_rate=0.0):
217
+ norm_layer = self._norm_layer
218
+ downsample = None
219
+ previous_dilation = self.dilation
220
+ if dilate:
221
+ self.dilation *= stride
222
+ stride = 1
223
+ if stride != 1 or self.inplanes != planes * block.expansion:
224
+ downsample = nn.Sequential(
225
+ conv1x1(self.inplanes, planes * block.expansion, stride),
226
+ norm_layer(planes * block.expansion),
227
+ )
228
+
229
+ layers = []
230
+ layers.append(
231
+ block(
232
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
233
+ )
234
+ )
235
+ self.inplanes = planes * block.expansion
236
+
237
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)]
238
+ for i in range(1, blocks):
239
+ layers.append(
240
+ block(
241
+ self.inplanes,
242
+ planes,
243
+ groups=self.groups,
244
+ base_width=self.base_width,
245
+ dilation=self.dilation,
246
+ norm_layer=norm_layer,
247
+ drop_path_rate=dpr[i],
248
+ )
249
+ )
250
+
251
+ return nn.Sequential(*layers)
252
+
253
+ def _forward_impl(self, x):
254
+ # See note [TorchScript super()]
255
+ x = self.conv1(x)
256
+ x = self.bn1(x)
257
+ x = self.relu(x)
258
+ x = self.maxpool(x)
259
+
260
+ x = self.layer1(x)
261
+ x = self.layer2(x)
262
+ x = self.layer3(x)
263
+
264
+ return x
265
+
266
+ def forward(self, x):
267
+ return self._forward_impl(x)
ofa.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### https://github.com/yangjianxin1/OFA-Chinese
2
+
3
+ from component.ofa.modeling_ofa import OFAModelForCaption
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from transformers import BertTokenizerFast
7
+ import torch
8
+ import pathlib
9
+ import pandas as pd
10
+ import numpy as np
11
+ from IPython.core.display import HTML
12
+ import os
13
+ import requests
14
+
15
+ # 定义图片预处理逻辑
16
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
17
+ resolution = 256
18
+ patch_resize_transform = transforms.Compose([
19
+ lambda image: image.convert("RGB"),
20
+ transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=mean, std=std)
23
+ ])
24
+
25
+ class OFA(object):
26
+ def __init__(self ,model_path = 'YeungNLP/ofa-cn-base-muge-v2',
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
28
+ ):
29
+ self.device = device
30
+ self.model = OFAModelForCaption.from_pretrained(model_path)
31
+ self.tokenizer = BertTokenizerFast.from_pretrained(model_path)
32
+ self.model = self.model.to(self.device)
33
+
34
+ def predict_to_df(self, image_paths):
35
+ img_caption_pred = self.predict_step(image_paths)
36
+ img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
37
+ img_cation_df.columns = ["img", "caption"]
38
+ return img_cation_df
39
+ #img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
40
+
41
+ def predict_step(self ,image_paths):
42
+ images = []
43
+ for image_path in image_paths:
44
+ #i_image = Image.open(image_path)
45
+ if image_path.startswith("http"):
46
+ i_image = Image.open(
47
+ requests.get(image_path, stream=True).raw
48
+ )
49
+ else:
50
+ i_image = Image.open(image_path)
51
+
52
+ if i_image.mode != "RGB":
53
+ i_image = i_image.convert(mode="RGB")
54
+ patch_img = patch_resize_transform(i_image).unsqueeze(0)
55
+ images.append(patch_img)
56
+
57
+ txt = '图片描述了什么?'
58
+ inputs = self.tokenizer([txt], return_tensors="pt").input_ids
59
+ inputs = inputs.to(self.device)
60
+ req = []
61
+ for patch_img in images:
62
+ # 生成caption
63
+ patch_img = patch_img.to(self.device)
64
+ gen = self.model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
65
+ gen = self.tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
66
+ gen = gen.replace(" ", "").strip()
67
+ req.append(gen)
68
+ return req
69
+
70
+ def path_to_image_html(path):
71
+ return '<img src="'+ path + '" width="60" >'
72
+
73
+ if __name__ == "__main__":
74
+ #### build too slow
75
+ ofa_obj = OFA()
76
+
77
+ img_path_l = pd.Series(list(pathlib.Path("../../pic").rglob("*"))).map(
78
+ lambda x: x.__fspath__()
79
+ ).map(str).map(lambda x: np.nan if "._" in x else x).dropna().values.tolist()
80
+ img_path_l
81
+
82
+ img_caption_ofa_df = ofa_obj.predict_to_df(img_path_l)
83
+
84
+ HTML(img_caption_ofa_df.to_html(escape=False, formatters=dict(img=path_to_image_html)))
predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Obj:
2
+ def __init__(self, model, tokenizer, device = "cpu"):
3
+ self.model = model
4
+ self.tokenizer = tokenizer
5
+ self.device = device
6
+ self.model = self.model.to(self.device)
7
+
8
+ def predict(
9
+ self,
10
+ source_text: str,
11
+ max_length: int = 512,
12
+ num_return_sequences: int = 1,
13
+ num_beams: int = 2,
14
+ top_k: int = 50,
15
+ top_p: float = 0.95,
16
+ do_sample: bool = True,
17
+ repetition_penalty: float = 2.5,
18
+ length_penalty: float = 1.0,
19
+ early_stopping: bool = True,
20
+ skip_special_tokens: bool = True,
21
+ clean_up_tokenization_spaces: bool = True,
22
+ ):
23
+ input_ids = self.tokenizer.encode(
24
+ source_text, return_tensors="pt", add_special_tokens=True
25
+ )
26
+ input_ids = input_ids.to(self.device)
27
+ generated_ids = self.model.generate(
28
+ input_ids=input_ids,
29
+ num_beams=num_beams,
30
+ max_length=max_length,
31
+ repetition_penalty=repetition_penalty,
32
+ length_penalty=length_penalty,
33
+ early_stopping=early_stopping,
34
+ top_p=top_p,
35
+ top_k=top_k,
36
+ num_return_sequences=num_return_sequences,
37
+ do_sample = do_sample
38
+ )
39
+ preds = [
40
+ self.tokenizer.decode(
41
+ g,
42
+ skip_special_tokens=skip_special_tokens,
43
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
44
+ )
45
+ for g in generated_ids
46
+ ]
47
+ return preds
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ jieba
4
+ rapidfuzz
5
+ ipykernel
summary_reverse_pred_native.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### Chinese scope
2
+ #device = "cuda:0"
3
+ device = "cpu"
4
+ assert device.startswith("cpu") or device.startswith("cuda")
5
+
6
+ import sys
7
+ from predict import *
8
+
9
+ from transformers import (
10
+ T5ForConditionalGeneration,
11
+ MT5ForConditionalGeneration,
12
+ ByT5Tokenizer,
13
+ PreTrainedTokenizer,
14
+ T5TokenizerFast as T5Tokenizer,
15
+ MT5TokenizerFast as MT5Tokenizer,
16
+ AutoModelForSeq2SeqLM,
17
+ AutoTokenizer,
18
+ BertTokenizer,
19
+ GPT2LMHeadModel,
20
+ )
21
+
22
+ import pandas as pd
23
+ import numpy as np
24
+ import re
25
+ from rapidfuzz import fuzz
26
+ from tqdm import tqdm
27
+ import numpy as np
28
+ import os
29
+
30
+ import jieba
31
+ def repeat_to_one_f(x):
32
+ req = None
33
+ for token in jieba.lcut(x):
34
+ #print("req :", req)
35
+
36
+ if len(set(token)) == 1:
37
+ token = token[0]
38
+ if req is None:
39
+ req = token
40
+ else:
41
+
42
+ if (token in req and token not in [',', ',', '、', ' ']) or (req and token in [',', ',', '、', ' '] and req[-1] in [',', ',', '、', ' ']):
43
+ continue
44
+ else:
45
+ while req.endswith(token[0]):
46
+ token = token[1:]
47
+ req = req + token
48
+ if req is None:
49
+ return ""
50
+ return req.strip()
51
+
52
+ def shorten_exists(l, sim_threshold = 80, slice_size = 5):
53
+ req = []
54
+ for ele in l:
55
+ if not req:
56
+ req.append(ele)
57
+ else:
58
+ if max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), req)) < sim_threshold:
59
+ req.append(ele)
60
+ return req
61
+
62
+ model_path = "svjack/summary-dialogue"
63
+ tokenizer0 = T5Tokenizer.from_pretrained(model_path)
64
+ model0 = T5ForConditionalGeneration.from_pretrained(model_path)
65
+
66
+ if device.startswith("cuda"):
67
+ model = Obj(model0, tokenizer0, device = "cuda:0")
68
+ else:
69
+ model = Obj(model0, tokenizer0, device = "cpu")
70
+
71
+ def loop_add(l, names = ["杰克", "安娜"]):
72
+ req = []
73
+ for i in range(len(l)):
74
+ ii = int(i % len(names))
75
+ req.append(
76
+ "{}:{}".format(names[ii], l[i])
77
+ )
78
+ return req
79
+
80
+ #### need some names drop in context(may not have ":")
81
+ #### '艾米-亚当斯在《沉睡的空洞》中,全身,双色大眼睛,咬牙切齿,恐怖,复杂的细节,电影,史诗,现实,解剖,汤姆-哈努卡,上光,艺术站,逼真,可怕'
82
+ def guess_name_candidates(context, cnt_threshold = 1):
83
+ from copy import deepcopy
84
+ assert type(context) == type("")
85
+ import re
86
+ l = re.findall(r"[\u4e00-\u9fa5a-zA-Z]+:", context)
87
+ l = list(filter(lambda x: x.strip(), l))
88
+ ori_l = deepcopy(l)
89
+ if not l:
90
+ return []
91
+ s = pd.Series(l).value_counts()
92
+ l = pd.Series(s[s > cnt_threshold].index.values.tolist()).map(lambda x: x[:-1]).values.tolist()
93
+ for ele in ori_l:
94
+ if len(ele[:-1]) not in l and (len(ele[:-1]) <= 3 or (
95
+ sum(map(len ,re.findall(r"[a-zA-Z]+:", ele))) == len(ele)
96
+ )):
97
+ l.append(ele[:-1])
98
+ l = list(set(l))
99
+ return l
100
+
101
+ def simple_pred(summary, candidates = ["杰克", "安娜"],
102
+ shorten_it = False, do_sample = True):
103
+ pred_text = model.predict(
104
+ "摘要:{} 候选集:{}".format(summary, " ".join(candidates)),
105
+ do_sample = do_sample
106
+ )[0]
107
+ candidates_ = guess_name_candidates(pred_text)
108
+ l = re.split("{}".format("|".join(map(lambda x: "{}:".format(x), candidates_))) ,pred_text)
109
+ l = list(filter(lambda x: x.strip(), l))
110
+ if shorten_it:
111
+ l = shorten_exists(l)
112
+ l = list(map(repeat_to_one_f, l))
113
+ l = loop_add(l, candidates)
114
+ return l
115
+
116
+ def percentile_sort(df, perc_num = 101):
117
+ score_tuple_s = df["score_tuple"]
118
+ score_array = np.asarray(score_tuple_s.values.tolist())
119
+ perc_list = np.linspace(0, 100, perc_num).tolist()
120
+ low_to_high_perc_array = np.stack(list(map(lambda p: np.percentile(score_array, p, axis = 0), perc_list)))
121
+
122
+ def get_rank(array_):
123
+ lookup_list = pd.DataFrame(array_ - low_to_high_perc_array[::-1]).apply(lambda s: min(s) >= 0, axis = 1).tolist()
124
+ if True not in lookup_list:
125
+ return len(lookup_list)
126
+ return lookup_list.index(True)
127
+
128
+ rank_list = []
129
+ for i in range(score_array.shape[0]):
130
+ rank_list.append(get_rank(score_array[i, :]))
131
+
132
+ rank_s = pd.Series(rank_list)
133
+ return df.iloc[np.argsort(rank_s.values)]
134
+
135
+ def repeat_score(l, slice_size = 200 ,sim_threshold = 70):
136
+ from copy import deepcopy
137
+ assert type(l) == type([])
138
+ l = deepcopy(l)
139
+ l = sorted(l)
140
+ cnt_num = 0
141
+ set0 = set([])
142
+ for ele in l:
143
+ if ":" in ele:
144
+ ele = "".join(ele.split(":")[1:])
145
+ if set0 and max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), set0)) > sim_threshold:
146
+ #if ele in set0:
147
+ cnt_num += 1
148
+ set0.add(ele)
149
+ return cnt_num
150
+
151
+ #### "svjack/prompt-extend-chinese-gpt"
152
+ #model_path = "/home/featurize/zh_p_extend_outputs/simplet5-epoch-3-train-loss-1.2628-val-loss-1.6293"
153
+ model_path = "svjack/prompt-extend-chinese-gpt"
154
+ tokenizer1 = BertTokenizer.from_pretrained(model_path)
155
+ model1 = GPT2LMHeadModel.from_pretrained(model_path)
156
+
157
+ if device.startswith("cuda"):
158
+ zh_pe_model = Obj(model1, tokenizer1, device = "cuda:0")
159
+ else:
160
+ zh_pe_model = Obj(model1, tokenizer1, device = "cpu")
161
+
162
+ def one_ele_trans(x):
163
+ x = x.strip()
164
+ x = x[1:] if x.startswith("'") else x
165
+ x = x[:-1] if x.endswith("'") else x
166
+ x = x[1:] if x.startswith('"') else x
167
+ x = x[:-1] if x.endswith('"') else x
168
+ return x
169
+
170
+ def stdf_prompt_expander(x):
171
+ assert type(x) == type("")
172
+ return zh_pe_model.predict(
173
+ one_ele_trans(x.strip()).strip(),
174
+ max_length = 128
175
+ )[0].replace(" ", "").strip()
176
+
177
+ def sample_pred(context, times = 5, stdf_prompt_expander = lambda _: _):
178
+ df_req = []
179
+ for i in tqdm(range(times)):
180
+ ele = stdf_prompt_expander(context)
181
+ #ele = context
182
+ l = simple_pred(ele, do_sample = True)
183
+ df_req.append(
184
+ [ele, l]
185
+ )
186
+ df = pd.DataFrame(df_req)
187
+ df.columns = ["context", "dialogue"]
188
+ df["fuzz"] = df["dialogue"].map(
189
+ lambda x: fuzz.ratio(context, " ".join(x))
190
+ )
191
+ df["max_fuzz"] = df["dialogue"].map(
192
+ lambda x: max(map(lambda y: fuzz.ratio(y, context), x))
193
+ )
194
+ df["length"] = df["dialogue"].map(len)
195
+ df["rpt_score"] = df["dialogue"].map(repeat_score)
196
+ df["score_tuple"] = df.apply(
197
+ lambda x: (x["fuzz"], -1 * x["max_fuzz"], x["length"], -1 * x["rpt_score"]), axis = 1
198
+ )
199
+ df = percentile_sort(df)
200
+ return df
201
+
202
+ def sample_pred_wrapper(context, i2c_obj, times = 5, extend_by_diffusion = False):
203
+ assert type(context) == type("")
204
+ if any(map(lambda x: context.endswith(x), [".jpg", ".png", ".jpeg"])):
205
+ img_path = context
206
+ i2c_df = i2c_obj.predict_to_df([img_path])
207
+ assert i2c_df.size > 0
208
+ context = i2c_df["caption"].iloc[0]
209
+ else:
210
+ pass
211
+ assert type(context) == type("")
212
+ if extend_by_diffusion:
213
+ req_df = sample_pred(context, times = times, stdf_prompt_expander = stdf_prompt_expander)
214
+ else:
215
+ req_df = sample_pred(context, times = times, stdf_prompt_expander = lambda _:_)
216
+ return req_df
217
+
218
+ from ofa import *
219
+ ofa_obj = OFA()
220
+
221
+ if __name__ == "__main__":
222
+ '''
223
+ from image2caption import *
224
+ i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
225
+ overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
226
+ overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M",
227
+ device = device
228
+ )
229
+ '''
230
+ from ofa import *
231
+ ofa_obj = OFA()
232
+
233
+ img_path = "../pic/bug.jpg"
234
+ img_path = "../pic/baobao.jpeg"
235
+ img_path = "../pic/cat0.jpg"
236
+ img_path = "../pic/cat.jpg"
237
+ os.path.exists(img_path)
238
+
239
+ df = sample_pred_wrapper(img_path, i2c_obj = ofa_obj)
240
+ df["dialogue"].values.tolist()
241
+
242
+ img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/2/image/image.jpg"
243
+ img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/6/image/image.jpg"
244
+
245
+ #### diffusion model, general model
246
+ df = sample_pred_wrapper(img_url, i2c_obj = ofa_obj)
247
+ df["dialogue"].values.tolist()
248
+
249
+ ds_en_zh_df = pd.read_csv("../ds_en_zh_df.csv")
250
+
251
+ idx = 3
252
+ ds_en_zh_df.iloc[:, -1].iloc[idx]
253
+
254
+ df = sample_pred(ds_en_zh_df.iloc[:, -1].iloc[idx])
255
+ df["dialogue"].values.tolist()