YaohuiW commited on
Commit
3e81f41
·
verified ·
1 Parent(s): be1dbf7

Update gradio_tabs/img_edit.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/img_edit.py +55 -18
gradio_tabs/img_edit.py CHANGED
@@ -42,6 +42,7 @@ def load_image(img, size):
42
  w, h = img.size
43
  img = img.resize((size, size))
44
  img = np.asarray(img)
 
45
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
46
 
47
  return img / 255.0, w, h
@@ -55,40 +56,76 @@ def img_preprocessing(img_path, size):
55
  return imgs_norm, w, h
56
 
57
 
58
- def resize(img, size):
59
- transform = torchvision.transforms.Compose([
60
- torchvision.transforms.Resize((size,size), antialias=True),
61
- ])
62
 
63
- return transform(img)
64
 
65
 
66
- def resize_back(img, w, h):
67
- transform = torchvision.transforms.Compose([
68
- torchvision.transforms.Resize((h, w), antialias=True),
69
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
71
  return transform(img)
72
 
73
 
 
 
 
 
 
 
74
  def img_denorm(img):
75
- img = img.clamp(-1, 1).cpu()
76
  img = (img - img.min()) / (img.max() - img.min())
77
 
78
  return img
79
 
80
 
81
- def img_postprocessing(image, w, h):
 
 
 
 
 
 
 
 
 
 
82
 
83
- image = resize_back(image, w, h)
84
- image = image.permute(0, 2, 3, 1)
85
- edited_image = img_denorm(image)
86
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
87
 
88
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
89
- imageio.imwrite(temp_file.name, img_output, quality=8)
90
- return temp_file.name
 
91
 
 
 
92
 
93
  def img_edit(gen, device):
94
 
 
42
  w, h = img.size
43
  img = img.resize((size, size))
44
  img = np.asarray(img)
45
+ img = np.copy(img)
46
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
47
 
48
  return img / 255.0, w, h
 
56
  return imgs_norm, w, h
57
 
58
 
59
+ # def resize(img, size):
60
+ # transform = torchvision.transforms.Compose([
61
+ # torchvision.transforms.Resize((size,size), antialias=True),
62
+ # ])
63
 
64
+ # return transform(img)
65
 
66
 
67
+ # def resize_back(img, w, h):
68
+ # transform = torchvision.transforms.Compose([
69
+ # torchvision.transforms.Resize((h, w), antialias=True),
70
+ # ])
71
+
72
+ # return transform(img)
73
+
74
+ # Pre-compile resize transforms for better performance
75
+ resize_transform_cache = {}
76
+
77
+ def get_resize_transform(size):
78
+ """Get cached resize transform - creates once, reuses many times"""
79
+ if size not in resize_transform_cache:
80
+ # Only create the transform if it doesn't exist in cache
81
+ resize_transform_cache[size] = torchvision.transforms.Resize(
82
+ size,
83
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
84
+ antialias=True
85
+ )
86
+ return resize_transform_cache[size]
87
+
88
 
89
+ def resize(img, size):
90
+ """Use cached resize transform"""
91
+ transform = get_resize_transform((size, size))
92
  return transform(img)
93
 
94
 
95
+ def resize_back(img, w, h):
96
+ """Use cached resize transform for back operation"""
97
+ transform = get_resize_transform((h, w))
98
+ return transform(img)
99
+
100
+
101
  def img_denorm(img):
102
+ img = img.clamp(-1, 1)
103
  img = (img - img.min()) / (img.max() - img.min())
104
 
105
  return img
106
 
107
 
108
+ # def img_postprocessing(image, w, h):
109
+
110
+ # image = resize_back(image, w, h)
111
+ # image = image.permute(0, 2, 3, 1)
112
+ # edited_image = img_denorm(image)
113
+ # img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
114
+
115
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
116
+ # imageio.imwrite(temp_file.name, img_output, quality=8)
117
+ # return temp_file.name
118
+
119
 
120
+ def img_postprocessing(img, w, h):
 
 
 
121
 
122
+ img = resize_back(img, w, h)
123
+ img = img_denorm(img)
124
+ img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
125
+ img_output = (img.cpu().numpy() * 255).astype(np.uint8)
126
 
127
+ return img_output
128
+
129
 
130
  def img_edit(gen, device):
131