Support Transformers AutoModel
Browse files- .gitignore +5 -0
- __init__.py +23 -0
- config.json +7 -3
- configuration_v1.py +236 -0
- grounding.py +559 -0
- modeling_v1.py +577 -0
- preprocessor_config.json +1 -1
- processor.py +536 -0
- processor_config.json +1 -1
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
backup/
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.swo
|
5 |
+
*.swp
|
__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
AutoModelForCausalLM,
|
3 |
+
AutoProcessor,
|
4 |
+
AutoImageProcessor,
|
5 |
+
AutoConfig,
|
6 |
+
)
|
7 |
+
|
8 |
+
from .processor import (
|
9 |
+
Qwen2VLImagePointerProcessor,
|
10 |
+
get_processor,
|
11 |
+
V1Processor,
|
12 |
+
collate_fn,
|
13 |
+
)
|
14 |
+
from .modeling_v1 import V1ForConditionalGeneration
|
15 |
+
from .configuration_v1 import V1Config
|
16 |
+
|
17 |
+
print("Registering V1 model and processor with Transformers")
|
18 |
+
AutoConfig.register("v1", V1Config)
|
19 |
+
AutoModelForCausalLM.register(
|
20 |
+
V1Config, V1ForConditionalGeneration
|
21 |
+
)
|
22 |
+
AutoProcessor.register(V1Config, V1Processor)
|
23 |
+
AutoImageProcessor.register(V1Config, Qwen2VLImagePointerProcessor)
|
config.json
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
-
"
|
4 |
],
|
5 |
"attention_dropout": 0.0,
|
6 |
"bos_token_id": 151643,
|
@@ -18,7 +18,7 @@
|
|
18 |
"label_smoothing": 0.0,
|
19 |
"max_position_embeddings": 128000,
|
20 |
"max_window_layers": 28,
|
21 |
-
"model_type": "
|
22 |
"normalize_copy_states": false,
|
23 |
"num_attention_heads": 28,
|
24 |
"num_hidden_layers": 28,
|
@@ -75,5 +75,9 @@
|
|
75 |
"vision_token_id": 151654,
|
76 |
"vocab_size": 152064,
|
77 |
"z_loss_top_k": 40,
|
78 |
-
"z_loss_weight": 1e-05
|
|
|
|
|
|
|
|
|
79 |
}
|
|
|
1 |
{
|
2 |
"architectures": [
|
3 |
+
"V1ForConditionalGeneration"
|
4 |
],
|
5 |
"attention_dropout": 0.0,
|
6 |
"bos_token_id": 151643,
|
|
|
18 |
"label_smoothing": 0.0,
|
19 |
"max_position_embeddings": 128000,
|
20 |
"max_window_layers": 28,
|
21 |
+
"model_type": "v1",
|
22 |
"normalize_copy_states": false,
|
23 |
"num_attention_heads": 28,
|
24 |
"num_hidden_layers": 28,
|
|
|
75 |
"vision_token_id": 151654,
|
76 |
"vocab_size": 152064,
|
77 |
"z_loss_top_k": 40,
|
78 |
+
"z_loss_weight": 1e-05,
|
79 |
+
"auto_map": {
|
80 |
+
"AutoConfig": "configuration_v1.V1Config",
|
81 |
+
"AutoModelForConditionalGeneration": "modeling_v1.V1ForConditionalGeneration"
|
82 |
+
}
|
83 |
}
|
configuration_v1.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
5 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
6 |
+
Qwen2_5_VLVisionConfig,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class V1Config(PretrainedConfig):
|
11 |
+
r"""
|
12 |
+
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
13 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
14 |
+
with the defaults will yield a similar configuration to that of
|
15 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
16 |
+
|
17 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
18 |
+
documentation from [`PretrainedConfig`] for more information.
|
19 |
+
|
20 |
+
|
21 |
+
Args:
|
22 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
23 |
+
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
24 |
+
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
25 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
26 |
+
Dimension of the hidden representations.
|
27 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
28 |
+
Dimension of the MLP representations.
|
29 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
30 |
+
Number of hidden layers in the Transformer encoder.
|
31 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
32 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
33 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
34 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
35 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
36 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
37 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
38 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
39 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
40 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
41 |
+
The non-linear activation function (function or string) in the decoder.
|
42 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
43 |
+
The maximum sequence length that this model might ever be used with.
|
44 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
45 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
46 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
47 |
+
The epsilon used by the rms normalization layers.
|
48 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
49 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
50 |
+
relevant if `config.is_decoder=True`.
|
51 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
52 |
+
Whether the model's input and output word embeddings should be tied.
|
53 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
54 |
+
The base period of the RoPE embeddings.
|
55 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
56 |
+
Whether to use sliding window attention.
|
57 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
58 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
59 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
60 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
61 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
62 |
+
The dropout ratio for the attention probabilities.
|
63 |
+
vision_config (`Dict`, *optional*):
|
64 |
+
The config for the visual encoder initialization.
|
65 |
+
rope_scaling (`Dict`, *optional*):
|
66 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
67 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
68 |
+
accordingly.
|
69 |
+
Expected contents:
|
70 |
+
`rope_type` (`str`):
|
71 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
72 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
73 |
+
`factor` (`float`, *optional*):
|
74 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
75 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
76 |
+
original maximum pre-trained length.
|
77 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
78 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
79 |
+
pretraining.
|
80 |
+
`attention_factor` (`float`, *optional*):
|
81 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
82 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
83 |
+
`factor` field to infer the suggested value.
|
84 |
+
`beta_fast` (`float`, *optional*):
|
85 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
86 |
+
ramp function. If unspecified, it defaults to 32.
|
87 |
+
`beta_slow` (`float`, *optional*):
|
88 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
89 |
+
ramp function. If unspecified, it defaults to 1.
|
90 |
+
`short_factor` (`List[float]`, *optional*):
|
91 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
92 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
93 |
+
size divided by the number of attention heads divided by 2
|
94 |
+
`long_factor` (`List[float]`, *optional*):
|
95 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
96 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
97 |
+
size divided by the number of attention heads divided by 2
|
98 |
+
`low_freq_factor` (`float`, *optional*):
|
99 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
100 |
+
`high_freq_factor` (`float`, *optional*):
|
101 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
102 |
+
|
103 |
+
```python
|
104 |
+
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
105 |
+
|
106 |
+
>>> # Initializing a Qwen2_5_VL style configuration
|
107 |
+
>>> configuration = Qwen2_5_VLConfig()
|
108 |
+
|
109 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
110 |
+
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
111 |
+
|
112 |
+
>>> # Accessing the model configuration
|
113 |
+
>>> configuration = model.config
|
114 |
+
```"""
|
115 |
+
|
116 |
+
model_type = "v1"
|
117 |
+
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
118 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
119 |
+
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
120 |
+
base_model_tp_plan = {
|
121 |
+
"layers.*.self_attn.q_proj": "colwise",
|
122 |
+
"layers.*.self_attn.k_proj": "colwise",
|
123 |
+
"layers.*.self_attn.v_proj": "colwise",
|
124 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
125 |
+
"layers.*.mlp.gate_proj": "colwise",
|
126 |
+
"layers.*.mlp.up_proj": "colwise",
|
127 |
+
"layers.*.mlp.down_proj": "rowwise",
|
128 |
+
}
|
129 |
+
base_model_pp_plan = {
|
130 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
131 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
132 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
133 |
+
}
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
vocab_size=152064,
|
138 |
+
hidden_size=8192,
|
139 |
+
intermediate_size=29568,
|
140 |
+
num_hidden_layers=80,
|
141 |
+
num_attention_heads=64,
|
142 |
+
num_key_value_heads=8,
|
143 |
+
hidden_act="silu",
|
144 |
+
max_position_embeddings=32768,
|
145 |
+
initializer_range=0.02,
|
146 |
+
rms_norm_eps=1e-05,
|
147 |
+
use_cache=True,
|
148 |
+
tie_word_embeddings=False,
|
149 |
+
rope_theta=1000000.0,
|
150 |
+
use_sliding_window=False,
|
151 |
+
sliding_window=4096,
|
152 |
+
max_window_layers=80,
|
153 |
+
attention_dropout=0.0,
|
154 |
+
vision_config=None,
|
155 |
+
rope_scaling=None,
|
156 |
+
region_token_id: int = 151662, # <|fim_pad|>
|
157 |
+
copy_token_start: int = 151665,
|
158 |
+
copy_token_num: int = 30000,
|
159 |
+
copy_scaler: float = 0.1,
|
160 |
+
use_embeddings_as_keys: bool = False,
|
161 |
+
normalize_copy_states: bool = False,
|
162 |
+
copy_extraction_layer: int = -1,
|
163 |
+
tie_copy_heads: bool = False,
|
164 |
+
use_cfg: bool = False,
|
165 |
+
copy_hidden_size: Optional[int] = None,
|
166 |
+
z_loss_weight: float = 1e-5,
|
167 |
+
z_loss_top_k: int = 40,
|
168 |
+
use_gate: bool = False,
|
169 |
+
label_smoothing: bool = False,
|
170 |
+
separate_copy_loss: bool = False,
|
171 |
+
do_copy: bool = True,
|
172 |
+
**kwargs,
|
173 |
+
):
|
174 |
+
if isinstance(vision_config, dict):
|
175 |
+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
176 |
+
elif vision_config is None:
|
177 |
+
self.vision_config = self.sub_configs["vision_config"]()
|
178 |
+
|
179 |
+
self.vocab_size = vocab_size
|
180 |
+
self.max_position_embeddings = max_position_embeddings
|
181 |
+
self.hidden_size = hidden_size
|
182 |
+
self.intermediate_size = intermediate_size
|
183 |
+
self.num_hidden_layers = num_hidden_layers
|
184 |
+
self.num_attention_heads = num_attention_heads
|
185 |
+
self.use_sliding_window = use_sliding_window
|
186 |
+
self.sliding_window = sliding_window
|
187 |
+
self.max_window_layers = max_window_layers
|
188 |
+
|
189 |
+
# for backward compatibility
|
190 |
+
if num_key_value_heads is None:
|
191 |
+
num_key_value_heads = num_attention_heads
|
192 |
+
|
193 |
+
self.num_key_value_heads = num_key_value_heads
|
194 |
+
self.hidden_act = hidden_act
|
195 |
+
self.initializer_range = initializer_range
|
196 |
+
self.rms_norm_eps = rms_norm_eps
|
197 |
+
self.use_cache = use_cache
|
198 |
+
self.rope_theta = rope_theta
|
199 |
+
self.attention_dropout = attention_dropout
|
200 |
+
self.rope_scaling = rope_scaling
|
201 |
+
|
202 |
+
# Validate the correctness of rotary position embeddings parameters
|
203 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
204 |
+
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
205 |
+
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
206 |
+
# TODO: @raushan update config in the hub
|
207 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
208 |
+
if self.rope_scaling["type"] == "mrope":
|
209 |
+
self.rope_scaling["type"] = "default"
|
210 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
211 |
+
rope_config_validation(self, ignore_keys={"mrope_section"})
|
212 |
+
|
213 |
+
self.region_token_id = region_token_id
|
214 |
+
self.copy_token_start = copy_token_start
|
215 |
+
self.copy_token_num = copy_token_num
|
216 |
+
self.copy_scaler = copy_scaler
|
217 |
+
self.use_embeddings_as_keys = use_embeddings_as_keys
|
218 |
+
self.normalize_copy_states = normalize_copy_states
|
219 |
+
self.copy_extraction_layer = copy_extraction_layer
|
220 |
+
self.tie_copy_heads = tie_copy_heads
|
221 |
+
self.use_cfg = use_cfg
|
222 |
+
|
223 |
+
if copy_hidden_size is None:
|
224 |
+
copy_hidden_size = self.hidden_size
|
225 |
+
self.copy_hidden_size = copy_hidden_size
|
226 |
+
self.z_loss_weight = z_loss_weight
|
227 |
+
self.z_loss_top_k = z_loss_top_k
|
228 |
+
self.use_gate = use_gate
|
229 |
+
self.label_smoothing = label_smoothing
|
230 |
+
self.separate_copy_loss = separate_copy_loss
|
231 |
+
self.do_copy = do_copy
|
232 |
+
|
233 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
234 |
+
|
235 |
+
|
236 |
+
__all__ = ["V1Config"]
|
grounding.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Optional, Union
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from qwen_vl_utils import process_vision_info
|
10 |
+
from transformers.feature_extraction_utils import BatchFeature
|
11 |
+
from transformers.image_utils import ImageInput, VideoInput
|
12 |
+
from transformers.processing_utils import (
|
13 |
+
Unpack,
|
14 |
+
)
|
15 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
16 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
17 |
+
smart_resize,
|
18 |
+
Qwen2VLImageProcessor,
|
19 |
+
)
|
20 |
+
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import (
|
21 |
+
Qwen2_5_VLProcessorKwargs,
|
22 |
+
Qwen2_5_VLProcessor,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
"""
|
27 |
+
Qwen2.5-VL does not use AnyRes to my relief.
|
28 |
+
Things to take into account:
|
29 |
+
- smart_resize
|
30 |
+
- temporal dimension
|
31 |
+
- grid_t = patches.shape[0] // self.temporal_patch_size
|
32 |
+
- grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
33 |
+
- merge_size (2)
|
34 |
+
|
35 |
+
|
36 |
+
Usage:
|
37 |
+
|
38 |
+
model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
|
39 |
+
|
40 |
+
|
41 |
+
processor = Qwen2_5_VLPointerProcessor.from_pretrained(model_name)
|
42 |
+
processor.image_processor = Qwen2VLImagePointerProcessor.from_pretrained(model_name)
|
43 |
+
|
44 |
+
messages = [
|
45 |
+
{
|
46 |
+
"role": "user",
|
47 |
+
"content": [
|
48 |
+
{
|
49 |
+
"type": "image",
|
50 |
+
"image": "https://example---/demo.jpeg",
|
51 |
+
},
|
52 |
+
{"type": "text", "text": "Describe this image."},
|
53 |
+
],
|
54 |
+
},
|
55 |
+
{
|
56 |
+
'role': 'assistant',
|
57 |
+
'content': [
|
58 |
+
{
|
59 |
+
'type': 'text', 'text': '<think>Theres a cat at <|region|>, a dog at <|region|>.</think>A calico cat hanging out with a golden retriever.'
|
60 |
+
}
|
61 |
+
]
|
62 |
+
}
|
63 |
+
]
|
64 |
+
|
65 |
+
# Preparation for inference
|
66 |
+
text = processor.apply_chat_template(
|
67 |
+
messages, tokenize=False, add_generation_prompt=True
|
68 |
+
)
|
69 |
+
regions = [
|
70 |
+
[0, 10, 100, 200],
|
71 |
+
[300, 0, 600, 250]
|
72 |
+
]
|
73 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
74 |
+
inputs = processor(
|
75 |
+
text=[text],
|
76 |
+
images=image_inputs,
|
77 |
+
videos=video_inputs,
|
78 |
+
regions=[regions]
|
79 |
+
padding=True,
|
80 |
+
return_tensors="pt",
|
81 |
+
)
|
82 |
+
inputs = inputs.to("cuda")
|
83 |
+
|
84 |
+
|
85 |
+
# Qwen2VLImageProcessor in a nutshell
|
86 |
+
'(tl tp) c (hlm hm hp) (wlm wm wp) -> (tl hlm wlm hm wm) (c tp hp wp)'
|
87 |
+
"""
|
88 |
+
|
89 |
+
|
90 |
+
BBOX = Tuple[int, int, int, int]
|
91 |
+
|
92 |
+
|
93 |
+
class PointerProcessor:
|
94 |
+
@staticmethod
|
95 |
+
def normalize_bbox(image_size: Tuple[int, int], bbox: BBOX):
|
96 |
+
w, h = image_size
|
97 |
+
bbox = [
|
98 |
+
bbox[0] / w,
|
99 |
+
bbox[1] / h,
|
100 |
+
bbox[2] / w,
|
101 |
+
bbox[3] / h,
|
102 |
+
]
|
103 |
+
return "[{}]".format(", ".join([f"{v:.2f}" for v in bbox]))
|
104 |
+
|
105 |
+
def get_masks(self, image_size: Tuple[int, int], indices: List[int]):
|
106 |
+
width, height = image_size
|
107 |
+
resized_height, resized_width = smart_resize(
|
108 |
+
height,
|
109 |
+
width,
|
110 |
+
factor=self.patch_size * self.merge_size,
|
111 |
+
min_pixels=self.min_pixels,
|
112 |
+
max_pixels=self.max_pixels,
|
113 |
+
)
|
114 |
+
|
115 |
+
# grid_h = resized_height // self.patch_size // self.merge_size
|
116 |
+
grid_w_m = resized_width // self.patch_size // self.merge_size
|
117 |
+
|
118 |
+
mask = torch.zeros(resized_height, resized_width)
|
119 |
+
for index in indices:
|
120 |
+
index_h = index // grid_w_m
|
121 |
+
index_w = index % grid_w_m
|
122 |
+
bbox = (
|
123 |
+
max(index_w * self.patch_size * self.merge_size, 0),
|
124 |
+
max(index_h * self.patch_size * self.merge_size, 0),
|
125 |
+
min((index_w + 1) * self.patch_size * self.merge_size, resized_width),
|
126 |
+
min((index_h + 1) * self.patch_size * self.merge_size, resized_height),
|
127 |
+
)
|
128 |
+
x1, y1, x2, y2 = bbox
|
129 |
+
mask[y1:y2, x1:x2] = 1
|
130 |
+
# mask = mask.t() # to width, height
|
131 |
+
return mask, (resized_width, resized_height)
|
132 |
+
|
133 |
+
def get_patch_pointers(
|
134 |
+
self, image_size: Tuple[int, int], region: Union[BBOX, np.ndarray]
|
135 |
+
):
|
136 |
+
if isinstance(region, np.ndarray):
|
137 |
+
return self.get_mask_patch_pointers(image_size, region)
|
138 |
+
else:
|
139 |
+
return self.get_bbox_patch_pointers(image_size, region)
|
140 |
+
|
141 |
+
def get_bbox_patch_pointers(self, image_size: Tuple[int, int], bbox: BBOX):
|
142 |
+
factor = self.merge_size
|
143 |
+
# factor = 1
|
144 |
+
width, height = image_size
|
145 |
+
resized_height, resized_width = smart_resize(
|
146 |
+
height,
|
147 |
+
width,
|
148 |
+
factor=self.patch_size * self.merge_size,
|
149 |
+
min_pixels=self.min_pixels,
|
150 |
+
max_pixels=self.max_pixels,
|
151 |
+
)
|
152 |
+
x0, y0, x1, y1 = bbox
|
153 |
+
resized_bbox = [
|
154 |
+
max(x0 / width * resized_width, 0),
|
155 |
+
max(y0 / height * resized_height, 0),
|
156 |
+
min(x1 / width * resized_width, resized_width),
|
157 |
+
min(y1 / height * resized_height, resized_height),
|
158 |
+
]
|
159 |
+
# patch_bbox = [v / self.patch_size / self.merge_size for v in resized_bbox]
|
160 |
+
patch_bbox = [v / self.patch_size / factor for v in resized_bbox]
|
161 |
+
x0, y0, x1, y1 = patch_bbox
|
162 |
+
boundaries = [
|
163 |
+
math.floor(x0),
|
164 |
+
math.floor(y0),
|
165 |
+
math.ceil(x1),
|
166 |
+
math.ceil(y1),
|
167 |
+
]
|
168 |
+
x0, y0, x1, y1 = boundaries
|
169 |
+
|
170 |
+
# t, h, w
|
171 |
+
grid_w = resized_width // self.patch_size
|
172 |
+
grid_w_m = grid_w // factor
|
173 |
+
rows, cols = np.meshgrid(np.arange(y0, y1), np.arange(x0, x1), indexing="ij")
|
174 |
+
grid_indices = np.column_stack((rows.ravel(), cols.ravel()))
|
175 |
+
indices = grid_indices[:, 0] * grid_w_m + grid_indices[:, 1]
|
176 |
+
base_ids = list(indices)
|
177 |
+
# reorder
|
178 |
+
# t, hl, wl, hm, wm
|
179 |
+
# ids_map = torch.arange(grid_h * grid_w).reshape(grid_h, grid_w)
|
180 |
+
# ids_map = rearrange(
|
181 |
+
# ids_map,
|
182 |
+
# "(hl hm) (wl wm) -> (hl wl) (hm wm)",
|
183 |
+
# hm=self.merge_size,
|
184 |
+
# wm=self.merge_size,
|
185 |
+
# ).reshape(-1)
|
186 |
+
# inv_map = ids_map.argsort()
|
187 |
+
# ids = inv_map[base_ids].numpy()
|
188 |
+
ids = np.array(base_ids)
|
189 |
+
# ids.sort()
|
190 |
+
return ids
|
191 |
+
|
192 |
+
def get_mask_patch_pointers(self, image_size: Tuple[int, int], mask: np.ndarray):
|
193 |
+
# mask size: w h
|
194 |
+
width, height = image_size
|
195 |
+
resized_height, resized_width = smart_resize(
|
196 |
+
height,
|
197 |
+
width,
|
198 |
+
factor=self.patch_size * self.merge_size,
|
199 |
+
min_pixels=self.min_pixels,
|
200 |
+
max_pixels=self.max_pixels,
|
201 |
+
)
|
202 |
+
grid_w_m = resized_width // self.patch_size // self.merge_size
|
203 |
+
grid_h_m = resized_height // self.patch_size // self.merge_size
|
204 |
+
|
205 |
+
m = torch.from_numpy(mask).float()
|
206 |
+
m = F.interpolate(
|
207 |
+
m[None, None], (grid_h_m, grid_w_m), mode="bilinear", antialias="bilinear"
|
208 |
+
)[0, 0]
|
209 |
+
# m = m > 0 # upper bound
|
210 |
+
|
211 |
+
grid_indices = m.nonzero(as_tuple=False)
|
212 |
+
indices = grid_indices[:, 0] * grid_w_m + grid_indices[:, 1]
|
213 |
+
ids = indices.numpy()
|
214 |
+
return ids
|
215 |
+
|
216 |
+
def renormalize(self, tensor):
|
217 |
+
# crude - non-accurate implementation for the lazy
|
218 |
+
mean = np.array(self.image_mean).mean()
|
219 |
+
std = np.array(self.image_std).mean()
|
220 |
+
return tensor * std + mean
|
221 |
+
|
222 |
+
|
223 |
+
class Qwen2VLImagePointerProcessor(Qwen2VLImageProcessor, PointerProcessor):
|
224 |
+
pass
|
225 |
+
|
226 |
+
|
227 |
+
class Qwen2_5_VLPointerProcessor(Qwen2_5_VLProcessor):
|
228 |
+
image_processor_class = "Qwen2VLImagePointerProcessor"
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
image_processor=None,
|
233 |
+
tokenizer=None,
|
234 |
+
chat_template=None,
|
235 |
+
prepend_raw_region_to_text: bool = True,
|
236 |
+
**kwargs,
|
237 |
+
):
|
238 |
+
super().__init__(
|
239 |
+
image_processor=image_processor,
|
240 |
+
tokenizer=tokenizer,
|
241 |
+
chat_template=chat_template,
|
242 |
+
**kwargs,
|
243 |
+
)
|
244 |
+
|
245 |
+
self.region_token = "<|region|>"
|
246 |
+
self.copy_token_start = None
|
247 |
+
self.prepend_raw_region_to_text = prepend_raw_region_to_text
|
248 |
+
|
249 |
+
def extract_masks(self, image_size: Tuple[int, int], text: str):
|
250 |
+
# first, gather region indices from text
|
251 |
+
region_pattern = re.compile(r"<region>(.*?)</region>")
|
252 |
+
regions = region_pattern.findall(text)
|
253 |
+
|
254 |
+
indices = []
|
255 |
+
copy_pattern = re.compile(r"<\|copy_(\d+)\|>")
|
256 |
+
|
257 |
+
for region in regions:
|
258 |
+
# Extract all numbers inside <|copy_X|> tags within the region
|
259 |
+
numbers = [int(match) for match in copy_pattern.findall(region)]
|
260 |
+
indices.append(numbers)
|
261 |
+
|
262 |
+
# Then, convert region indices into masks
|
263 |
+
masks = []
|
264 |
+
resized_image_size = image_size
|
265 |
+
for region in indices:
|
266 |
+
mask, resized_image_size = self.image_processor.get_masks(
|
267 |
+
image_size, region
|
268 |
+
)
|
269 |
+
masks.append(mask)
|
270 |
+
return masks, resized_image_size
|
271 |
+
|
272 |
+
def __call__(
|
273 |
+
self,
|
274 |
+
images: ImageInput = None,
|
275 |
+
text: Union[
|
276 |
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
277 |
+
] = None,
|
278 |
+
videos: VideoInput = None,
|
279 |
+
regions: Optional[List[Union[BBOX, np.ndarray]]] = None,
|
280 |
+
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
281 |
+
) -> BatchFeature:
|
282 |
+
"""
|
283 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
284 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
285 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
286 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
290 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
291 |
+
tensor. Both channels-first and channels-last formats are supported.
|
292 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
293 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
294 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
295 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
296 |
+
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
297 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
298 |
+
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
299 |
+
regions:
|
300 |
+
either bboxes: List[Tuple[int, int, int, int]]
|
301 |
+
or masks: List[np.ndarray[width, height]]
|
302 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
303 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
304 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
305 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
306 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
307 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
311 |
+
|
312 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
313 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
314 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
315 |
+
`None`).
|
316 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
317 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
318 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
319 |
+
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
320 |
+
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
321 |
+
"""
|
322 |
+
|
323 |
+
output_kwargs = self._merge_kwargs(
|
324 |
+
Qwen2_5_VLProcessorKwargs,
|
325 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
326 |
+
**kwargs,
|
327 |
+
)
|
328 |
+
obj_ptrs = None
|
329 |
+
if images is not None:
|
330 |
+
image_inputs = self.image_processor(
|
331 |
+
images=images, videos=None, **output_kwargs["images_kwargs"]
|
332 |
+
)
|
333 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
334 |
+
|
335 |
+
for image in images:
|
336 |
+
assert isinstance(
|
337 |
+
image, Image.Image
|
338 |
+
), "only supporting a single image per row for now"
|
339 |
+
|
340 |
+
if regions is not None:
|
341 |
+
obj_ptrs = [
|
342 |
+
[
|
343 |
+
(
|
344 |
+
self.image_processor.get_patch_pointers(image.size, region)
|
345 |
+
if region is not None
|
346 |
+
else np.array([])
|
347 |
+
)
|
348 |
+
for region in image_region
|
349 |
+
]
|
350 |
+
for image, image_region in zip(images, regions)
|
351 |
+
]
|
352 |
+
else:
|
353 |
+
image_inputs = {}
|
354 |
+
image_grid_thw = None
|
355 |
+
|
356 |
+
assert videos is None, "video inputs are not supported yet" # TODO
|
357 |
+
if videos is not None:
|
358 |
+
videos_inputs = self.image_processor(
|
359 |
+
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
360 |
+
)
|
361 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
362 |
+
|
363 |
+
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
364 |
+
if isinstance(fps, (int, float)):
|
365 |
+
second_per_grid_ts = [
|
366 |
+
self.image_processor.temporal_patch_size / fps
|
367 |
+
] * len(video_grid_thw)
|
368 |
+
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
369 |
+
second_per_grid_ts = [
|
370 |
+
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
371 |
+
]
|
372 |
+
else:
|
373 |
+
raise ValueError(
|
374 |
+
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
375 |
+
)
|
376 |
+
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
377 |
+
|
378 |
+
else:
|
379 |
+
videos_inputs = {}
|
380 |
+
video_grid_thw = None
|
381 |
+
|
382 |
+
if not isinstance(text, list):
|
383 |
+
text = [text]
|
384 |
+
|
385 |
+
if image_grid_thw is not None:
|
386 |
+
merge_length = self.image_processor.merge_size**2
|
387 |
+
index = 0
|
388 |
+
for i in range(len(text)):
|
389 |
+
while self.image_token in text[i]:
|
390 |
+
text[i] = text[i].replace(
|
391 |
+
self.image_token,
|
392 |
+
"<|placeholder|>"
|
393 |
+
* (image_grid_thw[index].prod() // merge_length),
|
394 |
+
1,
|
395 |
+
)
|
396 |
+
index += 1
|
397 |
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
398 |
+
|
399 |
+
if obj_ptrs is not None:
|
400 |
+
assert regions is not None
|
401 |
+
for i in range(len(text)):
|
402 |
+
ptrs = obj_ptrs[i]
|
403 |
+
region = regions[i]
|
404 |
+
assert len(ptrs) == text[i].count(self.region_token)
|
405 |
+
index = 0
|
406 |
+
while self.region_token in text[i]:
|
407 |
+
ptrs_str = "".join([f"<|copy_{j}|>" for j in ptrs[index]])
|
408 |
+
region_str = self.image_processor.normalize_bbox(
|
409 |
+
image.size, region[index]
|
410 |
+
)
|
411 |
+
out_str = ("<region>" + ptrs_str + "</region>",)
|
412 |
+
if self.prepend_raw_region_to_text:
|
413 |
+
out_str = "<region>" + region_str + ptrs_str + "</region>"
|
414 |
+
|
415 |
+
text[i] = text[i].replace(
|
416 |
+
self.region_token,
|
417 |
+
out_str,
|
418 |
+
1,
|
419 |
+
)
|
420 |
+
index += 1
|
421 |
+
|
422 |
+
# text[i] = text[i].replace("<|placeholder|>", self.region_token)
|
423 |
+
|
424 |
+
if video_grid_thw is not None:
|
425 |
+
# TODO: support video inputs
|
426 |
+
merge_length = self.image_processor.merge_size**2
|
427 |
+
index = 0
|
428 |
+
for i in range(len(text)):
|
429 |
+
while self.video_token in text[i]:
|
430 |
+
text[i] = text[i].replace(
|
431 |
+
self.video_token,
|
432 |
+
"<patch>"
|
433 |
+
+ "<|placeholder|>"
|
434 |
+
* (video_grid_thw[index].prod() // merge_length)
|
435 |
+
+ "</patch>",
|
436 |
+
1,
|
437 |
+
)
|
438 |
+
index += 1
|
439 |
+
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
440 |
+
|
441 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
442 |
+
|
443 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
444 |
+
|
445 |
+
|
446 |
+
def get_processor(model_name: str, **kwargs):
|
447 |
+
processor = Qwen2_5_VLPointerProcessor.from_pretrained(model_name, **kwargs)
|
448 |
+
processor.image_processor = Qwen2VLImagePointerProcessor.from_pretrained(
|
449 |
+
model_name, **kwargs
|
450 |
+
)
|
451 |
+
# max_position_tokens = processor.tokenizer.model_max_length
|
452 |
+
# new_tokens = [f"<|copy_{i}|>" for i in range(max_position_tokens)] # too slow
|
453 |
+
processor.tokenizer.orig_vocab_size = len(processor.tokenizer)
|
454 |
+
new_tokens = [f"<|copy_{i}|>" for i in range(30000)]
|
455 |
+
processor.tokenizer.add_tokens(new_tokens)
|
456 |
+
processor.copy_token_start = processor.tokenizer.convert_tokens_to_ids("<|copy_0|>")
|
457 |
+
return processor
|
458 |
+
|
459 |
+
|
460 |
+
# Create a data collator to encode text and image pairs
|
461 |
+
def collate_fn(examples, processor):
|
462 |
+
# Get the texts and images, and apply the chat template
|
463 |
+
examples, masks = zip(*examples)
|
464 |
+
texts = [
|
465 |
+
processor.apply_chat_template(example, tokenize=False) for example in examples
|
466 |
+
] # Prepare texts for processing
|
467 |
+
image_inputs = [
|
468 |
+
process_vision_info(example)[0][0] for example in examples
|
469 |
+
] # Process the images to extract inputs
|
470 |
+
|
471 |
+
# Tokenize the texts and process the images
|
472 |
+
batch = processor(
|
473 |
+
text=texts,
|
474 |
+
images=image_inputs,
|
475 |
+
videos=None,
|
476 |
+
regions=masks,
|
477 |
+
padding=True,
|
478 |
+
return_tensors="pt",
|
479 |
+
) # Encode texts and images into tensors
|
480 |
+
|
481 |
+
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
482 |
+
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
483 |
+
labels[labels == processor.tokenizer.pad_token_id] = (
|
484 |
+
-100
|
485 |
+
) # Mask padding tokens in labels
|
486 |
+
|
487 |
+
# Ignore the image token index in the loss computation (model specific)
|
488 |
+
if isinstance(
|
489 |
+
processor, Qwen2VLImagePointerProcessor
|
490 |
+
): # Check if the processor is Qwen2VLProcessor
|
491 |
+
image_tokens = [
|
492 |
+
151652,
|
493 |
+
151653,
|
494 |
+
151655,
|
495 |
+
] # Specific image token IDs for Qwen2VLProcessor
|
496 |
+
else:
|
497 |
+
image_tokens = [
|
498 |
+
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
499 |
+
] # Convert image token to ID
|
500 |
+
|
501 |
+
# Mask image token IDs in the labels
|
502 |
+
for image_token_id in image_tokens:
|
503 |
+
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
|
504 |
+
|
505 |
+
batch["labels"] = labels # Add labels to the batch
|
506 |
+
|
507 |
+
return batch # Return the prepared batch
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == "__main__":
|
511 |
+
# processor = Qwen2VLImagePointerProcessor.from_pretrained(
|
512 |
+
# "Qwen/Qwen2.5-VL-7B-Instruct"
|
513 |
+
# )
|
514 |
+
|
515 |
+
# image_size = [1036, 756]
|
516 |
+
# regions = [[0, 20, 25, 120], [512, 600, 800, 800], [0, 0, 1023, 740]]
|
517 |
+
# processor.test(image_size, regions)
|
518 |
+
|
519 |
+
model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
|
520 |
+
processor = get_processor(model_name)
|
521 |
+
|
522 |
+
messages = [
|
523 |
+
{
|
524 |
+
"role": "user",
|
525 |
+
"content": [
|
526 |
+
{
|
527 |
+
"type": "image",
|
528 |
+
"image": "https://example---/demo.jpeg",
|
529 |
+
},
|
530 |
+
{"type": "text", "text": "Describe this image."},
|
531 |
+
],
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"role": "assistant",
|
535 |
+
"content": [
|
536 |
+
{
|
537 |
+
"type": "text",
|
538 |
+
"text": "<think>Theres a cat at <|region|>, a dog at <|region|>.</think>A calico cat hanging out with a golden retriever.",
|
539 |
+
}
|
540 |
+
],
|
541 |
+
},
|
542 |
+
]
|
543 |
+
image = Image.new("RGB", (800, 500), "black")
|
544 |
+
text = processor.apply_chat_template(
|
545 |
+
messages, tokenize=False, add_generation_prompt=True
|
546 |
+
)
|
547 |
+
bboxes = [[0, 10, 100, 200], [300, 0, 600, 250]]
|
548 |
+
inputs = processor(
|
549 |
+
text=[text],
|
550 |
+
images=[image],
|
551 |
+
videos=None,
|
552 |
+
regions=[bboxes],
|
553 |
+
padding=True,
|
554 |
+
return_tensors="pt",
|
555 |
+
)
|
556 |
+
text = processor.tokenizer.decode(inputs.input_ids[0])
|
557 |
+
print(text)
|
558 |
+
masks, image_size = processor.extract_masks(image.size, text)
|
559 |
+
import ipdb; ipdb.set_trace() # noqa # fmt: skip
|
modeling_v1.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Union, Tuple, List
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
10 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
11 |
+
Qwen2_5_VLModel,
|
12 |
+
Qwen2_5_VLForConditionalGeneration,
|
13 |
+
Qwen2_5_VLCausalLMOutputWithPast,
|
14 |
+
)
|
15 |
+
|
16 |
+
from .configuration_v1 import V1Config
|
17 |
+
|
18 |
+
|
19 |
+
def init_identity(layer, scale: float = 1):
|
20 |
+
if isinstance(layer, nn.Linear):
|
21 |
+
with torch.no_grad():
|
22 |
+
# Ensure weight matrix is square
|
23 |
+
rows, cols = layer.weight.shape
|
24 |
+
identity_matrix = (
|
25 |
+
torch.eye(rows, cols) * scale
|
26 |
+
) # Creates an identity matrix
|
27 |
+
layer.weight.copy_(
|
28 |
+
identity_matrix
|
29 |
+
) # Copy identity matrix into layer weights
|
30 |
+
if hasattr(layer, "bias"):
|
31 |
+
layer.bias.fill_(0) # Set bias to zero (or another value if needed)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class V1CausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
|
36 |
+
z_loss: torch.Tensor = None
|
37 |
+
gen_loss: torch.Tensor = None
|
38 |
+
copy_loss: torch.Tensor = None
|
39 |
+
|
40 |
+
|
41 |
+
class V1ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
42 |
+
config_class = V1Config
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super().__init__(config)
|
46 |
+
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
|
47 |
+
config.vision_config
|
48 |
+
)
|
49 |
+
self.model = Qwen2_5_VLModel(config)
|
50 |
+
self.copy_init_scale = 1 / math.sqrt(self.config.hidden_size)
|
51 |
+
|
52 |
+
# self.tokenizer_vocab_size = (
|
53 |
+
# config.tokenizer_vocab_size
|
54 |
+
# ) # Qwen2.5-VL: different from embedding_size==vocab_size. 151665 vs. 152064
|
55 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
56 |
+
self.rope_deltas = None # cache rope_deltas here
|
57 |
+
|
58 |
+
if self.config.do_copy:
|
59 |
+
if self.config.tie_copy_heads:
|
60 |
+
self._copy_head = nn.Linear(config.hidden_size, config.copy_hidden_size)
|
61 |
+
else:
|
62 |
+
self._copy_q_head = nn.Linear(
|
63 |
+
config.hidden_size, config.copy_hidden_size
|
64 |
+
)
|
65 |
+
self._copy_k_head = nn.Linear(
|
66 |
+
config.hidden_size, config.copy_hidden_size
|
67 |
+
)
|
68 |
+
if self.config.use_gate:
|
69 |
+
self.gate = nn.Linear(config.hidden_size, 1, bias=False)
|
70 |
+
|
71 |
+
# Initialize weights and apply final processing
|
72 |
+
self.post_init()
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def after_loading(self):
|
76 |
+
if self.config.do_copy:
|
77 |
+
self.init_heads()
|
78 |
+
if self.config.use_gate:
|
79 |
+
self.lm_head.weight.data = self.lm_head.weight.data * 2
|
80 |
+
self.gate.weight.data.fill_(0)
|
81 |
+
|
82 |
+
@property
|
83 |
+
def copy_q_head(self):
|
84 |
+
return self._copy_head if self.config.tie_copy_heads else self._copy_q_head
|
85 |
+
|
86 |
+
@property
|
87 |
+
def copy_k_head(self):
|
88 |
+
return self._copy_head if self.config.tie_copy_heads else self._copy_k_head
|
89 |
+
|
90 |
+
def init_heads(self):
|
91 |
+
if hasattr(self, "_copy_head"):
|
92 |
+
init_identity(self._copy_head, self.copy_init_scale)
|
93 |
+
if hasattr(self, "_copy_k_head"):
|
94 |
+
init_identity(self._copy_k_head, self.copy_init_scale)
|
95 |
+
if hasattr(self, "_copy_q_head"):
|
96 |
+
init_identity(self._copy_q_head, self.copy_init_scale)
|
97 |
+
|
98 |
+
def copy_representations(
|
99 |
+
self,
|
100 |
+
inputs_embeds: torch.FloatTensor,
|
101 |
+
input_ids: torch.LongTensor,
|
102 |
+
copy_values: Optional[torch.FloatTensor] = None,
|
103 |
+
):
|
104 |
+
if copy_values is None:
|
105 |
+
mask = input_ids == self.config.image_token_id
|
106 |
+
copy_values, _ = self.extract_image_tokens(inputs_embeds, mask) # initial
|
107 |
+
assert copy_values is not None
|
108 |
+
copy_values = copy_values.to(inputs_embeds.device)
|
109 |
+
input_ids = input_ids.to(inputs_embeds.device)
|
110 |
+
|
111 |
+
input_ids = input_ids.clone()
|
112 |
+
input_ids = input_ids - self.config.copy_token_start
|
113 |
+
copy_mask = input_ids >= 0
|
114 |
+
input_ids[~copy_mask] = 0
|
115 |
+
|
116 |
+
assert copy_values is not None
|
117 |
+
extracted = copy_values.gather(
|
118 |
+
1, input_ids[..., None].repeat(1, 1, copy_values.shape[-1])
|
119 |
+
)
|
120 |
+
copy_mask = copy_mask.to(extracted.dtype)[..., None]
|
121 |
+
return copy_mask * extracted + (1 - copy_mask) * inputs_embeds
|
122 |
+
|
123 |
+
def extract_image_tokens(self, features: torch.FloatTensor, mask: torch.Tensor):
|
124 |
+
out_feat, out_mask = extract_image_tokens_right_pad(features, mask)
|
125 |
+
return out_feat, out_mask
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
input_ids: torch.LongTensor = None,
|
130 |
+
attention_mask: Optional[torch.Tensor] = None,
|
131 |
+
position_ids: Optional[torch.LongTensor] = None,
|
132 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
133 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
134 |
+
labels: Optional[torch.LongTensor] = None,
|
135 |
+
use_cache: Optional[bool] = None,
|
136 |
+
output_attentions: Optional[bool] = None,
|
137 |
+
output_hidden_states: Optional[bool] = None,
|
138 |
+
return_dict: Optional[bool] = None,
|
139 |
+
pixel_values: Optional[torch.Tensor] = None,
|
140 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
141 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
142 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
143 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
144 |
+
cache_position: Optional[torch.LongTensor] = None,
|
145 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
146 |
+
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
147 |
+
r"""
|
148 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
149 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
150 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
151 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
|
155 |
+
Example:
|
156 |
+
|
157 |
+
```python
|
158 |
+
>>> from PIL import Image
|
159 |
+
>>> import requests
|
160 |
+
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
161 |
+
|
162 |
+
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
163 |
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
164 |
+
|
165 |
+
>>> messages = [
|
166 |
+
{
|
167 |
+
"role": "user",
|
168 |
+
"content": [
|
169 |
+
{"type": "image"},
|
170 |
+
{"type": "text", "text": "What is shown in this image?"},
|
171 |
+
],
|
172 |
+
},
|
173 |
+
]
|
174 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
175 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
176 |
+
|
177 |
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
178 |
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
179 |
+
|
180 |
+
>>> # Generate
|
181 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
182 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
183 |
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
184 |
+
```"""
|
185 |
+
|
186 |
+
output_attentions = (
|
187 |
+
output_attentions
|
188 |
+
if output_attentions is not None
|
189 |
+
else self.config.output_attentions
|
190 |
+
)
|
191 |
+
output_hidden_states = (
|
192 |
+
output_hidden_states
|
193 |
+
if output_hidden_states is not None
|
194 |
+
else self.config.output_hidden_states
|
195 |
+
)
|
196 |
+
return_dict = (
|
197 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
198 |
+
)
|
199 |
+
|
200 |
+
input_ids = input_ids.clone()
|
201 |
+
input_ids_with_ptrs = input_ids.clone()
|
202 |
+
input_ids[input_ids >= self.config.copy_token_start] = (
|
203 |
+
self.config.region_token_id
|
204 |
+
)
|
205 |
+
|
206 |
+
if inputs_embeds is None:
|
207 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
208 |
+
if pixel_values is not None:
|
209 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
210 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
211 |
+
|
212 |
+
mask = input_ids == self.config.image_token_id
|
213 |
+
mask_unsqueezed = mask.unsqueeze(-1)
|
214 |
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
215 |
+
image_mask = mask_expanded.to(inputs_embeds.device)
|
216 |
+
|
217 |
+
image_embeds = image_embeds.to(
|
218 |
+
inputs_embeds.device, inputs_embeds.dtype
|
219 |
+
)
|
220 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
221 |
+
|
222 |
+
if pixel_values_videos is not None:
|
223 |
+
raise NotImplementedError("video inputs are not supported yet.")
|
224 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
225 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
226 |
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
227 |
+
n_video_features = video_embeds.shape[0]
|
228 |
+
if n_video_tokens != n_video_features:
|
229 |
+
raise ValueError(
|
230 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
231 |
+
)
|
232 |
+
|
233 |
+
mask = input_ids == self.config.video_token_id
|
234 |
+
mask_unsqueezed = mask.unsqueeze(-1)
|
235 |
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
236 |
+
video_mask = mask_expanded.to(inputs_embeds.device)
|
237 |
+
|
238 |
+
video_embeds = video_embeds.to(
|
239 |
+
inputs_embeds.device, inputs_embeds.dtype
|
240 |
+
)
|
241 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
242 |
+
|
243 |
+
if attention_mask is not None:
|
244 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
245 |
+
|
246 |
+
if self.config.do_copy:
|
247 |
+
copy_keys, copy_keys_mask = None, None
|
248 |
+
copy_values, copy_values_mask = None, None
|
249 |
+
|
250 |
+
has_cache = bool(past_key_values)
|
251 |
+
if has_cache:
|
252 |
+
copy_keys, copy_values = past_key_values[len(past_key_values) - 2]
|
253 |
+
copy_keys_mask, copy_values_mask = past_key_values[
|
254 |
+
len(past_key_values) - 1
|
255 |
+
]
|
256 |
+
# we add channel dim to the mask for consistency in tensor shape in cache
|
257 |
+
copy_keys_mask = copy_keys_mask[..., 0]
|
258 |
+
copy_values_mask = copy_values_mask[..., 0]
|
259 |
+
|
260 |
+
inputs_embeds = self.copy_representations(
|
261 |
+
inputs_embeds, input_ids_with_ptrs, copy_values
|
262 |
+
)
|
263 |
+
|
264 |
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
265 |
+
if position_ids is None and (
|
266 |
+
attention_mask is None or attention_mask.ndim == 2
|
267 |
+
):
|
268 |
+
# calculate RoPE index once per generation in the pre-fill stage only
|
269 |
+
if (
|
270 |
+
(cache_position is not None and cache_position[0] == 0)
|
271 |
+
or self.rope_deltas is None
|
272 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
273 |
+
):
|
274 |
+
position_ids, rope_deltas = self.get_rope_index(
|
275 |
+
input_ids,
|
276 |
+
image_grid_thw,
|
277 |
+
video_grid_thw,
|
278 |
+
second_per_grid_ts,
|
279 |
+
attention_mask,
|
280 |
+
)
|
281 |
+
self.rope_deltas = rope_deltas
|
282 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
283 |
+
else:
|
284 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
285 |
+
delta = (
|
286 |
+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
287 |
+
if cache_position is not None
|
288 |
+
else 0
|
289 |
+
)
|
290 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
291 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
292 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
293 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
294 |
+
position_ids = position_ids.add(delta)
|
295 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
296 |
+
|
297 |
+
outputs = self.model(
|
298 |
+
input_ids=None,
|
299 |
+
position_ids=position_ids,
|
300 |
+
attention_mask=attention_mask,
|
301 |
+
past_key_values=past_key_values,
|
302 |
+
inputs_embeds=inputs_embeds,
|
303 |
+
use_cache=use_cache,
|
304 |
+
output_attentions=output_attentions,
|
305 |
+
output_hidden_states=output_hidden_states,
|
306 |
+
return_dict=return_dict,
|
307 |
+
cache_position=cache_position,
|
308 |
+
)
|
309 |
+
|
310 |
+
hidden_states = outputs[0]
|
311 |
+
|
312 |
+
gen_logits = self.lm_head(hidden_states)
|
313 |
+
|
314 |
+
if self.config.do_copy:
|
315 |
+
assert (
|
316 |
+
self.config.copy_extraction_layer == -1
|
317 |
+
), f"copy_extraction_layer should be -1: {self.config.copy_extraction_layer}"
|
318 |
+
copy_hidden_states = hidden_states
|
319 |
+
copy_q_states = copy_hidden_states
|
320 |
+
if self.config.normalize_copy_states:
|
321 |
+
copy_q_states = F.normalize(copy_q_states, 2, -1)
|
322 |
+
copy_q_states = self.copy_q_head(copy_q_states)
|
323 |
+
|
324 |
+
present_key_values = outputs.past_key_values
|
325 |
+
|
326 |
+
if not has_cache:
|
327 |
+
mask = input_ids == self.config.image_token_id
|
328 |
+
copy_k_states = (
|
329 |
+
inputs_embeds
|
330 |
+
if self.config.use_embeddings_as_keys
|
331 |
+
else copy_hidden_states
|
332 |
+
)
|
333 |
+
if self.config.normalize_copy_states:
|
334 |
+
copy_k_states = F.normalize(copy_k_states, 2, -1)
|
335 |
+
copy_k_states, copy_k_mask = self.extract_image_tokens(
|
336 |
+
self.copy_k_head(copy_k_states), mask
|
337 |
+
)
|
338 |
+
copy_v_states, copy_v_mask = self.extract_image_tokens(
|
339 |
+
inputs_embeds.detach(), mask
|
340 |
+
)
|
341 |
+
|
342 |
+
# we add channel dim to the mask for consistency in tensor shape in cache
|
343 |
+
copy_memories = [
|
344 |
+
(copy_k_states.detach(), copy_v_states.detach()),
|
345 |
+
(copy_k_mask[..., None], copy_v_mask[..., None]),
|
346 |
+
]
|
347 |
+
|
348 |
+
if use_cache:
|
349 |
+
# only update at the first iteration
|
350 |
+
start = len(present_key_values)
|
351 |
+
for i, mem in enumerate(copy_memories):
|
352 |
+
present_key_values.update(*mem, start + i)
|
353 |
+
else:
|
354 |
+
copy_k_states = copy_keys
|
355 |
+
copy_k_mask = copy_keys_mask
|
356 |
+
|
357 |
+
assert copy_k_states is not None
|
358 |
+
assert copy_k_mask is not None
|
359 |
+
assert (
|
360 |
+
copy_k_states.shape[1] > 0
|
361 |
+
), f"zero image tokens on batch elements: {copy_k_mask.sum(dim=1)}"
|
362 |
+
|
363 |
+
copy_logits = (copy_q_states @ copy_k_states.transpose(-1, -2)).to(
|
364 |
+
gen_logits.device
|
365 |
+
) * self.copy_init_scale
|
366 |
+
|
367 |
+
if hasattr(self, "gate"):
|
368 |
+
gate = torch.sigmoid(self.gate(hidden_states))
|
369 |
+
gen_logits = gen_logits * (1 - gate)
|
370 |
+
copy_logits = copy_logits * gate
|
371 |
+
|
372 |
+
copy_logits = copy_logits.masked_fill(
|
373 |
+
~copy_k_mask[:, None, :].to(copy_logits.device),
|
374 |
+
torch.finfo(copy_logits.dtype).min,
|
375 |
+
)
|
376 |
+
logits = torch.cat(
|
377 |
+
[gen_logits[..., : self.config.copy_token_start], copy_logits], dim=-1
|
378 |
+
)
|
379 |
+
else:
|
380 |
+
logits = gen_logits
|
381 |
+
loss = None
|
382 |
+
z_loss = None
|
383 |
+
gen_loss = None
|
384 |
+
if labels is not None:
|
385 |
+
gen_logits = gen_logits.float()
|
386 |
+
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
|
387 |
+
shift_labels = labels[:, 1:].contiguous()
|
388 |
+
gen_loss_fct = CrossEntropyLoss(reduction="none")
|
389 |
+
gen_logits_flat = shift_gen_logits.view(-1, shift_gen_logits.shape[-1])
|
390 |
+
gen_labels_flat = shift_labels.view(-1)
|
391 |
+
|
392 |
+
gen_loss_all = gen_loss_fct(gen_logits_flat, gen_labels_flat)
|
393 |
+
gen_loss = gen_loss_all.mean()
|
394 |
+
|
395 |
+
loss = gen_loss
|
396 |
+
|
397 |
+
if self.config.z_loss_weight > 0:
|
398 |
+
valid_mask = shift_labels >= 0
|
399 |
+
# top-k approx z_loss for better memory usage
|
400 |
+
top_logits, _ = torch.topk(
|
401 |
+
shift_gen_logits, k=self.config.z_loss_top_k, dim=-1
|
402 |
+
)
|
403 |
+
lse = torch.logsumexp(top_logits, dim=-1)
|
404 |
+
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight
|
405 |
+
|
406 |
+
# z_loss = (
|
407 |
+
# torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
|
408 |
+
# * self.config.z_loss_weight
|
409 |
+
# )
|
410 |
+
loss = loss + z_loss
|
411 |
+
z_loss = z_loss.detach()
|
412 |
+
|
413 |
+
return V1CausalLMOutputWithPast(
|
414 |
+
loss=loss,
|
415 |
+
z_loss=z_loss,
|
416 |
+
gen_loss=gen_loss,
|
417 |
+
copy_loss=None,
|
418 |
+
logits=logits,
|
419 |
+
# copy_logits=copy_logits,
|
420 |
+
# gen_logits=gen_logits,
|
421 |
+
past_key_values=outputs.past_key_values,
|
422 |
+
hidden_states=outputs.hidden_states,
|
423 |
+
attentions=outputs.attentions,
|
424 |
+
rope_deltas=self.rope_deltas,
|
425 |
+
)
|
426 |
+
|
427 |
+
loss = None
|
428 |
+
z_loss = None
|
429 |
+
gen_loss = None
|
430 |
+
copy_loss = None
|
431 |
+
if labels is not None:
|
432 |
+
if self.config.separate_copy_loss:
|
433 |
+
# Shift labels and logits for next-token prediction
|
434 |
+
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
|
435 |
+
shift_copy_logits = copy_logits[:, :-1, :].contiguous().float()
|
436 |
+
shift_labels = labels[:, 1:].contiguous()
|
437 |
+
shift_logits = shift_copy_logits
|
438 |
+
|
439 |
+
# Build masks
|
440 |
+
gen_mask = shift_labels < self.config.copy_token_start
|
441 |
+
copy_mask = shift_labels >= self.config.copy_token_start
|
442 |
+
|
443 |
+
# Generation loss
|
444 |
+
if gen_mask.any():
|
445 |
+
gen_loss_fct = CrossEntropyLoss(reduction="none")
|
446 |
+
|
447 |
+
G = shift_gen_logits.shape[-1]
|
448 |
+
gen_logits_flat = shift_gen_logits.view(-1, G)
|
449 |
+
gen_labels_flat = shift_labels.view(-1)
|
450 |
+
gen_mask_flat = gen_mask.view(-1)
|
451 |
+
# mask logits
|
452 |
+
gen_logits_flat_masked = gen_logits_flat[gen_mask_flat]
|
453 |
+
gen_labels_flat_masked = gen_labels_flat[gen_mask_flat]
|
454 |
+
|
455 |
+
gen_loss_all = gen_loss_fct(
|
456 |
+
gen_logits_flat_masked, gen_labels_flat_masked
|
457 |
+
)
|
458 |
+
gen_loss = gen_loss_all.mean()
|
459 |
+
|
460 |
+
# Copy loss (adjust label indices to match copy_logits range)
|
461 |
+
if copy_mask.any():
|
462 |
+
copy_loss_fct = CrossEntropyLoss(reduction="none")
|
463 |
+
C = shift_copy_logits.shape[-1]
|
464 |
+
copy_logits_flat = shift_copy_logits.view(-1, C)
|
465 |
+
copy_labels_flat = (
|
466 |
+
shift_labels.view(-1) - self.config.copy_token_start
|
467 |
+
)
|
468 |
+
copy_mask_flat = copy_mask.view(-1)
|
469 |
+
copy_logits_flat_masked = copy_logits_flat[copy_mask_flat]
|
470 |
+
copy_labels_flat_masked = copy_labels_flat[copy_mask_flat]
|
471 |
+
copy_loss_all = copy_loss_fct(
|
472 |
+
copy_logits_flat_masked, copy_labels_flat_masked
|
473 |
+
)
|
474 |
+
copy_loss = copy_loss_all.mean()
|
475 |
+
else:
|
476 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
477 |
+
logits = logits.float()
|
478 |
+
# Shift so that tokens < n predict n
|
479 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
480 |
+
shift_labels = labels[..., 1:].contiguous()
|
481 |
+
# Flatten the tokens
|
482 |
+
loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
|
483 |
+
total_vocab_size = logits.shape[-1] # gen + copy
|
484 |
+
shift_logits = shift_logits.view(-1, total_vocab_size)
|
485 |
+
shift_labels = shift_labels.view(-1)
|
486 |
+
# Enable model parallelism
|
487 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
488 |
+
gen_loss = loss_fct(shift_logits, shift_labels)
|
489 |
+
|
490 |
+
loss = 0.0
|
491 |
+
if gen_loss is not None:
|
492 |
+
loss += gen_loss
|
493 |
+
if copy_loss is not None:
|
494 |
+
loss += copy_loss
|
495 |
+
|
496 |
+
if self.config.z_loss_weight > 0:
|
497 |
+
valid_mask = shift_labels >= 0
|
498 |
+
# top-k approx z_loss for better memory usage
|
499 |
+
top_logits, _ = torch.topk(
|
500 |
+
shift_logits, k=self.config.z_loss_top_k, dim=-1
|
501 |
+
)
|
502 |
+
lse = torch.logsumexp(top_logits, dim=-1)
|
503 |
+
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight
|
504 |
+
|
505 |
+
# z_loss = (
|
506 |
+
# torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
|
507 |
+
# * self.config.z_loss_weight
|
508 |
+
# )
|
509 |
+
loss = loss + z_loss
|
510 |
+
z_loss = z_loss.detach()
|
511 |
+
|
512 |
+
if gen_loss is not None:
|
513 |
+
gen_loss = gen_loss.detach()
|
514 |
+
if copy_loss is not None:
|
515 |
+
copy_loss = copy_loss.detach()
|
516 |
+
|
517 |
+
if self.config.use_cfg:
|
518 |
+
# expand as max_size for logit processors
|
519 |
+
extended_vocab_size = self.config.vocab_size + self.config.copy_token_num
|
520 |
+
B, L, V = logits.shape
|
521 |
+
pads = torch.full(
|
522 |
+
(B, L, extended_vocab_size - V),
|
523 |
+
torch.finfo(gen_logits.dtype).min,
|
524 |
+
device=logits.device,
|
525 |
+
).to(logits.dtype)
|
526 |
+
logits = torch.cat([logits, pads], dim=-1)
|
527 |
+
# logits = logits.clamp_min(-1e4)
|
528 |
+
|
529 |
+
if not return_dict:
|
530 |
+
output = (logits,) + outputs[1:]
|
531 |
+
return (loss,) + output if loss is not None else output
|
532 |
+
|
533 |
+
logits = logits.float()
|
534 |
+
return V1CausalLMOutputWithPast(
|
535 |
+
loss=loss,
|
536 |
+
z_loss=z_loss,
|
537 |
+
gen_loss=gen_loss,
|
538 |
+
copy_loss=copy_loss,
|
539 |
+
logits=logits,
|
540 |
+
# copy_logits=copy_logits,
|
541 |
+
# gen_logits=gen_logits,
|
542 |
+
past_key_values=present_key_values,
|
543 |
+
hidden_states=outputs.hidden_states,
|
544 |
+
attentions=outputs.attentions,
|
545 |
+
rope_deltas=self.rope_deltas,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
def extract_image_tokens_right_pad(features: torch.FloatTensor, mask: torch.Tensor):
|
550 |
+
X, M = features, mask.long() # bool is not supported for sort in CUDA
|
551 |
+
B, L, _ = X.shape
|
552 |
+
device = X.device
|
553 |
+
M = M.to(device)
|
554 |
+
|
555 |
+
# Compute number of valid elements per batch
|
556 |
+
valid_counts = M.sum(dim=1) # Shape: [B]
|
557 |
+
# Replace `.item()` with `max()` and `clamp_min()` for Torch Dynamo compatibility
|
558 |
+
R = valid_counts.max().clamp_min(1) # Ensures at least 1 for tensor compatibility
|
559 |
+
# Create index tensors for selection
|
560 |
+
sorted_indices = M.argsort(dim=1, descending=True) # Move True values to front
|
561 |
+
batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, L)
|
562 |
+
|
563 |
+
# Gather sorted X based on mask sorting
|
564 |
+
X_sorted = X[batch_indices, sorted_indices] # Shape: [B, L, C]
|
565 |
+
X_selected = X_sorted[:, :R, :] # Select the top valid elements per batch
|
566 |
+
|
567 |
+
# Create new mask M2 using `torch.arange`
|
568 |
+
M2 = torch.arange(L, device=device).expand(B, L) < valid_counts.unsqueeze(1)
|
569 |
+
M2 = M2[:, :R] # Trim to selected size
|
570 |
+
|
571 |
+
# Set out-of-bound values to zero
|
572 |
+
X_selected = torch.where(M2.unsqueeze(-1), X_selected, torch.zeros_like(X_selected))
|
573 |
+
|
574 |
+
return X_selected, M2
|
575 |
+
|
576 |
+
|
577 |
+
__all__ = ["V1ForConditionalGeneration"]
|
preprocessor_config.json
CHANGED
@@ -18,7 +18,7 @@
|
|
18 |
"merge_size": 2,
|
19 |
"min_pixels": 3136,
|
20 |
"patch_size": 14,
|
21 |
-
"processor_class": "
|
22 |
"resample": 3,
|
23 |
"rescale_factor": 0.00392156862745098,
|
24 |
"size": {
|
|
|
18 |
"merge_size": 2,
|
19 |
"min_pixels": 3136,
|
20 |
"patch_size": 14,
|
21 |
+
"processor_class": "V1Processor",
|
22 |
"resample": 3,
|
23 |
"rescale_factor": 0.00392156862745098,
|
24 |
"size": {
|
processor.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Optional, Union
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from qwen_vl_utils import process_vision_info
|
11 |
+
from transformers.feature_extraction_utils import BatchFeature
|
12 |
+
from transformers.image_utils import ImageInput, VideoInput
|
13 |
+
from transformers.processing_utils import (
|
14 |
+
Unpack,
|
15 |
+
)
|
16 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
17 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
18 |
+
smart_resize,
|
19 |
+
Qwen2VLImageProcessor,
|
20 |
+
)
|
21 |
+
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import (
|
22 |
+
Qwen2_5_VLProcessorKwargs,
|
23 |
+
Qwen2_5_VLProcessor,
|
24 |
+
)
|
25 |
+
|
26 |
+
"""
|
27 |
+
Qwen2.5-VL does not use AnyRes to my relief.
|
28 |
+
Things to take into account:
|
29 |
+
- smart_resize
|
30 |
+
- temporal dimension
|
31 |
+
- grid_t = patches.shape[0] // self.temporal_patch_size
|
32 |
+
- grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
33 |
+
- merge_size (2)
|
34 |
+
|
35 |
+
|
36 |
+
Usage:
|
37 |
+
|
38 |
+
model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
|
39 |
+
|
40 |
+
|
41 |
+
processor = Qwen2_5_VLPointerProcessor.from_pretrained(model_name)
|
42 |
+
processor.image_processor = Qwen2VLImagePointerProcessor.from_pretrained(model_name)
|
43 |
+
|
44 |
+
messages = [
|
45 |
+
{
|
46 |
+
"role": "user",
|
47 |
+
"content": [
|
48 |
+
{
|
49 |
+
"type": "image",
|
50 |
+
"image": "https://example---/demo.jpeg",
|
51 |
+
},
|
52 |
+
{"type": "text", "text": "Describe this image."},
|
53 |
+
],
|
54 |
+
},
|
55 |
+
{
|
56 |
+
'role': 'assistant',
|
57 |
+
'content': [
|
58 |
+
{
|
59 |
+
'type': 'text', 'text': '<think>Theres a cat at <|region|>, a dog at <|region|>.</think>A calico cat hanging out with a golden retriever.'
|
60 |
+
}
|
61 |
+
]
|
62 |
+
}
|
63 |
+
]
|
64 |
+
|
65 |
+
# Preparation for inference
|
66 |
+
text = processor.apply_chat_template(
|
67 |
+
messages, tokenize=False, add_generation_prompt=True
|
68 |
+
)
|
69 |
+
regions = [
|
70 |
+
[0, 10, 100, 200],
|
71 |
+
[300, 0, 600, 250]
|
72 |
+
]
|
73 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
74 |
+
inputs = processor(
|
75 |
+
text=[text],
|
76 |
+
images=image_inputs,
|
77 |
+
videos=video_inputs,
|
78 |
+
regions=[regions]
|
79 |
+
padding=True,
|
80 |
+
return_tensors="pt",
|
81 |
+
)
|
82 |
+
inputs = inputs.to("cuda")
|
83 |
+
|
84 |
+
|
85 |
+
# Qwen2VLImageProcessor in a nutshell
|
86 |
+
'(tl tp) c (hlm hm hp) (wlm wm wp) -> (tl hlm wlm hm wm) (c tp hp wp)'
|
87 |
+
"""
|
88 |
+
|
89 |
+
|
90 |
+
BBOX = Tuple[int, int, int, int]
|
91 |
+
|
92 |
+
|
93 |
+
class PointerProcessor:
|
94 |
+
@staticmethod
|
95 |
+
def normalize_bbox(image_size: Tuple[int, int], bbox: BBOX):
|
96 |
+
w, h = image_size
|
97 |
+
bbox = [
|
98 |
+
bbox[0] / w,
|
99 |
+
bbox[1] / h,
|
100 |
+
bbox[2] / w,
|
101 |
+
bbox[3] / h,
|
102 |
+
]
|
103 |
+
return "[{}]".format(", ".join([f"{v:.2f}" for v in bbox]))
|
104 |
+
|
105 |
+
def get_mask(self, image_size: Tuple[int, int], indices: List[int]):
|
106 |
+
width, height = image_size
|
107 |
+
resized_height, resized_width = smart_resize(
|
108 |
+
height,
|
109 |
+
width,
|
110 |
+
factor=self.patch_size * self.merge_size,
|
111 |
+
min_pixels=self.min_pixels,
|
112 |
+
max_pixels=self.max_pixels,
|
113 |
+
)
|
114 |
+
|
115 |
+
# grid_h = resized_height // self.patch_size // self.merge_size
|
116 |
+
grid_w_m = resized_width // self.patch_size // self.merge_size
|
117 |
+
|
118 |
+
mask = torch.zeros(resized_height, resized_width)
|
119 |
+
for index in indices:
|
120 |
+
index_h = index // grid_w_m
|
121 |
+
index_w = index % grid_w_m
|
122 |
+
bbox = (
|
123 |
+
max(index_w * self.patch_size * self.merge_size, 0),
|
124 |
+
max(index_h * self.patch_size * self.merge_size, 0),
|
125 |
+
min((index_w + 1) * self.patch_size * self.merge_size, resized_width),
|
126 |
+
min((index_h + 1) * self.patch_size * self.merge_size, resized_height),
|
127 |
+
)
|
128 |
+
x1, y1, x2, y2 = bbox
|
129 |
+
mask[y1:y2, x1:x2] = 1
|
130 |
+
# mask = mask.t() # to width, height
|
131 |
+
return mask, (resized_width, resized_height)
|
132 |
+
|
133 |
+
def get_patch_pointers(
|
134 |
+
self, image_size: Tuple[int, int], region: Union[BBOX, np.ndarray]
|
135 |
+
):
|
136 |
+
if isinstance(region, np.ndarray):
|
137 |
+
return self.get_mask_patch_pointers(image_size, region)
|
138 |
+
else:
|
139 |
+
return self.get_bbox_patch_pointers(image_size, region)
|
140 |
+
|
141 |
+
def get_bbox_patch_pointers(self, image_size: Tuple[int, int], bbox: BBOX):
|
142 |
+
factor = self.merge_size
|
143 |
+
# factor = 1
|
144 |
+
width, height = image_size
|
145 |
+
resized_height, resized_width = smart_resize(
|
146 |
+
height,
|
147 |
+
width,
|
148 |
+
factor=self.patch_size * self.merge_size,
|
149 |
+
min_pixels=self.min_pixels,
|
150 |
+
max_pixels=self.max_pixels,
|
151 |
+
)
|
152 |
+
x0, y0, x1, y1 = bbox
|
153 |
+
resized_bbox = [
|
154 |
+
max(x0 / width * resized_width, 0),
|
155 |
+
max(y0 / height * resized_height, 0),
|
156 |
+
min(x1 / width * resized_width, resized_width),
|
157 |
+
min(y1 / height * resized_height, resized_height),
|
158 |
+
]
|
159 |
+
# patch_bbox = [v / self.patch_size / self.merge_size for v in resized_bbox]
|
160 |
+
patch_bbox = [v / self.patch_size / factor for v in resized_bbox]
|
161 |
+
x0, y0, x1, y1 = patch_bbox
|
162 |
+
boundaries = [
|
163 |
+
math.floor(x0),
|
164 |
+
math.floor(y0),
|
165 |
+
math.ceil(x1),
|
166 |
+
math.ceil(y1),
|
167 |
+
]
|
168 |
+
x0, y0, x1, y1 = boundaries
|
169 |
+
|
170 |
+
# t, h, w
|
171 |
+
grid_w = resized_width // self.patch_size
|
172 |
+
grid_w_m = grid_w // factor
|
173 |
+
rows, cols = np.meshgrid(np.arange(y0, y1), np.arange(x0, x1), indexing="ij")
|
174 |
+
grid_indices = np.column_stack((rows.ravel(), cols.ravel()))
|
175 |
+
indices = grid_indices[:, 0] * grid_w_m + grid_indices[:, 1]
|
176 |
+
base_ids = list(indices)
|
177 |
+
ids = np.array(base_ids)
|
178 |
+
return ids
|
179 |
+
|
180 |
+
def get_mask_patch_pointers(self, image_size: Tuple[int, int], mask: np.ndarray):
|
181 |
+
# mask size: w h
|
182 |
+
width, height = image_size
|
183 |
+
resized_height, resized_width = smart_resize(
|
184 |
+
height,
|
185 |
+
width,
|
186 |
+
factor=self.patch_size * self.merge_size,
|
187 |
+
min_pixels=self.min_pixels,
|
188 |
+
max_pixels=self.max_pixels,
|
189 |
+
)
|
190 |
+
grid_w_m = resized_width // self.patch_size // self.merge_size
|
191 |
+
grid_h_m = resized_height // self.patch_size // self.merge_size
|
192 |
+
|
193 |
+
m = torch.from_numpy(mask).float()
|
194 |
+
m = F.interpolate(
|
195 |
+
m[None, None], (grid_h_m, grid_w_m), mode="bilinear", antialias="bilinear"
|
196 |
+
)[0, 0]
|
197 |
+
|
198 |
+
grid_indices = m.nonzero(as_tuple=False)
|
199 |
+
indices = grid_indices[:, 0] * grid_w_m + grid_indices[:, 1]
|
200 |
+
ids = indices.numpy()
|
201 |
+
return ids
|
202 |
+
|
203 |
+
def renormalize(self, tensor):
|
204 |
+
# crude - non-accurate implementation for the lazy
|
205 |
+
mean = np.array(self.image_mean).mean()
|
206 |
+
std = np.array(self.image_std).mean()
|
207 |
+
return tensor * std + mean
|
208 |
+
|
209 |
+
class Qwen2VLImagePointerProcessor(Qwen2VLImageProcessor, PointerProcessor):
|
210 |
+
pass
|
211 |
+
|
212 |
+
|
213 |
+
class V1Processor(Qwen2_5_VLProcessor):
|
214 |
+
image_processor_class = "Qwen2VLImagePointerProcessor"
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
image_processor=None,
|
219 |
+
tokenizer=None,
|
220 |
+
chat_template=None,
|
221 |
+
prepend_raw_region_to_text: bool = True,
|
222 |
+
separate_copy_loss: bool = False,
|
223 |
+
**kwargs,
|
224 |
+
):
|
225 |
+
super().__init__(
|
226 |
+
image_processor=image_processor,
|
227 |
+
tokenizer=tokenizer,
|
228 |
+
chat_template=chat_template,
|
229 |
+
**kwargs,
|
230 |
+
)
|
231 |
+
|
232 |
+
self.region_token = "<|region|>"
|
233 |
+
self.copy_token_start = None
|
234 |
+
self.prepend_raw_region_to_text = prepend_raw_region_to_text
|
235 |
+
self.separate_copy_loss = separate_copy_loss
|
236 |
+
self.copy_start_token = "<|box_start|>"
|
237 |
+
self.copy_end_token = "<|box_end|>"
|
238 |
+
|
239 |
+
# def extract_masks(self, image_size: Tuple[int, int], text: str):
|
240 |
+
# # first, gather region indices from text
|
241 |
+
# region_pattern = re.compile(r"<region>(.*?)</region>")
|
242 |
+
# regions = region_pattern.findall(text)
|
243 |
+
|
244 |
+
# indices = []
|
245 |
+
# copy_pattern = re.compile(r"<\|copy_(\d+)\|>")
|
246 |
+
|
247 |
+
# for region in regions:
|
248 |
+
# # Extract all numbers inside <|copy_X|> tags within the region
|
249 |
+
# numbers = [int(match) for match in copy_pattern.findall(region)]
|
250 |
+
# indices.append(numbers)
|
251 |
+
|
252 |
+
# # Then, convert region indices into masks
|
253 |
+
# masks = []
|
254 |
+
# resized_image_size = image_size
|
255 |
+
# for region in indices:
|
256 |
+
# mask, resized_image_size = self.image_processor.get_mask(
|
257 |
+
# image_size, region
|
258 |
+
# )
|
259 |
+
# masks.append(mask)
|
260 |
+
# return masks, resized_image_size
|
261 |
+
#
|
262 |
+
def extract_masks(self, image_size: Tuple[int, int], text: str):
|
263 |
+
# Match full detect(...) blocks and extract their content
|
264 |
+
# detect_pattern = r"detect\([^)]+objects\s*=\s*\[(.*?)\]\)"
|
265 |
+
detect_pattern = r'detect\(\s*query\s*=\s*"([^"]+)"\s*,\s*objects\s*=\s*\["((?:[^"\\]|\\.)*)"\]\s*\)'
|
266 |
+
obj_region_pattern = r"<obj(\d+)><region>\[.*?\](.*?)</region>"
|
267 |
+
copy_pattern = r"<\|copy_(\d+)\|>"
|
268 |
+
|
269 |
+
# results = defaultdict(list)
|
270 |
+
results = {}
|
271 |
+
|
272 |
+
for detect_match in re.finditer(detect_pattern, text, re.DOTALL):
|
273 |
+
query_str = detect_match.group(1)
|
274 |
+
objects_content = detect_match.group(2)
|
275 |
+
|
276 |
+
for obj_match in re.finditer(
|
277 |
+
obj_region_pattern, objects_content, re.DOTALL
|
278 |
+
):
|
279 |
+
obj_index = int(obj_match.group(1))
|
280 |
+
region_content = obj_match.group(2)
|
281 |
+
copy_ids = [int(m) for m in re.findall(copy_pattern, region_content)]
|
282 |
+
obj_key = f"<obj{obj_index}>"
|
283 |
+
results[obj_key] = (query_str, copy_ids)
|
284 |
+
|
285 |
+
results = dict(results)
|
286 |
+
|
287 |
+
masks = {}
|
288 |
+
resized_image_size = image_size
|
289 |
+
for k, (desc, region) in results.items():
|
290 |
+
mask, resized_image_size = self.image_processor.get_mask(image_size, region)
|
291 |
+
masks[k] = (desc, mask)
|
292 |
+
return masks, resized_image_size
|
293 |
+
|
294 |
+
def __call__(
|
295 |
+
self,
|
296 |
+
images: ImageInput = None,
|
297 |
+
text: Union[
|
298 |
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
299 |
+
] = None,
|
300 |
+
videos: VideoInput = None,
|
301 |
+
regions: Optional[List[dict[str, Union[BBOX, np.ndarray]]]] = None,
|
302 |
+
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
303 |
+
) -> BatchFeature:
|
304 |
+
"""
|
305 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
306 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
307 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
308 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
312 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
313 |
+
tensor. Both channels-first and channels-last formats are supported.
|
314 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
315 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
316 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
317 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
318 |
+
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
319 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
320 |
+
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
321 |
+
regions:
|
322 |
+
either bboxes: List[dict[str, Tuple[int, int, int, int]]]
|
323 |
+
or masks: List[dict[str, np.ndarray[width, height]]]
|
324 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
325 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
326 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
327 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
328 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
329 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
333 |
+
|
334 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
335 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
336 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
337 |
+
`None`).
|
338 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
339 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
340 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
341 |
+
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
342 |
+
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
343 |
+
"""
|
344 |
+
|
345 |
+
output_kwargs = self._merge_kwargs(
|
346 |
+
Qwen2_5_VLProcessorKwargs,
|
347 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
obj_ptrs = None
|
351 |
+
if images is not None:
|
352 |
+
image_inputs = self.image_processor(
|
353 |
+
images=images, videos=None, **output_kwargs["images_kwargs"]
|
354 |
+
)
|
355 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
356 |
+
|
357 |
+
for image in images:
|
358 |
+
assert isinstance(
|
359 |
+
image, Image.Image
|
360 |
+
), "only supporting a single image per row for now"
|
361 |
+
|
362 |
+
if regions is not None:
|
363 |
+
obj_ptrs = [
|
364 |
+
{
|
365 |
+
name: (
|
366 |
+
self.image_processor.get_patch_pointers(image.size, region)
|
367 |
+
if region is not None
|
368 |
+
else np.array([])
|
369 |
+
)
|
370 |
+
for name, region in image_region.items()
|
371 |
+
}
|
372 |
+
for image, image_region in zip(images, regions)
|
373 |
+
]
|
374 |
+
else:
|
375 |
+
image_inputs = {}
|
376 |
+
image_grid_thw = None
|
377 |
+
|
378 |
+
assert videos is None, "video inputs are not supported yet" # TODO
|
379 |
+
if videos is not None:
|
380 |
+
videos_inputs = self.image_processor(
|
381 |
+
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
382 |
+
)
|
383 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
384 |
+
|
385 |
+
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
386 |
+
if isinstance(fps, (int, float)):
|
387 |
+
second_per_grid_ts = [
|
388 |
+
self.image_processor.temporal_patch_size / fps
|
389 |
+
] * len(video_grid_thw)
|
390 |
+
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
391 |
+
second_per_grid_ts = [
|
392 |
+
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
393 |
+
]
|
394 |
+
else:
|
395 |
+
raise ValueError(
|
396 |
+
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
397 |
+
)
|
398 |
+
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
399 |
+
|
400 |
+
else:
|
401 |
+
videos_inputs = {}
|
402 |
+
video_grid_thw = None
|
403 |
+
|
404 |
+
if not isinstance(text, list):
|
405 |
+
text = [text]
|
406 |
+
|
407 |
+
if image_grid_thw is not None:
|
408 |
+
merge_length = self.image_processor.merge_size**2
|
409 |
+
index = 0
|
410 |
+
for i in range(len(text)):
|
411 |
+
while self.image_token in text[i]:
|
412 |
+
text[i] = text[i].replace(
|
413 |
+
self.image_token,
|
414 |
+
"<|placeholder|>"
|
415 |
+
* (image_grid_thw[index].prod() // merge_length),
|
416 |
+
1,
|
417 |
+
)
|
418 |
+
index += 1
|
419 |
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
420 |
+
|
421 |
+
if obj_ptrs is not None:
|
422 |
+
assert regions is not None
|
423 |
+
for i in range(len(text)):
|
424 |
+
image_ptrs = obj_ptrs[i]
|
425 |
+
image_region = regions[i]
|
426 |
+
|
427 |
+
for name, region in image_region.items():
|
428 |
+
region_ptr = image_ptrs[name]
|
429 |
+
|
430 |
+
assert name in text[i], f"object {name} not found in: {text[i]}"
|
431 |
+
|
432 |
+
ptrs_str = "".join([f"<|copy_{j}|>" for j in region_ptr])
|
433 |
+
region_str = self.image_processor.normalize_bbox(
|
434 |
+
image.size, region
|
435 |
+
)
|
436 |
+
if self.separate_copy_loss:
|
437 |
+
ptrs_str = (
|
438 |
+
self.copy_start_token + ptrs_str + self.copy_end_token
|
439 |
+
)
|
440 |
+
out_str = ("<region>" + ptrs_str + "</region>",)
|
441 |
+
if self.prepend_raw_region_to_text:
|
442 |
+
out_str = "<region>" + region_str + ptrs_str + "</region>"
|
443 |
+
|
444 |
+
text[i] = text[i].replace(name, out_str)
|
445 |
+
|
446 |
+
for name in image_region.keys():
|
447 |
+
assert name not in text[i]
|
448 |
+
|
449 |
+
if video_grid_thw is not None:
|
450 |
+
# TODO: support video inputs
|
451 |
+
raise NotImplementedError("video inputs are not yet supported")
|
452 |
+
merge_length = self.image_processor.merge_size**2
|
453 |
+
index = 0
|
454 |
+
for i in range(len(text)):
|
455 |
+
while self.video_token in text[i]:
|
456 |
+
text[i] = text[i].replace(
|
457 |
+
self.video_token,
|
458 |
+
"<patch>"
|
459 |
+
+ "<|placeholder|>"
|
460 |
+
* (video_grid_thw[index].prod() // merge_length)
|
461 |
+
+ "</patch>",
|
462 |
+
1,
|
463 |
+
)
|
464 |
+
index += 1
|
465 |
+
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
466 |
+
|
467 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
468 |
+
|
469 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
470 |
+
|
471 |
+
|
472 |
+
def get_processor(model_name: str, **kwargs):
|
473 |
+
import ipdb; ipdb.set_trace()
|
474 |
+
processor = V1Processor.from_pretrained(model_name, **kwargs)
|
475 |
+
processor.image_processor = Qwen2VLImagePointerProcessor.from_pretrained(
|
476 |
+
model_name, **kwargs
|
477 |
+
)
|
478 |
+
# max_position_tokens = processor.tokenizer.model_max_length
|
479 |
+
# new_tokens = [f"<|copy_{i}|>" for i in range(max_position_tokens)] # too slow
|
480 |
+
processor.tokenizer.orig_vocab_size = len(processor.tokenizer)
|
481 |
+
new_tokens = [f"<|copy_{i}|>" for i in range(30000)]
|
482 |
+
processor.tokenizer.add_tokens(new_tokens)
|
483 |
+
processor.copy_token_start = processor.tokenizer.convert_tokens_to_ids("<|copy_0|>")
|
484 |
+
return processor
|
485 |
+
|
486 |
+
|
487 |
+
# Create a data collator to encode text and image pairs
|
488 |
+
def collate_fn(examples, processor):
|
489 |
+
convs = [row["conversation"] for row in examples]
|
490 |
+
regions = [row["region"] for row in examples]
|
491 |
+
image_sizes = [row["image_size"] for row in examples]
|
492 |
+
|
493 |
+
texts = [
|
494 |
+
processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
495 |
+
for conv in convs
|
496 |
+
] # Prepare texts for processing
|
497 |
+
image_inputs = [
|
498 |
+
process_vision_info(conv)[0][0] for conv in convs
|
499 |
+
] # Process the images to extract inputs
|
500 |
+
image_inputs = [
|
501 |
+
image.resize(image_size) for image, image_size in zip(image_inputs, image_sizes)
|
502 |
+
]
|
503 |
+
|
504 |
+
# Tokenize the texts and process the images
|
505 |
+
batch = processor(
|
506 |
+
text=texts,
|
507 |
+
images=image_inputs,
|
508 |
+
videos=None,
|
509 |
+
regions=regions,
|
510 |
+
padding=True,
|
511 |
+
return_tensors="pt",
|
512 |
+
) # Encode texts and images into tensors
|
513 |
+
|
514 |
+
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
515 |
+
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
516 |
+
labels[labels == processor.tokenizer.pad_token_id] = (
|
517 |
+
-100
|
518 |
+
) # Mask padding tokens in labels
|
519 |
+
|
520 |
+
# Ignore the image token index in the loss computation (model specific)
|
521 |
+
image_tokens = [
|
522 |
+
151652,
|
523 |
+
151653,
|
524 |
+
151655,
|
525 |
+
] # Specific image token IDs for Qwen2VLProcessor
|
526 |
+
|
527 |
+
# Mask image token IDs in the labels
|
528 |
+
for image_token_id in image_tokens:
|
529 |
+
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
|
530 |
+
|
531 |
+
batch["labels"] = labels # Add labels to the batch
|
532 |
+
|
533 |
+
return batch # Return the prepared batch
|
534 |
+
|
535 |
+
if __name__ == '__main__':
|
536 |
+
import ipdb; ipdb.set_trace()
|
processor_config.json
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
{
|
2 |
"prepend_raw_region_to_text": true,
|
3 |
-
"processor_class": "
|
4 |
}
|
|
|
1 |
{
|
2 |
"prepend_raw_region_to_text": true,
|
3 |
+
"processor_class": "V1Processor"
|
4 |
}
|