TDN-M commited on
Commit
baebf95
·
verified ·
1 Parent(s): 1e2c4b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -26,7 +26,7 @@ def dict2namespace(config):
26
  setattr(namespace, key, new_value)
27
  return namespace
28
 
29
- def load_img (filename, norm=True,):
30
  img = np.array(Image.open(filename).convert("RGB"))
31
  h, w = img.shape[:2]
32
 
@@ -39,20 +39,19 @@ def load_img (filename, norm=True,):
39
  img = img.astype(np.float32)
40
  return img
41
 
42
- def process_img (image):
43
  img = np.array(image)
44
  img = img / 255.
45
  img = img.astype(np.float32)
46
- y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
47
 
48
  with torch.no_grad():
49
  x_hat = model(y)
50
 
51
- restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
52
- restored_img = np.clip(restored_img, 0. , 1.)
53
 
54
  restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
55
- #return Image.fromarray(restored_img) #
56
  return (image, Image.fromarray(restored_img))
57
 
58
  def load_network(net, load_path, strict=True, param_key='params'):
@@ -87,23 +86,21 @@ model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
87
  recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
88
 
89
  model = model.to(device)
90
- print ("IMAGE MODEL CKPT:", MODEL_NAME)
91
  load_network(model, MODEL_NAME, strict=True, param_key='params')
92
 
93
-
94
-
95
-
96
- title = "Enhane Quality"
97
-
98
-
99
 
100
  demo = gr.Interface(
101
  fn=process_img,
102
- inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png"),],
103
  outputs=ImageSlider(label="Super-Resolved Image",
104
  type="pil",
105
- show_download_button=True,
106
- ), #[gr.Image(type="pil", label="Ouput", min_width=500)],
107
  title=title,
108
  description=description,
109
  article=article,
 
26
  setattr(namespace, key, new_value)
27
  return namespace
28
 
29
+ def load_img(filename, norm=True):
30
  img = np.array(Image.open(filename).convert("RGB"))
31
  h, w = img.shape[:2]
32
 
 
39
  img = img.astype(np.float32)
40
  return img
41
 
42
+ def process_img(image):
43
  img = np.array(image)
44
  img = img / 255.
45
  img = img.astype(np.float32)
46
+ y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
47
 
48
  with torch.no_grad():
49
  x_hat = model(y)
50
 
51
+ restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
52
+ restored_img = np.clip(restored_img, 0., 1.)
53
 
54
  restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
 
55
  return (image, Image.fromarray(restored_img))
56
 
57
  def load_network(net, load_path, strict=True, param_key='params'):
 
86
  recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
87
 
88
  model = model.to(device)
89
+ print("IMAGE MODEL CKPT:", MODEL_NAME)
90
  load_network(model, MODEL_NAME, strict=True, param_key='params')
91
 
92
+ title = "Enhance Quality"
93
+ description = "This application enhances the quality of images using a super-resolution model."
94
+ article = "This is an article about the application."
95
+ examples = [["images/0878x4.png"]]
96
+ css = None
 
97
 
98
  demo = gr.Interface(
99
  fn=process_img,
100
+ inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png")],
101
  outputs=ImageSlider(label="Super-Resolved Image",
102
  type="pil",
103
+ show_download_button=True),
 
104
  title=title,
105
  description=description,
106
  article=article,