bmarci commited on
Commit
3669017
·
1 Parent(s): 940ab95

correct sizing, and limit

Browse files
Files changed (1) hide show
  1. app.py +33 -94
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  import spaces
5
  from PIL import Image
6
  import torch
@@ -24,63 +23,23 @@ model = AutoModel.from_pretrained(
24
  pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
25
 
26
  MAX_SEED = np.iinfo(np.int16).max
27
-
28
  DEFAULT_POSITIVE_PROMPT = None
29
  DEFAULT_NEGATIVE_PROMPT = None
30
 
31
  def _ensure_pil(x):
 
32
  if isinstance(x, Image.Image):
33
  return x
34
- # try common conversions (numpy / torch -> PIL)
35
- try:
36
- import numpy as np
37
- if hasattr(x, "detach"):
38
- x = x.detach().float().clamp(0, 1).cpu().numpy()
39
- if isinstance(x, np.ndarray):
40
- if x.dtype != np.uint8:
41
- x = (x * 255.0).clip(0, 255).astype(np.uint8)
42
- if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC
43
- x = np.moveaxis(x, 0, -1)
44
- return Image.fromarray(x)
45
- except Exception:
46
- pass
47
- raise TypeError("Unsupported image type returned by pipeline; expected PIL or array/torch image.")
48
-
49
- def resize_to_target(img: Image.Image, tw: int, th: int, mode: str = "fit"):
50
- """Return a PIL image of exactly (tw, th) using the selected mode."""
51
- mode = (mode or "fit").lower()
52
- # safety
53
- tw = int(max(1, tw))
54
- th = int(max(1, th))
55
-
56
- if mode == "stretch":
57
- return img.resize((tw, th), resample=Image.Resampling.LANCZOS)
58
-
59
- iw, ih = img.size
60
- if iw == 0 or ih == 0:
61
- return img
62
-
63
- src_ratio = iw / ih
64
- tgt_ratio = tw / th
65
-
66
- if mode == "fill":
67
- # scale so that image fully covers target, then center-crop
68
- scale = max(tw / iw, th / ih)
69
- nw, nh = int(round(iw * scale)), int(round(ih * scale))
70
- resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
71
- left = (nw - tw) // 2
72
- top = (nh - th) // 2
73
- return resized.crop((left, top, left + tw, top + th))
74
- else:
75
- # "fit": letterbox to target
76
- scale = min(tw / iw, th / ih)
77
- nw, nh = int(round(iw * scale)), int(round(ih * scale))
78
- resized = img.resize((nw, nh), resample=Image.Resampling.LANCZOS)
79
- canvas = Image.new("RGB", (tw, th), (0, 0, 0))
80
- left = (tw - nw) // 2
81
- top = (th - nh) // 2
82
- canvas.paste(resized, (left, top))
83
- return canvas
84
 
85
  @spaces.GPU(duration=300)
86
  def infer(
@@ -91,14 +50,13 @@ def infer(
91
  num_inference_steps=28,
92
  positive_prompt=DEFAULT_POSITIVE_PROMPT,
93
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
94
- resize_mode="fit (letterbox)",
95
  progress=gr.Progress(track_tqdm=True),
96
  ):
 
97
  if prompt in [None, ""]:
98
  gr.Warning("⚠️ Please enter a prompt!")
99
  return None
100
 
101
- # Generate at (height, width). Some models may return bucketed sizes.
102
  with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
103
  imgs = pipeline.generate_image(
104
  prompt,
@@ -116,23 +74,18 @@ def infer(
116
  progress=True,
117
  )
118
 
119
- img = _ensure_pil(imgs[0])
120
-
121
- # Force output to exactly Width x Height based on user preference
122
- mode_key = "fit" if "fit" in resize_mode else ("fill" if "fill" in resize_mode else "stretch")
123
- out = resize_to_target(img, int(width), int(height), mode=mode_key)
124
- return out
125
 
126
  css = """
127
  #col-container {
128
  margin: 0 auto;
129
- max-width: 900px;
130
  }
131
  """
132
 
133
  with gr.Blocks(css=css) as demo:
134
  with gr.Column(elem_id="col-container"):
135
- gr.Markdown("# NextStep-1-Large — Edit & Size-Adaptive Output")
136
 
137
  with gr.Row():
138
  prompt = gr.Text(
@@ -180,26 +133,19 @@ with gr.Blocks(css=css) as demo:
180
  width = gr.Slider(
181
  label="Width",
182
  minimum=256,
183
- maximum=1536,
184
  step=64,
185
- value=768,
186
  )
187
  height = gr.Slider(
188
  label="Height",
189
  minimum=256,
190
- maximum=1536,
191
  step=64,
192
- value=1024,
193
  )
194
- resize_mode = gr.Radio(
195
- label="Resize mode (final output)",
196
- choices=["fit (letterbox)", "fill (center-crop)", "stretch"],
197
- value="fit (letterbox)",
198
- )
199
 
200
  with gr.Row():
201
- # Remove fixed height so the component can display any size; it will scale in the UI,
202
- # but the returned image file is exactly width x height.
203
  result_1 = gr.Image(
204
  label="Result",
205
  show_label=True,
@@ -208,29 +154,25 @@ with gr.Blocks(css=css) as demo:
208
  format="png",
209
  )
210
 
211
- # --- Click & Fill Examples ---
212
  examples = [
213
- # [prompt, seed, width, height, steps, positive, negative, resize_mode]
214
  [
215
- "Sunrise over terraced rice fields, mist in the valley, lone farmer with conical hat",
216
- 101, 832, 1216, 28,
217
- "soft god rays, crisp details, photorealistic, golden hour",
218
- "blurry, over-saturated, artifacts",
219
- "fit (letterbox)",
220
  ],
221
  [
222
- "Glass lighthouse on a stormy cliff, waves crashing, bioluminescent algae trails",
223
- 777, 1024, 768, 32,
224
- "cinematic lighting, long exposure water, detailed foam",
225
- "cartoon, low-res, extra limbs",
226
- "fill (center-crop)",
227
  ],
228
  [
229
- "Ancient stone bridge in a mossy ravine, waterfalls and hanging lanterns at dusk",
230
- 3407, 1024, 1024, 30,
231
- "volumetric fog, wet stone microtexture, realistic vegetation",
232
- "banding, washed-out, text",
233
- "stretch",
234
  ],
235
  ]
236
 
@@ -244,9 +186,8 @@ with gr.Blocks(css=css) as demo:
244
  num_inference_steps,
245
  positive_prompt,
246
  negative_prompt,
247
- resize_mode,
248
  ],
249
- label="Click & Fill Examples",
250
  )
251
 
252
  def show_result():
@@ -263,7 +204,6 @@ with gr.Blocks(css=css) as demo:
263
  num_inference_steps,
264
  positive_prompt,
265
  negative_prompt,
266
- resize_mode,
267
  ],
268
  outputs=[result_1],
269
  )
@@ -271,5 +211,4 @@ with gr.Blocks(css=css) as demo:
271
  cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
272
 
273
  if __name__ == "__main__":
274
- # Set share=True if you want a public link
275
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import spaces
4
  from PIL import Image
5
  import torch
 
23
  pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
24
 
25
  MAX_SEED = np.iinfo(np.int16).max
 
26
  DEFAULT_POSITIVE_PROMPT = None
27
  DEFAULT_NEGATIVE_PROMPT = None
28
 
29
  def _ensure_pil(x):
30
+ """Ensure returned image is a PIL.Image.Image."""
31
  if isinstance(x, Image.Image):
32
  return x
33
+ import numpy as np
34
+ if hasattr(x, "detach"):
35
+ x = x.detach().float().clamp(0, 1).cpu().numpy()
36
+ if isinstance(x, np.ndarray):
37
+ if x.dtype != np.uint8:
38
+ x = (x * 255.0).clip(0, 255).astype(np.uint8)
39
+ if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC
40
+ x = np.moveaxis(x, 0, -1)
41
+ return Image.fromarray(x)
42
+ raise TypeError("Unsupported image type returned by pipeline.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  @spaces.GPU(duration=300)
45
  def infer(
 
50
  num_inference_steps=28,
51
  positive_prompt=DEFAULT_POSITIVE_PROMPT,
52
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
 
53
  progress=gr.Progress(track_tqdm=True),
54
  ):
55
+ """Run inference at exactly (width, height)."""
56
  if prompt in [None, ""]:
57
  gr.Warning("⚠️ Please enter a prompt!")
58
  return None
59
 
 
60
  with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
61
  imgs = pipeline.generate_image(
62
  prompt,
 
74
  progress=True,
75
  )
76
 
77
+ return _ensure_pil(imgs[0]) # Return raw output exactly as generated
 
 
 
 
 
78
 
79
  css = """
80
  #col-container {
81
  margin: 0 auto;
82
+ max-width: 800px;
83
  }
84
  """
85
 
86
  with gr.Blocks(css=css) as demo:
87
  with gr.Column(elem_id="col-container"):
88
+ gr.Markdown("# NextStep-1-Large — Exact Output Size")
89
 
90
  with gr.Row():
91
  prompt = gr.Text(
 
133
  width = gr.Slider(
134
  label="Width",
135
  minimum=256,
136
+ maximum=512,
137
  step=64,
138
+ value=512,
139
  )
140
  height = gr.Slider(
141
  label="Height",
142
  minimum=256,
143
+ maximum=512,
144
  step=64,
145
+ value=512,
146
  )
 
 
 
 
 
147
 
148
  with gr.Row():
 
 
149
  result_1 = gr.Image(
150
  label="Result",
151
  show_label=True,
 
154
  format="png",
155
  )
156
 
157
+ # Click & Fill Examples (all <=512px)
158
  examples = [
 
159
  [
160
+ "A cozy wooden cabin by a frozen lake, northern lights in the sky",
161
+ 123, 512, 512, 28,
162
+ "photorealistic, cinematic lighting, starry night, glowing reflections",
163
+ "low-res, distorted, extra objects"
 
164
  ],
165
  [
166
+ "Futuristic city skyline at sunset, flying cars, neon reflections",
167
+ 456, 512, 384, 30,
168
+ "detailed, vibrant, cinematic, sharp edges",
169
+ "washed out, cartoon, blurry"
 
170
  ],
171
  [
172
+ "Close-up of a rare orchid in a greenhouse with soft morning light",
173
+ 789, 384, 512, 32,
174
+ "macro lens effect, ultra-detailed petals, dew drops",
175
+ "grainy, noisy, oversaturated"
 
176
  ],
177
  ]
178
 
 
186
  num_inference_steps,
187
  positive_prompt,
188
  negative_prompt,
 
189
  ],
190
+ label="Click & Fill Examples (Exact Size)",
191
  )
192
 
193
  def show_result():
 
204
  num_inference_steps,
205
  positive_prompt,
206
  negative_prompt,
 
207
  ],
208
  outputs=[result_1],
209
  )
 
211
  cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
212
 
213
  if __name__ == "__main__":
 
214
  demo.launch()