DrHakase skytnt commited on
Commit
17e9cef
·
0 Parent(s):

Duplicate from skytnt/full-body-anime-gan

Browse files

Co-authored-by: skytnt <[email protected]>

Files changed (9) hide show
  1. .gitattributes +27 -0
  2. .gitignore +117 -0
  3. README.md +14 -0
  4. app.py +364 -0
  5. examples/01.jpg +0 -0
  6. examples/02.jpg +0 -0
  7. examples/03.jpg +0 -0
  8. examples/04.jpg +0 -0
  9. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+ local_settings.py
58
+ db.sqlite3
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Sphinx documentation
68
+ docs/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # IPython
77
+ profile_default/
78
+ ipython_config.py
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # SageMath parsed files
87
+ *.sage.py
88
+
89
+ # Environments
90
+ .env
91
+ .venv
92
+ env/
93
+ venv/
94
+ ENV/
95
+ env.bak/
96
+ venv.bak/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+ .spyproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+
105
+ # mkdocs documentation
106
+ /site
107
+
108
+ # mypy
109
+ .mypy_cache/
110
+ .dmypy.json
111
+ dmypy.json
112
+
113
+ # Pyre type checker
114
+ .pyre/
115
+
116
+ .idea/
117
+ video.mp4
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Full Body Anime GAN
3
+ emoji: 😇
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: skytnt/full-body-anime-gan
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ import imageio
4
+ import numpy as np
5
+ import onnx
6
+ import onnxruntime as rt
7
+ import huggingface_hub
8
+ from numpy.random import RandomState
9
+ from skimage import transform
10
+
11
+
12
+ def get_inter(r1, r2):
13
+ h_inter = max(min(r1[3], r2[3]) - max(r1[1], r2[1]), 0)
14
+ w_inter = max(min(r1[2], r2[2]) - max(r1[0], r2[0]), 0)
15
+ return h_inter * w_inter
16
+
17
+
18
+ def iou(r1, r2):
19
+ s1 = (r1[2] - r1[0]) * (r1[3] - r1[1])
20
+ s2 = (r2[2] - r2[0]) * (r2[3] - r2[1])
21
+ i = get_inter(r1, r2)
22
+ return i / (s1 + s2 - i)
23
+
24
+
25
+ def letterbox(im, new_shape=(640, 640), color=(0.5, 0.5, 0.5), stride=32):
26
+ # Resize and pad image while meeting stride-multiple constraints
27
+ shape = im.shape[:2] # current shape [height, width]
28
+
29
+ # Scale ratio (new / old)
30
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
31
+
32
+ # Compute padding
33
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
34
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
35
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
36
+
37
+ dw /= 2 # divide padding into 2 sides
38
+ dh /= 2
39
+
40
+ if shape != new_unpad: # resize
41
+ im = transform.resize(im, (new_unpad[1], new_unpad[0]))
42
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
43
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
44
+
45
+ im_new = np.full((new_unpad[1] + top + bottom, new_unpad[0] + left + right, 3), color, dtype=np.float32)
46
+ im_new[top:new_unpad[1] + top, left:new_unpad[0] + left] = im
47
+ return im_new
48
+
49
+
50
+ def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 + cls_num)
51
+ nc = pred.shape[1] - 5
52
+ candidates = [list() for x in range(nc)]
53
+ for x in pred:
54
+ if x[4] < conf_thres:
55
+ continue
56
+ cls = np.argmax(x[5:])
57
+ p = x[4] * x[5 + cls]
58
+ if conf_thres <= p:
59
+ box = (x[0] - x[2] / 2, x[1] - x[3] / 2, x[0] + x[2] / 2, x[1] + x[3] / 2) # xywh2xyxy
60
+ candidates[cls].append([p, box])
61
+ result = [list() for x in range(nc)]
62
+ for i, candidate in enumerate(candidates):
63
+ candidate = sorted(candidate, key=lambda a: a[0], reverse=True)
64
+ candidate = candidate[:max_instance]
65
+ for x in candidate:
66
+ ok = True
67
+ for r in result[i]:
68
+ if iou(r[1], x[1]) > iou_thres:
69
+ ok = False
70
+ break
71
+ if ok:
72
+ result[i].append(x)
73
+
74
+ return result
75
+
76
+
77
+ class Model:
78
+ def __init__(self):
79
+ self.detector = None
80
+ self.encoder = None
81
+ self.g_synthesis = None
82
+ self.g_mapping = None
83
+ self.detector_stride = None
84
+ self.detector_imgsz = None
85
+ self.detector_class_names = None
86
+ self.anime_seg = None
87
+ self.w_avg = None
88
+ self.load_models()
89
+
90
+ def load_models(self):
91
+ g_mapping_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_mapping.onnx")
92
+ g_synthesis_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_synthesis.onnx")
93
+ encoder_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "encoder.onnx")
94
+ detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
95
+ anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
96
+
97
+ providers = ['CPUExecutionProvider']
98
+ gpu_providers = ['CUDAExecutionProvider']
99
+ g_mapping = onnx.load(g_mapping_path)
100
+ w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
101
+ w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
102
+ w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
103
+ self.w_avg = w_avg
104
+ self.g_mapping = rt.InferenceSession(g_mapping_path, providers=gpu_providers + providers)
105
+ self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=gpu_providers + providers)
106
+ self.encoder = rt.InferenceSession(encoder_path, providers=providers)
107
+ self.detector = rt.InferenceSession(detector_path, providers=providers)
108
+ detector_meta = self.detector.get_modelmeta().custom_metadata_map
109
+ self.detector_stride = int(detector_meta['stride'])
110
+ self.detector_imgsz = 1088
111
+ self.detector_class_names = eval(detector_meta['names'])
112
+ self.anime_seg = rt.InferenceSession(anime_seg_path, providers=providers)
113
+
114
+ def get_img(self, w, noise=0):
115
+ img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
116
+ return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
117
+
118
+ def get_w(self, z, psi1, psi2):
119
+ return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
120
+
121
+ def remove_bg(self, img, s=1024):
122
+ img0 = img
123
+ img = (img / 255).astype(np.float32)
124
+ h, w = h0, w0 = img.shape[:-1]
125
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
126
+ ph, pw = s - h, s - w
127
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
128
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = transform.resize(img, (h, w))
129
+ img_input = np.transpose(img_input, (2, 0, 1))
130
+ img_input = img_input[np.newaxis, :]
131
+ mask = self.anime_seg.run(None, {'img': img_input})[0][0]
132
+ mask = np.transpose(mask, (1, 2, 0))
133
+ mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
134
+ mask = transform.resize(mask, (h0, w0))
135
+ img0 = (img0 * mask + 255 * (1 - mask)).astype(np.uint8)
136
+ return img0
137
+
138
+ def encode_img(self, img):
139
+ img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
140
+ np.float32)
141
+ return self.encoder.run(None, {'img': img})[0] + self.w_avg
142
+
143
+ def detect(self, im0, conf_thres, iou_thres, detail=False):
144
+ if im0 is None:
145
+ return []
146
+ img = letterbox((im0 / 255).astype(np.float32), (self.detector_imgsz, self.detector_imgsz),
147
+ stride=self.detector_stride)
148
+ # Convert
149
+ img = img.transpose(2, 0, 1)
150
+ img = img[np.newaxis, :]
151
+ pred = self.detector.run(None, {'images': img})[0][0]
152
+ dets = nms(pred, conf_thres, iou_thres)
153
+ imgs = []
154
+ # Print results
155
+ s = '%gx%g ' % img.shape[2:] # print string
156
+ for i, det in enumerate(dets):
157
+ n = len(det)
158
+ s += f"{n} {self.detector_class_names[i]}{'s' * (n > 1)}, " # add to string
159
+ if detail:
160
+ print(s)
161
+ waifu_rects = []
162
+ head_rects = []
163
+ body_rects = []
164
+
165
+ for i, det in enumerate(dets):
166
+ for x in det:
167
+ # Rescale boxes from img_size to im0 size
168
+ wr = im0.shape[1] / img.shape[3]
169
+ hr = im0.shape[0] / img.shape[2]
170
+ x[1] = (int(x[1][0] * wr), int(x[1][1] * hr),
171
+ int(x[1][2] * wr), int(x[1][3] * hr))
172
+ if i == 0:
173
+ head_rects.append(x[1])
174
+ elif i == 1:
175
+ body_rects.append(x[1])
176
+ elif i == 2:
177
+ waifu_rects.append(x[1])
178
+ for j, waifu_rect in enumerate(waifu_rects):
179
+ msg = f'waifu {j + 1} '
180
+ head_num = 0
181
+ body_num = 0
182
+ hr, br = None, None
183
+ for r in head_rects:
184
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.75:
185
+ hr = r
186
+ head_num += 1
187
+ if head_num != 1:
188
+ if detail:
189
+ print(msg + f'head num error: {head_num}')
190
+ continue
191
+ for r in body_rects:
192
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.65:
193
+ br = r
194
+ body_num += 1
195
+ if body_num != 1:
196
+ if detail:
197
+ print(msg + f'body num error: {body_num}')
198
+ continue
199
+ bounds = (min(waifu_rect[0], hr[0], br[0]),
200
+ min(waifu_rect[1], hr[1], br[1]),
201
+ max(waifu_rect[2], hr[2], br[2]),
202
+ max(waifu_rect[3], hr[3], br[3]))
203
+ if (bounds[2] - bounds[0]) / (bounds[3] - bounds[1]) > 0.7:
204
+ if detail:
205
+ print(msg + "ratio out of limit")
206
+ continue
207
+ expand_pixel = (bounds[3] - bounds[1]) // 20
208
+ bounds = [max(bounds[0] - expand_pixel // 2, 0),
209
+ max(bounds[1] - expand_pixel, 0),
210
+ min(bounds[2] + expand_pixel // 2, im0.shape[1]),
211
+ min(bounds[3] + expand_pixel, im0.shape[0]),
212
+ ]
213
+ # corp and resize
214
+ w = bounds[2] - bounds[0]
215
+ h = bounds[3] - bounds[1]
216
+ bounds[3] += h % 2
217
+ h += h % 2
218
+ r = min(512 / w, 1024 / h)
219
+ pw, ph = int(512 / r - w), int(1024 / r - h)
220
+ bounds_tmp = (bounds[0] - pw // 2, bounds[1] - ph // 2,
221
+ bounds[2] + pw // 2 + pw % 2, bounds[3] + ph // 2 + ph % 2)
222
+ bounds = (max(0, bounds_tmp[0]), max(0, bounds_tmp[1]),
223
+ min(im0.shape[1], bounds_tmp[2]), min(im0.shape[0], bounds_tmp[3]))
224
+ dl = bounds[0] - bounds_tmp[0]
225
+ dr = bounds[2] - bounds_tmp[2]
226
+ dt = bounds[1] - bounds_tmp[1]
227
+ db = bounds[3] - bounds_tmp[3]
228
+ w = bounds_tmp[2] - bounds_tmp[0]
229
+ h = bounds_tmp[3] - bounds_tmp[1]
230
+ temp_img = np.full((h, w, 3), 255, dtype=np.uint8)
231
+ temp_img[dt:h + db, dl:w + dr] = im0[bounds[1]:bounds[3], bounds[0]:bounds[2]]
232
+ temp_img = transform.resize(temp_img, (1024, 512), preserve_range=True).astype(np.uint8)
233
+ imgs.append(temp_img)
234
+ return imgs
235
+
236
+ def gen_video(self, w1, w2, noise, path, frame_num=10):
237
+ video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
238
+ lin = np.linspace(0, 1, frame_num)
239
+ for i in range(0, frame_num):
240
+ img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2), noise)
241
+ video.append_data(img)
242
+ video.close()
243
+
244
+
245
+ def get_thumbnail(img):
246
+ img_new = np.full((256, 384, 3), 200, dtype=np.uint8)
247
+ img_new[:, 128:256] = transform.resize(img, (256, 128), preserve_range=True)
248
+ return img_new
249
+
250
+
251
+ def gen_fn(seed, random_seed, psi1, psi2, noise):
252
+ if random_seed:
253
+ seed = random.randint(0, 2 ** 32 - 1)
254
+ z = RandomState(int(seed)).randn(1, 1024)
255
+ w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
256
+ img_out = model.get_img(w, noise)
257
+ return img_out, seed, w, get_thumbnail(img_out)
258
+
259
+
260
+ def encode_img_fn(img, noise):
261
+ if img is None:
262
+ return "please upload a image", None, None, None, None
263
+ img = model.remove_bg(img)
264
+ imgs = model.detect(img, 0.2, 0.03)
265
+ if len(imgs) == 0:
266
+ return "failed to detect anime character", None, None, None, None
267
+ w = model.encode_img(imgs[0])
268
+ img_out = model.get_img(w, noise)
269
+ return "success", imgs[0], img_out, w, get_thumbnail(img_out)
270
+
271
+
272
+ def gen_video_fn(w1, w2, noise, frame):
273
+ if w1 is None or w2 is None:
274
+ return None
275
+ model.gen_video(w1, w2, noise, "video.mp4", int(frame))
276
+ return "video.mp4"
277
+
278
+
279
+ if __name__ == '__main__':
280
+ model = Model()
281
+
282
+ app = gr.Blocks()
283
+ with app:
284
+ gr.Markdown("# full-body anime GAN\n\n"
285
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.full-body-anime-gan)\n\n")
286
+ with gr.Tabs():
287
+ with gr.TabItem("generate image"):
288
+ with gr.Row():
289
+ with gr.Column():
290
+ gr.Markdown("generate image")
291
+ with gr.Row():
292
+ gen_input1 = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, value=0, label="seed")
293
+ gen_input2 = gr.Checkbox(label="Random", value=True)
294
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="truncation psi 1")
295
+ gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
296
+ gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
297
+ with gr.Group():
298
+ gen_submit = gr.Button("Generate", variant="primary")
299
+ with gr.Column():
300
+ gen_output1 = gr.Image(label="output image")
301
+ select_img_input_w1 = gr.Variable()
302
+ select_img_input_img1 = gr.Variable()
303
+
304
+ with gr.TabItem("encode image"):
305
+ with gr.Row():
306
+ with gr.Column():
307
+ gr.Markdown("you'd better upload a standing full-body image")
308
+ encode_img_input = gr.Image(label="input image")
309
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 5)]
310
+ encode_img_examples = gr.Dataset(components=[encode_img_input], samples=examples_data)
311
+ with gr.Group():
312
+ encode_img_submit = gr.Button("Run", variant="primary")
313
+ with gr.Column():
314
+ encode_img_output1 = gr.Textbox(label="output message")
315
+ with gr.Row():
316
+ encode_img_output2 = gr.Image(label="detected")
317
+ encode_img_output3 = gr.Image(label="encoded")
318
+ select_img_input_w2 = gr.Variable()
319
+ select_img_input_img2 = gr.Variable()
320
+
321
+ with gr.TabItem("generate video"):
322
+ with gr.Row():
323
+ with gr.Column():
324
+ gr.Markdown("generate video between 2 images")
325
+ with gr.Row():
326
+ with gr.Column():
327
+ select_img1_dropdown = gr.Radio(label="Select image 1", value="current generated image",
328
+ choices=["current generated image",
329
+ "current encoded image"], type="index")
330
+ with gr.Group():
331
+ select_img1_button = gr.Button("Select", variant="primary")
332
+ select_img1_output_img = gr.Image(label="selected image 1")
333
+ select_img1_output_w = gr.Variable()
334
+ with gr.Column():
335
+ select_img2_dropdown = gr.Radio(label="Select image 2", value="current generated image",
336
+ choices=["current generated image",
337
+ "current encoded image"], type="index")
338
+ with gr.Group():
339
+ select_img2_button = gr.Button("Select", variant="primary")
340
+ select_img2_output_img = gr.Image(label="selected image 2")
341
+ select_img2_output_w = gr.Variable()
342
+ generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=15)
343
+ with gr.Group():
344
+ generate_video_button = gr.Button("Generate", variant="primary")
345
+ with gr.Column():
346
+ generate_video_output = gr.Video(label="output video")
347
+ gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
348
+ [gen_output1, gen_input1, select_img_input_w1, select_img_input_img1])
349
+ encode_img_submit.click(encode_img_fn, [encode_img_input, gen_input5],
350
+ [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
351
+ select_img_input_img2])
352
+ encode_img_examples.click(lambda x: x[0], [encode_img_examples], [encode_img_input])
353
+ select_img1_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
354
+ [select_img1_dropdown, select_img_input_img1, select_img_input_img2,
355
+ select_img_input_w1, select_img_input_w2],
356
+ [select_img1_output_img, select_img1_output_w])
357
+ select_img2_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
358
+ [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
359
+ select_img_input_w1, select_img_input_w2],
360
+ [select_img2_output_img, select_img2_output_w])
361
+ generate_video_button.click(gen_video_fn,
362
+ [select_img1_output_w, select_img2_output_w, gen_input5, generate_video_frame],
363
+ [generate_video_output])
364
+ app.launch()
examples/01.jpg ADDED
examples/02.jpg ADDED
examples/03.jpg ADDED
examples/04.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ onnx
2
+ onnxruntime-gpu
3
+ scikit-image
4
+ imageio-ffmpeg
5
+ huggingface_hub