BUAADreamer commited on
Commit
909a86b
·
verified ·
1 Parent(s): 0f4c447

Update modeling_minicpmo.py

Browse files
Files changed (1) hide show
  1. modeling_minicpmo.py +10 -6
modeling_minicpmo.py CHANGED
@@ -377,10 +377,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
377
  else:
378
  vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
379
 
 
 
380
  vision_hidden_states = [
381
  i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
382
  ]
383
-
384
  bs = len(data["input_ids"])
385
  for i in range(bs):
386
  cur_vs_hs = vision_hidden_states[i]
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
392
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
393
  ).to(vllm_embedding.device)
394
 
395
- cur_vllm_emb.scatter_(
396
  0,
397
  image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
398
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
399
  )
 
400
  elif self.training:
401
- cur_vllm_emb += cur_vs_hs[0].mean() * 0
402
 
403
- return vllm_embedding, vision_hidden_states
404
 
405
  def get_audio_embedding_streaming(self, data):
406
  r"""
@@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
595
  elif self.training:
596
  for i in range(bs):
597
  # dummy audio_embeddings
598
- input_embeddings += audio_embeddings[0].mean() * 0
599
 
600
  return input_embeddings
601
 
@@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
751
  input_ids=None,
752
  pixel_values=None,
753
  tgt_sizes=None,
754
- audio_features=None,
755
  audio_feature_lens=None,
756
  image_bound=None,
757
  audio_bounds=None,
@@ -2655,6 +2658,7 @@ class ConditionalChatTTS(PreTrainedModel):
2655
  """
2656
 
2657
  config_class = ConditionalChatTTSConfig
 
2658
 
2659
  def __init__(self, config: ConditionalChatTTSConfig):
2660
  super().__init__(config)
 
377
  else:
378
  vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
379
 
380
+ new_vllm_embedding = vllm_embedding.clone()
381
+
382
  vision_hidden_states = [
383
  i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
384
  ]
385
+
386
  bs = len(data["input_ids"])
387
  for i in range(bs):
388
  cur_vs_hs = vision_hidden_states[i]
 
394
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
395
  ).to(vllm_embedding.device)
396
 
397
+ new_vllm_embedding[i] = cur_vllm_emb.scatter(
398
  0,
399
  image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
400
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
401
  )
402
+
403
  elif self.training:
404
+ new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
405
 
406
+ return new_vllm_embedding, vision_hidden_states
407
 
408
  def get_audio_embedding_streaming(self, data):
409
  r"""
 
598
  elif self.training:
599
  for i in range(bs):
600
  # dummy audio_embeddings
601
+ input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
602
 
603
  return input_embeddings
604
 
 
754
  input_ids=None,
755
  pixel_values=None,
756
  tgt_sizes=None,
757
+ audio_features=[],
758
  audio_feature_lens=None,
759
  image_bound=None,
760
  audio_bounds=None,
 
2658
  """
2659
 
2660
  config_class = ConditionalChatTTSConfig
2661
+ _no_split_modules = []
2662
 
2663
  def __init__(self, config: ConditionalChatTTSConfig):
2664
  super().__init__(config)