bluelike commited on
Commit
5bdb809
·
1 Parent(s): bbe5a80

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +7 -7
modeling_qwen.py CHANGED
@@ -564,8 +564,11 @@ class QWenModel(QWenPreTrainedModel):
564
 
565
  images = self.visual.encode(images)
566
  assert images.shape[0] == len(images)
 
567
  else:
568
- images = None
 
 
569
 
570
  output_attentions = (
571
  output_attentions
@@ -623,11 +626,6 @@ class QWenModel(QWenPreTrainedModel):
623
 
624
  if inputs_embeds is None:
625
  inputs_embeds = self.wte(input_ids)
626
- if self.training and images == None: # Compatible with plain text data training
627
- fake_images=torch.zeros(1,3,224,224).to(
628
- dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
629
- image_embeds = self.visual(fake_images)
630
- inputs_embeds = inputs_embeds + image_embeds.mean()*0
631
 
632
  if batch_size <= 0:
633
  raise ValueError("batch_size has to be defined and > 0")
@@ -657,7 +655,9 @@ class QWenModel(QWenPreTrainedModel):
657
  rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
658
 
659
  hidden_states = self.drop(hidden_states).clone()
660
- if images is not None:
 
 
661
  for idx, (i, a, b) in enumerate(img_pos):
662
  hidden_states[i][a + 1 : b] = images[idx]
663
  output_shape = input_shape + (hidden_states.size(-1),)
 
564
 
565
  images = self.visual.encode(images)
566
  assert images.shape[0] == len(images)
567
+ fake_images = None
568
  else:
569
+ fake_images=torch.zeros(1,3,224,224).to(
570
+ dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
571
+ images = self.visual(fake_images)
572
 
573
  output_attentions = (
574
  output_attentions
 
626
 
627
  if inputs_embeds is None:
628
  inputs_embeds = self.wte(input_ids)
 
 
 
 
 
629
 
630
  if batch_size <= 0:
631
  raise ValueError("batch_size has to be defined and > 0")
 
655
  rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
656
 
657
  hidden_states = self.drop(hidden_states).clone()
658
+ if fake_images is not None:
659
+ hidden_states = hidden_states + images.mean()*0
660
+ else:
661
  for idx, (i, a, b) in enumerate(img_pos):
662
  hidden_states[i][a + 1 : b] = images[idx]
663
  output_shape = input_shape + (hidden_states.size(-1),)