oshizo commited on
Commit
b7935fa
verified
1 Parent(s): e1cc11b

Upload modeling_clip_qwen2vl.py

Browse files
Files changed (1) hide show
  1. modeling_clip_qwen2vl.py +275 -0
modeling_clip_qwen2vl.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 oshizo
3
+ #
4
+ # This implementation is based on:
5
+ # 1. Qwen2-VL (https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/)
6
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
7
+ # Originally based on EleutherAI's GPT-NeoX library and GPT-NeoX/OPT implementations.
8
+ #
9
+ # 2. CLIP (https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/)
10
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team.
11
+ # CLIP Configuration
12
+ # Copyright 2021 The HuggingFace Inc. team.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ """CLIPQwen2VL model implementation."""
26
+
27
+ from __future__ import annotations
28
+
29
+ import itertools
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ import transformers
35
+ from PIL import Image
36
+ from torch import nn
37
+ from transformers import BertConfig, BertModel, PretrainedConfig, PreTrainedModel
38
+ from transformers import LukeConfig, LukeModel, PretrainedConfig, PreTrainedModel
39
+ from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
40
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
41
+ Qwen2VisionTransformerPretrainedModel,
42
+ )
43
+
44
+
45
+ class CLIPQwen2VLConfig(PretrainedConfig):
46
+ model_type = "clip_qwen2vl"
47
+
48
+ def __init__(
49
+ self,
50
+ text_config: Optional[Dict[str, Any]] = None,
51
+ vision_config: Optional[Dict[str, Any]] = None,
52
+ projection_dim: int = 768,
53
+ logit_scale_init_value: float = 2.6592,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+
58
+ text_config = text_config or {}
59
+ vision_config = vision_config or {}
60
+
61
+ self.text_config = LukeConfig(**text_config)
62
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
63
+
64
+ self.projection_dim = projection_dim
65
+ self.logit_scale_init_value = logit_scale_init_value
66
+
67
+
68
+ class CLIPQwen2VLModel(PreTrainedModel):
69
+ config_class = CLIPQwen2VLConfig
70
+
71
+ def __init__(self, config: CLIPQwen2VLConfig):
72
+ super().__init__(config)
73
+
74
+ self.projection_dim = config.text_config.hidden_size # 1024
75
+ self.text_embed_dim = config.text_config.hidden_size # 1024
76
+ self.vision_embed_dim = config.vision_config.hidden_size # 1536
77
+
78
+ # Text encoder
79
+ self.text_model = LukeModel(config.text_config)
80
+
81
+ # Vision encoder
82
+ self.vision_model = Qwen2VisionTransformerPretrainedModel(config.vision_config)
83
+
84
+ # vision projection (1536 -> 1024)
85
+ self.vision_projection = nn.Linear(
86
+ self.vision_embed_dim, self.projection_dim, bias=False
87
+ )
88
+
89
+ self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value)
90
+
91
+ def get_text_features(
92
+ self,
93
+ input_ids: Optional[torch.Tensor] = None,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ position_ids: Optional[torch.Tensor] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ return_dict: Optional[bool] = None,
99
+ ) -> torch.FloatTensor:
100
+ text_outputs = self.text_model(
101
+ input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
+ position_ids=position_ids,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict,
107
+ )
108
+
109
+ # Mean pooling
110
+ attention_mask = attention_mask.to(text_outputs.last_hidden_state.dtype)
111
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(
112
+ text_outputs.last_hidden_state.size()
113
+ )
114
+ sum_embeddings = torch.sum(
115
+ text_outputs.last_hidden_state * input_mask_expanded, 1
116
+ )
117
+ sum_mask = input_mask_expanded.sum(1)
118
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
119
+ text_embeds = sum_embeddings / sum_mask
120
+
121
+ return text_embeds
122
+
123
+ def get_image_features(
124
+ self,
125
+ pixel_values: Optional[torch.FloatTensor] = None,
126
+ image_grid_thw: Optional[torch.LongTensor] = None,
127
+ ) -> torch.FloatTensor:
128
+ batch_size = image_grid_thw.shape[0]
129
+ spatial_merge_size = 2
130
+
131
+ cu_seqlens = torch.repeat_interleave(
132
+ image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]
133
+ ).cumsum(dim=0, dtype=torch.int32)
134
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
135
+
136
+ vision_output = self.vision_model(
137
+ hidden_states=pixel_values, grid_thw=image_grid_thw
138
+ )
139
+
140
+ merged_patches_per_image = [
141
+ ((h // spatial_merge_size) * (w // spatial_merge_size) * t).item()
142
+ for t, h, w in image_grid_thw
143
+ ]
144
+ merged_cu_seqlens = torch.tensor(
145
+ [0] + list(itertools.accumulate(merged_patches_per_image)),
146
+ device=vision_output.device,
147
+ )
148
+
149
+ image_features = []
150
+ for i in range(batch_size):
151
+ start_idx = merged_cu_seqlens[i]
152
+ end_idx = merged_cu_seqlens[i + 1]
153
+ image_features.append(vision_output[start_idx:end_idx].mean(dim=0))
154
+
155
+ image_features = torch.stack(image_features)
156
+ image_embeds = self.vision_projection(image_features)
157
+ return image_embeds
158
+
159
+
160
+ class CLIPQwen2VLWrapper(nn.Module):
161
+ save_in_root: bool = True
162
+
163
+ def __init__(
164
+ self,
165
+ model_name_or_path: str,
166
+ cache_dir: str = None,
167
+ backend: str = "torch",
168
+ enable_text_grad: bool = False,
169
+ **kwargs,
170
+ ) -> None:
171
+ super().__init__()
172
+
173
+ self.enable_text_grad = enable_text_grad
174
+
175
+ model_args = kwargs.get("model_args", {})
176
+ if "torch_dtype" not in model_args:
177
+ model_args["torch_dtype"] = torch.bfloat16
178
+
179
+ self.model = CLIPQwen2VLModel.from_pretrained(
180
+ model_name_or_path, cache_dir=cache_dir, **model_args
181
+ )
182
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
183
+ "pkshatech/GLuCoSE-base-ja-v2"
184
+ )
185
+ self.processor = transformers.AutoProcessor.from_pretrained(
186
+ "Qwen/Qwen2-VL-2B-Instruct"
187
+ )
188
+
189
+ def __repr__(self) -> str:
190
+ return "CLIPQwen2VLWrapper()"
191
+
192
+ def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
193
+ image_embeds = []
194
+ text_embeds = []
195
+
196
+ if "pixel_values" in features:
197
+ image_embeds = self.model.get_image_features(
198
+ pixel_values=features["pixel_values"],
199
+ image_grid_thw=features["image_grid_thw"],
200
+ )
201
+
202
+ if "input_ids" in features:
203
+ text_embeds = self.model.get_text_features(
204
+ input_ids=features["input_ids"],
205
+ attention_mask=features.get("attention_mask", None),
206
+ position_ids=features.get("position_ids", None),
207
+ output_attentions=features.get("output_attentions", None),
208
+ output_hidden_states=features.get("output_hidden_states", None),
209
+ )
210
+ if self.enable_text_grad:
211
+ # peft銇с伄瀛︾繏鏅傘伀text model銇堡銈掓寚瀹氥仐銇亜鍫村悎銇偍銉┿兗鍥為伩
212
+ text_embeds = text_embeds.detach().requires_grad_()
213
+
214
+ sentence_embedding = []
215
+ image_features = iter(image_embeds)
216
+ text_features = iter(text_embeds)
217
+
218
+ for idx, input_type in enumerate(features["image_text_info"]):
219
+ if input_type == 0:
220
+ sentence_embedding.append(next(image_features))
221
+ else:
222
+ sentence_embedding.append(next(text_features))
223
+
224
+ features["sentence_embedding"] = torch.stack(sentence_embedding).float()
225
+
226
+ return features
227
+
228
+ def tokenize(
229
+ self, texts: List[Union[str, Image.Image]], padding: str | bool = True
230
+ ) -> dict[str, torch.Tensor]:
231
+ images = []
232
+ texts_values = []
233
+ image_text_info = []
234
+
235
+ for idx, data in enumerate(texts):
236
+ if isinstance(data, Image.Image):
237
+ images.append(data)
238
+ image_text_info.append(0)
239
+ else:
240
+ texts_values.append(data)
241
+ image_text_info.append(1)
242
+
243
+ encoding = {}
244
+ if len(texts_values):
245
+ encoding = self.tokenizer(
246
+ texts_values,
247
+ return_tensors="pt",
248
+ padding=padding,
249
+ truncation=True,
250
+ max_length=512,
251
+ )
252
+
253
+ if len(images):
254
+ image_features = self.processor.image_processor(images, return_tensors="pt")
255
+ encoding.update(image_features)
256
+
257
+ encoding["image_text_info"] = image_text_info
258
+ return dict(encoding)
259
+
260
+ @property
261
+ def processor(self) -> transformers.PreTrainedModel:
262
+ return self._processor
263
+
264
+ @processor.setter
265
+ def processor(self, processor):
266
+ self._processor = processor
267
+
268
+ def save(self, output_path: str) -> None:
269
+ self.model.save_pretrained(output_path)
270
+ self.tokenizer.save_pretrained(output_path)
271
+ self.processor.save_pretrained(output_path)
272
+
273
+ @staticmethod
274
+ def load(input_path: str) -> CLIPQwen2VLWrapper:
275
+ return CLIPQwen2VLWrapper(model_name_or_path=input_path)