Upload modeling_clip_qwen2vl.py
Browse files- modeling_clip_qwen2vl.py +275 -0
modeling_clip_qwen2vl.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 oshizo
|
3 |
+
#
|
4 |
+
# This implementation is based on:
|
5 |
+
# 1. Qwen2-VL (https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/)
|
6 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
7 |
+
# Originally based on EleutherAI's GPT-NeoX library and GPT-NeoX/OPT implementations.
|
8 |
+
#
|
9 |
+
# 2. CLIP (https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/)
|
10 |
+
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team.
|
11 |
+
# CLIP Configuration
|
12 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
13 |
+
#
|
14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
15 |
+
# you may not use this file except in compliance with the License.
|
16 |
+
# You may obtain a copy of the License at
|
17 |
+
#
|
18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
19 |
+
#
|
20 |
+
# Unless required by applicable law or agreed to in writing, software
|
21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
23 |
+
# See the License for the specific language governing permissions and
|
24 |
+
# limitations under the License.
|
25 |
+
"""CLIPQwen2VL model implementation."""
|
26 |
+
|
27 |
+
from __future__ import annotations
|
28 |
+
|
29 |
+
import itertools
|
30 |
+
from typing import Any, Dict, List, Optional, Union
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch.nn.functional as F
|
34 |
+
import transformers
|
35 |
+
from PIL import Image
|
36 |
+
from torch import nn
|
37 |
+
from transformers import BertConfig, BertModel, PretrainedConfig, PreTrainedModel
|
38 |
+
from transformers import LukeConfig, LukeModel, PretrainedConfig, PreTrainedModel
|
39 |
+
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
40 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
41 |
+
Qwen2VisionTransformerPretrainedModel,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class CLIPQwen2VLConfig(PretrainedConfig):
|
46 |
+
model_type = "clip_qwen2vl"
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
text_config: Optional[Dict[str, Any]] = None,
|
51 |
+
vision_config: Optional[Dict[str, Any]] = None,
|
52 |
+
projection_dim: int = 768,
|
53 |
+
logit_scale_init_value: float = 2.6592,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
super().__init__(**kwargs)
|
57 |
+
|
58 |
+
text_config = text_config or {}
|
59 |
+
vision_config = vision_config or {}
|
60 |
+
|
61 |
+
self.text_config = LukeConfig(**text_config)
|
62 |
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
63 |
+
|
64 |
+
self.projection_dim = projection_dim
|
65 |
+
self.logit_scale_init_value = logit_scale_init_value
|
66 |
+
|
67 |
+
|
68 |
+
class CLIPQwen2VLModel(PreTrainedModel):
|
69 |
+
config_class = CLIPQwen2VLConfig
|
70 |
+
|
71 |
+
def __init__(self, config: CLIPQwen2VLConfig):
|
72 |
+
super().__init__(config)
|
73 |
+
|
74 |
+
self.projection_dim = config.text_config.hidden_size # 1024
|
75 |
+
self.text_embed_dim = config.text_config.hidden_size # 1024
|
76 |
+
self.vision_embed_dim = config.vision_config.hidden_size # 1536
|
77 |
+
|
78 |
+
# Text encoder
|
79 |
+
self.text_model = LukeModel(config.text_config)
|
80 |
+
|
81 |
+
# Vision encoder
|
82 |
+
self.vision_model = Qwen2VisionTransformerPretrainedModel(config.vision_config)
|
83 |
+
|
84 |
+
# vision projection (1536 -> 1024)
|
85 |
+
self.vision_projection = nn.Linear(
|
86 |
+
self.vision_embed_dim, self.projection_dim, bias=False
|
87 |
+
)
|
88 |
+
|
89 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value)
|
90 |
+
|
91 |
+
def get_text_features(
|
92 |
+
self,
|
93 |
+
input_ids: Optional[torch.Tensor] = None,
|
94 |
+
attention_mask: Optional[torch.Tensor] = None,
|
95 |
+
position_ids: Optional[torch.Tensor] = None,
|
96 |
+
output_attentions: Optional[bool] = None,
|
97 |
+
output_hidden_states: Optional[bool] = None,
|
98 |
+
return_dict: Optional[bool] = None,
|
99 |
+
) -> torch.FloatTensor:
|
100 |
+
text_outputs = self.text_model(
|
101 |
+
input_ids=input_ids,
|
102 |
+
attention_mask=attention_mask,
|
103 |
+
position_ids=position_ids,
|
104 |
+
output_attentions=output_attentions,
|
105 |
+
output_hidden_states=output_hidden_states,
|
106 |
+
return_dict=return_dict,
|
107 |
+
)
|
108 |
+
|
109 |
+
# Mean pooling
|
110 |
+
attention_mask = attention_mask.to(text_outputs.last_hidden_state.dtype)
|
111 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
|
112 |
+
text_outputs.last_hidden_state.size()
|
113 |
+
)
|
114 |
+
sum_embeddings = torch.sum(
|
115 |
+
text_outputs.last_hidden_state * input_mask_expanded, 1
|
116 |
+
)
|
117 |
+
sum_mask = input_mask_expanded.sum(1)
|
118 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
119 |
+
text_embeds = sum_embeddings / sum_mask
|
120 |
+
|
121 |
+
return text_embeds
|
122 |
+
|
123 |
+
def get_image_features(
|
124 |
+
self,
|
125 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
126 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
127 |
+
) -> torch.FloatTensor:
|
128 |
+
batch_size = image_grid_thw.shape[0]
|
129 |
+
spatial_merge_size = 2
|
130 |
+
|
131 |
+
cu_seqlens = torch.repeat_interleave(
|
132 |
+
image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]
|
133 |
+
).cumsum(dim=0, dtype=torch.int32)
|
134 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
135 |
+
|
136 |
+
vision_output = self.vision_model(
|
137 |
+
hidden_states=pixel_values, grid_thw=image_grid_thw
|
138 |
+
)
|
139 |
+
|
140 |
+
merged_patches_per_image = [
|
141 |
+
((h // spatial_merge_size) * (w // spatial_merge_size) * t).item()
|
142 |
+
for t, h, w in image_grid_thw
|
143 |
+
]
|
144 |
+
merged_cu_seqlens = torch.tensor(
|
145 |
+
[0] + list(itertools.accumulate(merged_patches_per_image)),
|
146 |
+
device=vision_output.device,
|
147 |
+
)
|
148 |
+
|
149 |
+
image_features = []
|
150 |
+
for i in range(batch_size):
|
151 |
+
start_idx = merged_cu_seqlens[i]
|
152 |
+
end_idx = merged_cu_seqlens[i + 1]
|
153 |
+
image_features.append(vision_output[start_idx:end_idx].mean(dim=0))
|
154 |
+
|
155 |
+
image_features = torch.stack(image_features)
|
156 |
+
image_embeds = self.vision_projection(image_features)
|
157 |
+
return image_embeds
|
158 |
+
|
159 |
+
|
160 |
+
class CLIPQwen2VLWrapper(nn.Module):
|
161 |
+
save_in_root: bool = True
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
model_name_or_path: str,
|
166 |
+
cache_dir: str = None,
|
167 |
+
backend: str = "torch",
|
168 |
+
enable_text_grad: bool = False,
|
169 |
+
**kwargs,
|
170 |
+
) -> None:
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
self.enable_text_grad = enable_text_grad
|
174 |
+
|
175 |
+
model_args = kwargs.get("model_args", {})
|
176 |
+
if "torch_dtype" not in model_args:
|
177 |
+
model_args["torch_dtype"] = torch.bfloat16
|
178 |
+
|
179 |
+
self.model = CLIPQwen2VLModel.from_pretrained(
|
180 |
+
model_name_or_path, cache_dir=cache_dir, **model_args
|
181 |
+
)
|
182 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
183 |
+
"pkshatech/GLuCoSE-base-ja-v2"
|
184 |
+
)
|
185 |
+
self.processor = transformers.AutoProcessor.from_pretrained(
|
186 |
+
"Qwen/Qwen2-VL-2B-Instruct"
|
187 |
+
)
|
188 |
+
|
189 |
+
def __repr__(self) -> str:
|
190 |
+
return "CLIPQwen2VLWrapper()"
|
191 |
+
|
192 |
+
def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
193 |
+
image_embeds = []
|
194 |
+
text_embeds = []
|
195 |
+
|
196 |
+
if "pixel_values" in features:
|
197 |
+
image_embeds = self.model.get_image_features(
|
198 |
+
pixel_values=features["pixel_values"],
|
199 |
+
image_grid_thw=features["image_grid_thw"],
|
200 |
+
)
|
201 |
+
|
202 |
+
if "input_ids" in features:
|
203 |
+
text_embeds = self.model.get_text_features(
|
204 |
+
input_ids=features["input_ids"],
|
205 |
+
attention_mask=features.get("attention_mask", None),
|
206 |
+
position_ids=features.get("position_ids", None),
|
207 |
+
output_attentions=features.get("output_attentions", None),
|
208 |
+
output_hidden_states=features.get("output_hidden_states", None),
|
209 |
+
)
|
210 |
+
if self.enable_text_grad:
|
211 |
+
# peft銇с伄瀛︾繏鏅傘伀text model銇堡銈掓寚瀹氥仐銇亜鍫村悎銇偍銉┿兗鍥為伩
|
212 |
+
text_embeds = text_embeds.detach().requires_grad_()
|
213 |
+
|
214 |
+
sentence_embedding = []
|
215 |
+
image_features = iter(image_embeds)
|
216 |
+
text_features = iter(text_embeds)
|
217 |
+
|
218 |
+
for idx, input_type in enumerate(features["image_text_info"]):
|
219 |
+
if input_type == 0:
|
220 |
+
sentence_embedding.append(next(image_features))
|
221 |
+
else:
|
222 |
+
sentence_embedding.append(next(text_features))
|
223 |
+
|
224 |
+
features["sentence_embedding"] = torch.stack(sentence_embedding).float()
|
225 |
+
|
226 |
+
return features
|
227 |
+
|
228 |
+
def tokenize(
|
229 |
+
self, texts: List[Union[str, Image.Image]], padding: str | bool = True
|
230 |
+
) -> dict[str, torch.Tensor]:
|
231 |
+
images = []
|
232 |
+
texts_values = []
|
233 |
+
image_text_info = []
|
234 |
+
|
235 |
+
for idx, data in enumerate(texts):
|
236 |
+
if isinstance(data, Image.Image):
|
237 |
+
images.append(data)
|
238 |
+
image_text_info.append(0)
|
239 |
+
else:
|
240 |
+
texts_values.append(data)
|
241 |
+
image_text_info.append(1)
|
242 |
+
|
243 |
+
encoding = {}
|
244 |
+
if len(texts_values):
|
245 |
+
encoding = self.tokenizer(
|
246 |
+
texts_values,
|
247 |
+
return_tensors="pt",
|
248 |
+
padding=padding,
|
249 |
+
truncation=True,
|
250 |
+
max_length=512,
|
251 |
+
)
|
252 |
+
|
253 |
+
if len(images):
|
254 |
+
image_features = self.processor.image_processor(images, return_tensors="pt")
|
255 |
+
encoding.update(image_features)
|
256 |
+
|
257 |
+
encoding["image_text_info"] = image_text_info
|
258 |
+
return dict(encoding)
|
259 |
+
|
260 |
+
@property
|
261 |
+
def processor(self) -> transformers.PreTrainedModel:
|
262 |
+
return self._processor
|
263 |
+
|
264 |
+
@processor.setter
|
265 |
+
def processor(self, processor):
|
266 |
+
self._processor = processor
|
267 |
+
|
268 |
+
def save(self, output_path: str) -> None:
|
269 |
+
self.model.save_pretrained(output_path)
|
270 |
+
self.tokenizer.save_pretrained(output_path)
|
271 |
+
self.processor.save_pretrained(output_path)
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def load(input_path: str) -> CLIPQwen2VLWrapper:
|
275 |
+
return CLIPQwen2VLWrapper(model_name_or_path=input_path)
|