toshi-456 commited on
Commit
021aa10
·
verified ·
1 Parent(s): 05f6e7d

update modeling_sarashina2_vision.py

Browse files
Files changed (1) hide show
  1. modeling_sarashina2_vision.py +7 -1
modeling_sarashina2_vision.py CHANGED
@@ -70,7 +70,6 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
70
  self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
71
  self.norm = nn.LayerNorm(config.text_config.hidden_size)
72
  self.llm = LlamaForCausalLM._from_config(config.text_config)
73
- self._attn_implementation = config._attn_implementation
74
 
75
  # Initialize weights and apply final processing
76
  self.post_init()
@@ -113,6 +112,7 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
113
  pixel_values: torch.FloatTensor = None,
114
  image_grid_thw: Optional[torch.LongTensor] = None,
115
  cache_position: Optional[torch.LongTensor] = None,
 
116
  **lm_kwargs,
117
  ) -> Union[Tuple, CausalLMOutputWithPast]:
118
  """
@@ -130,6 +130,11 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
130
  pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
131
  image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
132
  cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
 
 
 
 
 
133
  Returns:
134
  CausalLMOutputWithPast: The output of the model.
135
  """
@@ -173,6 +178,7 @@ class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMix
173
  output_hidden_states=output_hidden_states,
174
  return_dict=return_dict,
175
  cache_position=cache_position,
 
176
  **lm_kwargs,
177
  )
178
 
 
70
  self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
71
  self.norm = nn.LayerNorm(config.text_config.hidden_size)
72
  self.llm = LlamaForCausalLM._from_config(config.text_config)
 
73
 
74
  # Initialize weights and apply final processing
75
  self.post_init()
 
112
  pixel_values: torch.FloatTensor = None,
113
  image_grid_thw: Optional[torch.LongTensor] = None,
114
  cache_position: Optional[torch.LongTensor] = None,
115
+ logits_to_keep: Union[int, torch.Tensor] = 0,
116
  **lm_kwargs,
117
  ) -> Union[Tuple, CausalLMOutputWithPast]:
118
  """
 
130
  pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
131
  image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
132
  cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
133
+ logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
134
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
135
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
136
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
137
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
138
  Returns:
139
  CausalLMOutputWithPast: The output of the model.
140
  """
 
178
  output_hidden_states=output_hidden_states,
179
  return_dict=return_dict,
180
  cache_position=cache_position,
181
+ logits_to_keep=logits_to_keep,
182
  **lm_kwargs,
183
  )
184