Update hunyuan.py
Browse files- hunyuan.py +1 -29
hunyuan.py
CHANGED
@@ -41,7 +41,6 @@ from transformers.utils.import_utils import is_torch_fx_available
|
|
41 |
from transformers.generation.utils import GenerateOutput
|
42 |
from .configuration_hunyuan import HunYuanConfig
|
43 |
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
44 |
-
from .vit_model import NaVitForward, VitForward, Vit
|
45 |
|
46 |
|
47 |
if is_flash_attn_2_available():
|
@@ -363,16 +362,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
|
363 |
|
364 |
def __init__(self, config: HunYuanConfig):
|
365 |
super().__init__(config)
|
366 |
-
|
367 |
-
if "-tp" in config.vit_type:
|
368 |
-
config.vit_type = config.vit_type.replace("-tp", "")
|
369 |
-
self.vit_type = config.vit_type
|
370 |
-
if self.vit_type not in ['NaVit', 'EvaVit']:
|
371 |
-
if config.vit_mapping_type == 'mlp':
|
372 |
-
self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
373 |
-
self.vit = Vit(config)
|
374 |
-
else:
|
375 |
-
self.vit = None
|
376 |
self.config = config
|
377 |
self.model = HunYuanModel(config)
|
378 |
self.add_classification_head = config.add_classification_head
|
@@ -643,15 +633,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
|
|
643 |
video_start_id = self.config.video_start_id
|
644 |
video_end_id = self.config.video_end_id
|
645 |
|
646 |
-
if self.vit is not None and imgs is not None:
|
647 |
-
encoder_input = self.model.embed_tokens(input_ids)
|
648 |
-
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
649 |
-
inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
650 |
-
im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
651 |
-
else:
|
652 |
-
inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
653 |
-
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
654 |
-
|
655 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
656 |
output_hidden_states = (
|
657 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -738,15 +719,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
|
|
738 |
if "inputs_embeds" in kwargs:
|
739 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
740 |
|
741 |
-
if self.vit is not None:
|
742 |
-
encoder_input = self.model.embed_tokens(inputs)
|
743 |
-
if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
|
744 |
-
inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
|
745 |
-
self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
|
746 |
-
else:
|
747 |
-
inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
|
748 |
-
self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
|
749 |
-
|
750 |
return super().generate(
|
751 |
inputs=input_ids,
|
752 |
position_ids=position_ids,
|
|
|
41 |
from transformers.generation.utils import GenerateOutput
|
42 |
from .configuration_hunyuan import HunYuanConfig
|
43 |
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
|
|
44 |
|
45 |
|
46 |
if is_flash_attn_2_available():
|
|
|
362 |
|
363 |
def __init__(self, config: HunYuanConfig):
|
364 |
super().__init__(config)
|
365 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
self.config = config
|
367 |
self.model = HunYuanModel(config)
|
368 |
self.add_classification_head = config.add_classification_head
|
|
|
633 |
video_start_id = self.config.video_start_id
|
634 |
video_end_id = self.config.video_end_id
|
635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
637 |
output_hidden_states = (
|
638 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
719 |
if "inputs_embeds" in kwargs:
|
720 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
return super().generate(
|
723 |
inputs=input_ids,
|
724 |
position_ids=position_ids,
|