manaestras commited on
Commit
8e9085f
·
verified ·
1 Parent(s): a7d6ead

Update hunyuan.py

Browse files
Files changed (1) hide show
  1. hunyuan.py +1 -29
hunyuan.py CHANGED
@@ -41,7 +41,6 @@ from transformers.utils.import_utils import is_torch_fx_available
41
  from transformers.generation.utils import GenerateOutput
42
  from .configuration_hunyuan import HunYuanConfig
43
  from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
44
- from .vit_model import NaVitForward, VitForward, Vit
45
 
46
 
47
  if is_flash_attn_2_available():
@@ -363,16 +362,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
363
 
364
  def __init__(self, config: HunYuanConfig):
365
  super().__init__(config)
366
- if config.vit_path is not None:
367
- if "-tp" in config.vit_type:
368
- config.vit_type = config.vit_type.replace("-tp", "")
369
- self.vit_type = config.vit_type
370
- if self.vit_type not in ['NaVit', 'EvaVit']:
371
- if config.vit_mapping_type == 'mlp':
372
- self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
373
- self.vit = Vit(config)
374
- else:
375
- self.vit = None
376
  self.config = config
377
  self.model = HunYuanModel(config)
378
  self.add_classification_head = config.add_classification_head
@@ -643,15 +633,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
643
  video_start_id = self.config.video_start_id
644
  video_end_id = self.config.video_end_id
645
 
646
- if self.vit is not None and imgs is not None:
647
- encoder_input = self.model.embed_tokens(input_ids)
648
- if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
649
- inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
650
- im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
651
- else:
652
- inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
653
- self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
654
-
655
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
656
  output_hidden_states = (
657
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -738,15 +719,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
738
  if "inputs_embeds" in kwargs:
739
  raise NotImplementedError("`inputs_embeds` is not supported")
740
 
741
- if self.vit is not None:
742
- encoder_input = self.model.embed_tokens(inputs)
743
- if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
744
- inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
745
- self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
746
- else:
747
- inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
748
- self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
749
-
750
  return super().generate(
751
  inputs=input_ids,
752
  position_ids=position_ids,
 
41
  from transformers.generation.utils import GenerateOutput
42
  from .configuration_hunyuan import HunYuanConfig
43
  from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
 
44
 
45
 
46
  if is_flash_attn_2_available():
 
362
 
363
  def __init__(self, config: HunYuanConfig):
364
  super().__init__(config)
365
+
 
 
 
 
 
 
 
 
 
366
  self.config = config
367
  self.model = HunYuanModel(config)
368
  self.add_classification_head = config.add_classification_head
 
633
  video_start_id = self.config.video_start_id
634
  video_end_id = self.config.video_end_id
635
 
 
 
 
 
 
 
 
 
 
636
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
637
  output_hidden_states = (
638
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
719
  if "inputs_embeds" in kwargs:
720
  raise NotImplementedError("`inputs_embeds` is not supported")
721
 
 
 
 
 
 
 
 
 
 
722
  return super().generate(
723
  inputs=input_ids,
724
  position_ids=position_ids,