update modeling_sarashina2_vision.py
Browse files
modeling_sarashina2_vision.py
CHANGED
@@ -70,7 +70,6 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
|
|
70 |
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
|
71 |
self.norm = nn.LayerNorm(config.text_config.hidden_size)
|
72 |
self.llm = LlamaForCausalLM._from_config(config.text_config)
|
73 |
-
self._attn_implementation = config._attn_implementation
|
74 |
|
75 |
# Initialize weights and apply final processing
|
76 |
self.post_init()
|
@@ -113,6 +112,7 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
|
|
113 |
pixel_values: torch.FloatTensor = None,
|
114 |
image_grid_thw: Optional[torch.LongTensor] = None,
|
115 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
116 |
**lm_kwargs,
|
117 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
118 |
"""
|
@@ -130,6 +130,11 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
|
|
130 |
pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
|
131 |
image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
|
132 |
cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
|
|
|
|
|
|
|
|
|
|
|
133 |
Returns:
|
134 |
CausalLMOutputWithPast: The output of the model.
|
135 |
"""
|
@@ -173,6 +178,7 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
|
|
173 |
output_hidden_states=output_hidden_states,
|
174 |
return_dict=return_dict,
|
175 |
cache_position=cache_position,
|
|
|
176 |
**lm_kwargs,
|
177 |
)
|
178 |
|
|
|
70 |
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
|
71 |
self.norm = nn.LayerNorm(config.text_config.hidden_size)
|
72 |
self.llm = LlamaForCausalLM._from_config(config.text_config)
|
|
|
73 |
|
74 |
# Initialize weights and apply final processing
|
75 |
self.post_init()
|
|
|
112 |
pixel_values: torch.FloatTensor = None,
|
113 |
image_grid_thw: Optional[torch.LongTensor] = None,
|
114 |
cache_position: Optional[torch.LongTensor] = None,
|
115 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
116 |
**lm_kwargs,
|
117 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
118 |
"""
|
|
|
130 |
pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
|
131 |
image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
|
132 |
cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
|
133 |
+
logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
134 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
135 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
136 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
137 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
138 |
Returns:
|
139 |
CausalLMOutputWithPast: The output of the model.
|
140 |
"""
|
|
|
178 |
output_hidden_states=output_hidden_states,
|
179 |
return_dict=return_dict,
|
180 |
cache_position=cache_position,
|
181 |
+
logits_to_keep=logits_to_keep,
|
182 |
**lm_kwargs,
|
183 |
)
|
184 |
|