Update modeling_maira2.py
Browse files- modeling_maira2.py +4 -0
modeling_maira2.py
CHANGED
@@ -88,6 +88,10 @@ class Maira2ForConditionalGeneration(LlavaForConditionalGeneration):
|
|
88 |
image_features = self.multi_modal_projector(selected_image_feature)
|
89 |
return image_features # type: ignore[no-any-return]
|
90 |
|
|
|
|
|
|
|
|
|
91 |
# modification from original, added forward from transformers 4.46 to prevent new preprocessing
|
92 |
def forward(
|
93 |
self,
|
|
|
88 |
image_features = self.multi_modal_projector(selected_image_feature)
|
89 |
return image_features # type: ignore[no-any-return]
|
90 |
|
91 |
+
# modification from original, added get_input_embeddings from transformers 4.52 to prevent issues related llava model structure changes
|
92 |
+
def get_input_embeddings(self):
|
93 |
+
return self.language_model.get_input_embeddings()
|
94 |
+
|
95 |
# modification from original, added forward from transformers 4.46 to prevent new preprocessing
|
96 |
def forward(
|
97 |
self,
|