BAAI
/

ryanzhangfan commited on
Commit
dabb3aa
1 Parent(s): 32e9c18

add support for batch image generation

Browse files
Files changed (2) hide show
  1. processing_emu3.py +7 -4
  2. utils_emu3.py +11 -5
processing_emu3.py CHANGED
@@ -84,7 +84,7 @@ class Emu3Processor(ProcessorMixin):
84
  image: Optional[Image.Image | List[Image.Image]] = None,
85
  *,
86
  mode: str = "G",
87
- ratio: str = "1:1",
88
  image_area: int = 518400,
89
  **kwargs,
90
  ) -> BatchFeature:
@@ -129,8 +129,11 @@ class Emu3Processor(ProcessorMixin):
129
  if image is not None:
130
  raise ValueError("You have to specify only `text` in generation mode")
131
 
132
- if len(text) > 1:
133
- raise ValueError("`text` can only be `str` in generation mode")
 
 
 
134
  else:
135
  if image is None:
136
  raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
@@ -165,7 +168,7 @@ class Emu3Processor(ProcessorMixin):
165
  )
166
  prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
167
  else:
168
- h, w = self.calculate_generate_size(ratio, image_area, self.vision_tokenizer.spatial_scale_factor)
169
  image_prompt = (
170
  self.tokenizer.boi_token +
171
  self.prefix_template.format(H=h, W=w) +
 
84
  image: Optional[Image.Image | List[Image.Image]] = None,
85
  *,
86
  mode: str = "G",
87
+ ratio: str | List[str] = "1:1",
88
  image_area: int = 518400,
89
  **kwargs,
90
  ) -> BatchFeature:
 
129
  if image is not None:
130
  raise ValueError("You have to specify only `text` in generation mode")
131
 
132
+ if isinstance(ratio, str):
133
+ ratio = [ratio] * len(text)
134
+
135
+ if len(ratio) != len(text):
136
+ raise ValueError("ratio number must match text number")
137
  else:
138
  if image is None:
139
  raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
 
168
  )
169
  prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
170
  else:
171
+ h, w = self.calculate_generate_size(ratio[idx], image_area, self.vision_tokenizer.spatial_scale_factor)
172
  image_prompt = (
173
  self.tokenizer.boi_token +
174
  self.prefix_template.format(H=h, W=w) +
utils_emu3.py CHANGED
@@ -47,16 +47,22 @@ class Emu3PrefixConstrainedLogitsHelper:
47
  position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
48
  self.offset_cache[batch_id] = position
49
 
 
 
 
50
  offset = input_ids.shape[0] - self.offset_cache[batch_id]
51
- if offset % (self.width + 1) == 0:
 
 
 
52
  return (self.eol_token, )
53
- elif offset == (self.width + 1) * self.height + 1:
54
  return (self.eof_token, )
55
- elif offset == (self.width + 1) * self.height + 2:
56
  return (self.eoi_token, )
57
- elif offset == (self.width + 1) * self.height + 3:
58
  return (self.eos_token, )
59
- elif offset > (self.width + 1) * self.height + 3:
60
  return (self.pad_token, )
61
  else:
62
  return self.visual_tokens
 
47
  position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
48
  self.offset_cache[batch_id] = position
49
 
50
+ height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
51
+ width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]
52
+
53
  offset = input_ids.shape[0] - self.offset_cache[batch_id]
54
+ height = height.to(offset.device)
55
+ width = width.to(offset.device)
56
+
57
+ if offset % (width + 1) == 0:
58
  return (self.eol_token, )
59
+ elif offset == (width + 1) * height + 1:
60
  return (self.eof_token, )
61
+ elif offset == (width + 1) * height + 2:
62
  return (self.eoi_token, )
63
+ elif offset == (width + 1) * height + 3:
64
  return (self.eos_token, )
65
+ elif offset > (width + 1) * height + 3:
66
  return (self.pad_token, )
67
  else:
68
  return self.visual_tokens