ryanzhangfan
commited on
Commit
•
dabb3aa
1
Parent(s):
32e9c18
add support for batch image generation
Browse files- processing_emu3.py +7 -4
- utils_emu3.py +11 -5
processing_emu3.py
CHANGED
@@ -84,7 +84,7 @@ class Emu3Processor(ProcessorMixin):
|
|
84 |
image: Optional[Image.Image | List[Image.Image]] = None,
|
85 |
*,
|
86 |
mode: str = "G",
|
87 |
-
ratio: str = "1:1",
|
88 |
image_area: int = 518400,
|
89 |
**kwargs,
|
90 |
) -> BatchFeature:
|
@@ -129,8 +129,11 @@ class Emu3Processor(ProcessorMixin):
|
|
129 |
if image is not None:
|
130 |
raise ValueError("You have to specify only `text` in generation mode")
|
131 |
|
132 |
-
if
|
133 |
-
|
|
|
|
|
|
|
134 |
else:
|
135 |
if image is None:
|
136 |
raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
|
@@ -165,7 +168,7 @@ class Emu3Processor(ProcessorMixin):
|
|
165 |
)
|
166 |
prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
|
167 |
else:
|
168 |
-
h, w = self.calculate_generate_size(ratio, image_area, self.vision_tokenizer.spatial_scale_factor)
|
169 |
image_prompt = (
|
170 |
self.tokenizer.boi_token +
|
171 |
self.prefix_template.format(H=h, W=w) +
|
|
|
84 |
image: Optional[Image.Image | List[Image.Image]] = None,
|
85 |
*,
|
86 |
mode: str = "G",
|
87 |
+
ratio: str | List[str] = "1:1",
|
88 |
image_area: int = 518400,
|
89 |
**kwargs,
|
90 |
) -> BatchFeature:
|
|
|
129 |
if image is not None:
|
130 |
raise ValueError("You have to specify only `text` in generation mode")
|
131 |
|
132 |
+
if isinstance(ratio, str):
|
133 |
+
ratio = [ratio] * len(text)
|
134 |
+
|
135 |
+
if len(ratio) != len(text):
|
136 |
+
raise ValueError("ratio number must match text number")
|
137 |
else:
|
138 |
if image is None:
|
139 |
raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
|
|
|
168 |
)
|
169 |
prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
|
170 |
else:
|
171 |
+
h, w = self.calculate_generate_size(ratio[idx], image_area, self.vision_tokenizer.spatial_scale_factor)
|
172 |
image_prompt = (
|
173 |
self.tokenizer.boi_token +
|
174 |
self.prefix_template.format(H=h, W=w) +
|
utils_emu3.py
CHANGED
@@ -47,16 +47,22 @@ class Emu3PrefixConstrainedLogitsHelper:
|
|
47 |
position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
|
48 |
self.offset_cache[batch_id] = position
|
49 |
|
|
|
|
|
|
|
50 |
offset = input_ids.shape[0] - self.offset_cache[batch_id]
|
51 |
-
|
|
|
|
|
|
|
52 |
return (self.eol_token, )
|
53 |
-
elif offset == (
|
54 |
return (self.eof_token, )
|
55 |
-
elif offset == (
|
56 |
return (self.eoi_token, )
|
57 |
-
elif offset == (
|
58 |
return (self.eos_token, )
|
59 |
-
elif offset > (
|
60 |
return (self.pad_token, )
|
61 |
else:
|
62 |
return self.visual_tokens
|
|
|
47 |
position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
|
48 |
self.offset_cache[batch_id] = position
|
49 |
|
50 |
+
height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
|
51 |
+
width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]
|
52 |
+
|
53 |
offset = input_ids.shape[0] - self.offset_cache[batch_id]
|
54 |
+
height = height.to(offset.device)
|
55 |
+
width = width.to(offset.device)
|
56 |
+
|
57 |
+
if offset % (width + 1) == 0:
|
58 |
return (self.eol_token, )
|
59 |
+
elif offset == (width + 1) * height + 1:
|
60 |
return (self.eof_token, )
|
61 |
+
elif offset == (width + 1) * height + 2:
|
62 |
return (self.eoi_token, )
|
63 |
+
elif offset == (width + 1) * height + 3:
|
64 |
return (self.eos_token, )
|
65 |
+
elif offset > (width + 1) * height + 3:
|
66 |
return (self.pad_token, )
|
67 |
else:
|
68 |
return self.visual_tokens
|