Spaces:
Running
on
Zero
Running
on
Zero
correct sizing, and limit
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import random
|
4 |
import spaces
|
5 |
from PIL import Image
|
6 |
import torch
|
@@ -24,63 +23,23 @@ model = AutoModel.from_pretrained(
|
|
24 |
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
25 |
|
26 |
MAX_SEED = np.iinfo(np.int16).max
|
27 |
-
|
28 |
DEFAULT_POSITIVE_PROMPT = None
|
29 |
DEFAULT_NEGATIVE_PROMPT = None
|
30 |
|
31 |
def _ensure_pil(x):
|
|
|
32 |
if isinstance(x, Image.Image):
|
33 |
return x
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
return Image.fromarray(x)
|
45 |
-
except Exception:
|
46 |
-
pass
|
47 |
-
raise TypeError("Unsupported image type returned by pipeline; expected PIL or array/torch image.")
|
48 |
-
|
49 |
-
def resize_to_target(img: Image.Image, tw: int, th: int, mode: str = "fit"):
|
50 |
-
"""Return a PIL image of exactly (tw, th) using the selected mode."""
|
51 |
-
mode = (mode or "fit").lower()
|
52 |
-
# safety
|
53 |
-
tw = int(max(1, tw))
|
54 |
-
th = int(max(1, th))
|
55 |
-
|
56 |
-
if mode == "stretch":
|
57 |
-
return img.resize((tw, th), resample=Image.Resampling.LANCZOS)
|
58 |
-
|
59 |
-
iw, ih = img.size
|
60 |
-
if iw == 0 or ih == 0:
|
61 |
-
return img
|
62 |
-
|
63 |
-
src_ratio = iw / ih
|
64 |
-
tgt_ratio = tw / th
|
65 |
-
|
66 |
-
if mode == "fill":
|
67 |
-
# scale so that image fully covers target, then center-crop
|
68 |
-
scale = max(tw / iw, th / ih)
|
69 |
-
nw, nh = int(round(iw * scale)), int(round(ih * scale))
|
70 |
-
resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
|
71 |
-
left = (nw - tw) // 2
|
72 |
-
top = (nh - th) // 2
|
73 |
-
return resized.crop((left, top, left + tw, top + th))
|
74 |
-
else:
|
75 |
-
# "fit": letterbox to target
|
76 |
-
scale = min(tw / iw, th / ih)
|
77 |
-
nw, nh = int(round(iw * scale)), int(round(ih * scale))
|
78 |
-
resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
|
79 |
-
canvas = Image.new("RGB", (tw, th), (0, 0, 0))
|
80 |
-
left = (tw - nw) // 2
|
81 |
-
top = (th - nh) // 2
|
82 |
-
canvas.paste(resized, (left, top))
|
83 |
-
return canvas
|
84 |
|
85 |
@spaces.GPU(duration=300)
|
86 |
def infer(
|
@@ -91,14 +50,13 @@ def infer(
|
|
91 |
num_inference_steps=28,
|
92 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
93 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
94 |
-
resize_mode="fit (letterbox)",
|
95 |
progress=gr.Progress(track_tqdm=True),
|
96 |
):
|
|
|
97 |
if prompt in [None, ""]:
|
98 |
gr.Warning("⚠️ Please enter a prompt!")
|
99 |
return None
|
100 |
|
101 |
-
# Generate at (height, width). Some models may return bucketed sizes.
|
102 |
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
103 |
imgs = pipeline.generate_image(
|
104 |
prompt,
|
@@ -116,23 +74,18 @@ def infer(
|
|
116 |
progress=True,
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
# Force output to exactly Width x Height based on user preference
|
122 |
-
mode_key = "fit" if "fit" in resize_mode else ("fill" if "fill" in resize_mode else "stretch")
|
123 |
-
out = resize_to_target(img, int(width), int(height), mode=mode_key)
|
124 |
-
return out
|
125 |
|
126 |
css = """
|
127 |
#col-container {
|
128 |
margin: 0 auto;
|
129 |
-
max-width:
|
130 |
}
|
131 |
"""
|
132 |
|
133 |
with gr.Blocks(css=css) as demo:
|
134 |
with gr.Column(elem_id="col-container"):
|
135 |
-
gr.Markdown("# NextStep-1-Large —
|
136 |
|
137 |
with gr.Row():
|
138 |
prompt = gr.Text(
|
@@ -180,26 +133,19 @@ with gr.Blocks(css=css) as demo:
|
|
180 |
width = gr.Slider(
|
181 |
label="Width",
|
182 |
minimum=256,
|
183 |
-
maximum=
|
184 |
step=64,
|
185 |
-
value=
|
186 |
)
|
187 |
height = gr.Slider(
|
188 |
label="Height",
|
189 |
minimum=256,
|
190 |
-
maximum=
|
191 |
step=64,
|
192 |
-
value=
|
193 |
)
|
194 |
-
resize_mode = gr.Radio(
|
195 |
-
label="Resize mode (final output)",
|
196 |
-
choices=["fit (letterbox)", "fill (center-crop)", "stretch"],
|
197 |
-
value="fit (letterbox)",
|
198 |
-
)
|
199 |
|
200 |
with gr.Row():
|
201 |
-
# Remove fixed height so the component can display any size; it will scale in the UI,
|
202 |
-
# but the returned image file is exactly width x height.
|
203 |
result_1 = gr.Image(
|
204 |
label="Result",
|
205 |
show_label=True,
|
@@ -208,29 +154,25 @@ with gr.Blocks(css=css) as demo:
|
|
208 |
format="png",
|
209 |
)
|
210 |
|
211 |
-
#
|
212 |
examples = [
|
213 |
-
# [prompt, seed, width, height, steps, positive, negative, resize_mode]
|
214 |
[
|
215 |
-
"
|
216 |
-
|
217 |
-
"
|
218 |
-
"
|
219 |
-
"fit (letterbox)",
|
220 |
],
|
221 |
[
|
222 |
-
"
|
223 |
-
|
224 |
-
"
|
225 |
-
"
|
226 |
-
"fill (center-crop)",
|
227 |
],
|
228 |
[
|
229 |
-
"
|
230 |
-
|
231 |
-
"
|
232 |
-
"
|
233 |
-
"stretch",
|
234 |
],
|
235 |
]
|
236 |
|
@@ -244,9 +186,8 @@ with gr.Blocks(css=css) as demo:
|
|
244 |
num_inference_steps,
|
245 |
positive_prompt,
|
246 |
negative_prompt,
|
247 |
-
resize_mode,
|
248 |
],
|
249 |
-
label="Click & Fill Examples",
|
250 |
)
|
251 |
|
252 |
def show_result():
|
@@ -263,7 +204,6 @@ with gr.Blocks(css=css) as demo:
|
|
263 |
num_inference_steps,
|
264 |
positive_prompt,
|
265 |
negative_prompt,
|
266 |
-
resize_mode,
|
267 |
],
|
268 |
outputs=[result_1],
|
269 |
)
|
@@ -271,5 +211,4 @@ with gr.Blocks(css=css) as demo:
|
|
271 |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
-
# Set share=True if you want a public link
|
275 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
import spaces
|
4 |
from PIL import Image
|
5 |
import torch
|
|
|
23 |
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
24 |
|
25 |
MAX_SEED = np.iinfo(np.int16).max
|
|
|
26 |
DEFAULT_POSITIVE_PROMPT = None
|
27 |
DEFAULT_NEGATIVE_PROMPT = None
|
28 |
|
29 |
def _ensure_pil(x):
|
30 |
+
"""Ensure returned image is a PIL.Image.Image."""
|
31 |
if isinstance(x, Image.Image):
|
32 |
return x
|
33 |
+
import numpy as np
|
34 |
+
if hasattr(x, "detach"):
|
35 |
+
x = x.detach().float().clamp(0, 1).cpu().numpy()
|
36 |
+
if isinstance(x, np.ndarray):
|
37 |
+
if x.dtype != np.uint8:
|
38 |
+
x = (x * 255.0).clip(0, 255).astype(np.uint8)
|
39 |
+
if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC
|
40 |
+
x = np.moveaxis(x, 0, -1)
|
41 |
+
return Image.fromarray(x)
|
42 |
+
raise TypeError("Unsupported image type returned by pipeline.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
@spaces.GPU(duration=300)
|
45 |
def infer(
|
|
|
50 |
num_inference_steps=28,
|
51 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
52 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
|
|
53 |
progress=gr.Progress(track_tqdm=True),
|
54 |
):
|
55 |
+
"""Run inference at exactly (width, height)."""
|
56 |
if prompt in [None, ""]:
|
57 |
gr.Warning("⚠️ Please enter a prompt!")
|
58 |
return None
|
59 |
|
|
|
60 |
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
61 |
imgs = pipeline.generate_image(
|
62 |
prompt,
|
|
|
74 |
progress=True,
|
75 |
)
|
76 |
|
77 |
+
return _ensure_pil(imgs[0]) # Return raw output exactly as generated
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
css = """
|
80 |
#col-container {
|
81 |
margin: 0 auto;
|
82 |
+
max-width: 800px;
|
83 |
}
|
84 |
"""
|
85 |
|
86 |
with gr.Blocks(css=css) as demo:
|
87 |
with gr.Column(elem_id="col-container"):
|
88 |
+
gr.Markdown("# NextStep-1-Large — Exact Output Size")
|
89 |
|
90 |
with gr.Row():
|
91 |
prompt = gr.Text(
|
|
|
133 |
width = gr.Slider(
|
134 |
label="Width",
|
135 |
minimum=256,
|
136 |
+
maximum=512,
|
137 |
step=64,
|
138 |
+
value=512,
|
139 |
)
|
140 |
height = gr.Slider(
|
141 |
label="Height",
|
142 |
minimum=256,
|
143 |
+
maximum=512,
|
144 |
step=64,
|
145 |
+
value=512,
|
146 |
)
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
with gr.Row():
|
|
|
|
|
149 |
result_1 = gr.Image(
|
150 |
label="Result",
|
151 |
show_label=True,
|
|
|
154 |
format="png",
|
155 |
)
|
156 |
|
157 |
+
# Click & Fill Examples (all <=512px)
|
158 |
examples = [
|
|
|
159 |
[
|
160 |
+
"A cozy wooden cabin by a frozen lake, northern lights in the sky",
|
161 |
+
123, 512, 512, 28,
|
162 |
+
"photorealistic, cinematic lighting, starry night, glowing reflections",
|
163 |
+
"low-res, distorted, extra objects"
|
|
|
164 |
],
|
165 |
[
|
166 |
+
"Futuristic city skyline at sunset, flying cars, neon reflections",
|
167 |
+
456, 512, 384, 30,
|
168 |
+
"detailed, vibrant, cinematic, sharp edges",
|
169 |
+
"washed out, cartoon, blurry"
|
|
|
170 |
],
|
171 |
[
|
172 |
+
"Close-up of a rare orchid in a greenhouse with soft morning light",
|
173 |
+
789, 384, 512, 32,
|
174 |
+
"macro lens effect, ultra-detailed petals, dew drops",
|
175 |
+
"grainy, noisy, oversaturated"
|
|
|
176 |
],
|
177 |
]
|
178 |
|
|
|
186 |
num_inference_steps,
|
187 |
positive_prompt,
|
188 |
negative_prompt,
|
|
|
189 |
],
|
190 |
+
label="Click & Fill Examples (Exact Size)",
|
191 |
)
|
192 |
|
193 |
def show_result():
|
|
|
204 |
num_inference_steps,
|
205 |
positive_prompt,
|
206 |
negative_prompt,
|
|
|
207 |
],
|
208 |
outputs=[result_1],
|
209 |
)
|
|
|
211 |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
|
212 |
|
213 |
if __name__ == "__main__":
|
|
|
214 |
demo.launch()
|