Adapt inputs_merger() Function for Scenarios Without Image Input
Browse filesAn if clause has been added to ensure that the merging process only occurs if image_hidden_state is not None.
- modeling_vmistral.py +30 -28
modeling_vmistral.py
CHANGED
|
@@ -1372,36 +1372,38 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
| 1372 |
batch_size = input_ids.size(0)
|
| 1373 |
|
| 1374 |
if inputs_embeds is not None:
|
| 1375 |
-
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
| 1376 |
-
vision_hidden_size = image_hidden_states.shape[2]
|
| 1377 |
new_inputs_embeds = inputs_embeds.clone()
|
| 1378 |
-
|
| 1379 |
-
|
| 1380 |
-
|
| 1381 |
-
|
| 1382 |
-
# Get the number of images for
|
| 1383 |
-
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
|
| 1390 |
-
|
| 1391 |
-
|
| 1392 |
-
|
| 1393 |
-
|
| 1394 |
-
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
-
|
| 1398 |
-
|
| 1399 |
-
|
| 1400 |
-
|
| 1401 |
-
|
| 1402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
)
|
| 1404 |
-
)
|
| 1405 |
|
| 1406 |
return_dict = {}
|
| 1407 |
if inputs_embeds is not None:
|
|
|
|
| 1372 |
batch_size = input_ids.size(0)
|
| 1373 |
|
| 1374 |
if inputs_embeds is not None:
|
|
|
|
|
|
|
| 1375 |
new_inputs_embeds = inputs_embeds.clone()
|
| 1376 |
+
|
| 1377 |
+
if image_hidden_states is not None:
|
| 1378 |
+
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
| 1379 |
+
vision_hidden_size = image_hidden_states.shape[2]
|
| 1380 |
+
# Get the number of images for each example
|
| 1381 |
+
num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
|
| 1382 |
+
cum_num_images = num_images.cumsum(dim=-1)
|
| 1383 |
+
for batch_idx in range(batch_size):
|
| 1384 |
+
# Get the number of images for this particular example
|
| 1385 |
+
example_num_images = num_images[batch_idx]
|
| 1386 |
+
# Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
|
| 1387 |
+
start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
|
| 1388 |
+
end = cum_num_images[batch_idx]
|
| 1389 |
+
example_true_image_hidden_states = image_hidden_states[start:end]
|
| 1390 |
+
if (
|
| 1391 |
+
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
|
| 1392 |
+
!= example_num_images * vision_pipeline_output_seq_len
|
| 1393 |
+
):
|
| 1394 |
+
raise ValueError(
|
| 1395 |
+
"new_inputs_embeds to replace has shape[0]:"
|
| 1396 |
+
f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
|
| 1397 |
+
" should have shape[0]:"
|
| 1398 |
+
f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
|
| 1399 |
+
)
|
| 1400 |
+
# Insert the image_hidden_states
|
| 1401 |
+
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
|
| 1402 |
+
example_true_image_hidden_states.view(
|
| 1403 |
+
example_num_images * vision_pipeline_output_seq_len,
|
| 1404 |
+
vision_hidden_size,
|
| 1405 |
+
)
|
| 1406 |
)
|
|
|
|
| 1407 |
|
| 1408 |
return_dict = {}
|
| 1409 |
if inputs_embeds is not None:
|