Men1scus commited on
Commit
88a948a
·
1 Parent(s): 05c79ef

fix: Update process_sr function to return Image instead of List[np.ndarray] and adjust output gallery type

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -322,7 +322,7 @@ def process_sr(
322
  cfg_scale: float,
323
  seed: int,
324
  model_choice: str,
325
- ) -> List[np.ndarray]:
326
  process_size = 512
327
  resize_preproc = transforms.Compose([
328
  transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
@@ -349,8 +349,6 @@ def process_sr(
349
  input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
350
  width, height = input_image.size
351
  resize_flag = True #
352
-
353
- images = []
354
 
355
  # Choose pipeline based on model selection - prioritize dit4sr_f
356
  if model_choice == "dit4sr_q" and pipeline_dit4sr_q is not None:
@@ -380,8 +378,8 @@ def process_sr(
380
  except Exception as e:
381
  print(f"Error during inference: {e}")
382
  image = Image.new(mode="RGB", size=(512, 512))
383
- images.append(np.array(image))
384
- return images
385
 
386
 
387
 
@@ -451,7 +449,7 @@ with block:
451
  scale_factor = gr.Number(label="SR Scale", value=4)
452
  gr.Examples(examples=exaple_images, inputs=[input_image])
453
  with gr.Column():
454
- result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", columns=1)
455
  with gr.Row():
456
  run_llava_button = gr.Button(value="Run LLAVA")
457
  run_sr_button = gr.Button(value="Run DiT4SR")
 
322
  cfg_scale: float,
323
  seed: int,
324
  model_choice: str,
325
+ ) -> Image.Image:
326
  process_size = 512
327
  resize_preproc = transforms.Compose([
328
  transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
 
349
  input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
350
  width, height = input_image.size
351
  resize_flag = True #
 
 
352
 
353
  # Choose pipeline based on model selection - prioritize dit4sr_f
354
  if model_choice == "dit4sr_q" and pipeline_dit4sr_q is not None:
 
378
  except Exception as e:
379
  print(f"Error during inference: {e}")
380
  image = Image.new(mode="RGB", size=(512, 512))
381
+
382
+ return image
383
 
384
 
385
 
 
449
  scale_factor = gr.Number(label="SR Scale", value=4)
450
  gr.Examples(examples=exaple_images, inputs=[input_image])
451
  with gr.Column():
452
+ result_gallery = gr.Image(label="Output", show_label=False, elem_id="gallery", type="pil", format="png")
453
  with gr.Row():
454
  run_llava_button = gr.Button(value="Run LLAVA")
455
  run_sr_button = gr.Button(value="Run DiT4SR")