Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -36,6 +36,7 @@ else:
|
|
36 |
from torchvision.transforms import Compose, Lambda, Normalize
|
37 |
from torchvision.io.video import read_video
|
38 |
import argparse
|
|
|
39 |
|
40 |
from common.distributed import (
|
41 |
get_device,
|
@@ -62,6 +63,7 @@ from urllib.parse import urlparse
|
|
62 |
from torch.hub import download_url_to_file, get_dir
|
63 |
import shlex
|
64 |
import uuid
|
|
|
65 |
|
66 |
|
67 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
@@ -76,7 +78,7 @@ subprocess.run(
|
|
76 |
)
|
77 |
|
78 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
79 |
-
"""Load file
|
80 |
|
81 |
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
82 |
|
@@ -225,29 +227,6 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
225 |
@spaces.GPU(duration=100)
|
226 |
def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
227 |
runner = configure_runner(1)
|
228 |
-
output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
|
229 |
-
def _build_pos_and_neg_prompt():
|
230 |
-
# read positive prompt
|
231 |
-
positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
|
232 |
-
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
|
233 |
-
skin pore detailing, hyper sharpness, perfect without deformations."
|
234 |
-
# read negative prompt
|
235 |
-
negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
|
236 |
-
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
|
237 |
-
signature, jpeg artifacts, deformed, lowres, over-smooth"
|
238 |
-
return positive_text, negative_text
|
239 |
-
|
240 |
-
def _build_test_prompts(video_path):
|
241 |
-
positive_text, negative_text = _build_pos_and_neg_prompt()
|
242 |
-
original_videos = []
|
243 |
-
prompts = {}
|
244 |
-
video_list = os.listdir(video_path)
|
245 |
-
for f in video_list:
|
246 |
-
# if f.endswith(".mp4"):
|
247 |
-
original_videos.append(f)
|
248 |
-
prompts[f] = positive_text
|
249 |
-
print(f"Total prompts to be generated: {len(original_videos)}")
|
250 |
-
return original_videos, prompts, negative_text
|
251 |
|
252 |
def _extract_text_embeds():
|
253 |
# Text encoder forward.
|
@@ -294,7 +273,6 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
294 |
# set random seed
|
295 |
set_seed(seed, same_across_ranks=True)
|
296 |
os.makedirs('output/', exist_ok=True)
|
297 |
-
tgt_path = 'output/'
|
298 |
|
299 |
# get test prompts
|
300 |
original_videos = [video_path.split('/')[-1]]
|
@@ -331,13 +309,24 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
331 |
# read condition latents
|
332 |
cond_latents = []
|
333 |
for video in videos:
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
342 |
|
343 |
ori_lengths = [video.size(1) for video in cond_latents]
|
@@ -386,14 +375,20 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
386 |
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
|
387 |
sample = sample.to(torch.uint8).numpy()
|
388 |
|
389 |
-
|
390 |
-
output_dir, sample
|
391 |
-
|
|
|
|
|
|
|
392 |
|
393 |
# print(f"Generated video size: {sample.shape}")
|
394 |
gc.collect()
|
395 |
torch.cuda.empty_cache()
|
396 |
-
|
|
|
|
|
|
|
397 |
|
398 |
|
399 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
|
@@ -411,16 +406,17 @@ with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversar
|
|
411 |
|
412 |
# Interface
|
413 |
with gr.Row():
|
414 |
-
input_video = gr.
|
415 |
seed = gr.Number(label="Seeds", value=666)
|
416 |
fps = gr.Number(label="fps", value=24)
|
417 |
|
418 |
with gr.Row():
|
419 |
output_video = gr.Video(label="Output")
|
|
|
420 |
download_link = gr.File(label="Download the output")
|
421 |
|
422 |
run_button = gr.Button("Run")
|
423 |
-
run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_video, download_link])
|
424 |
|
425 |
# Examples
|
426 |
gr.Examples(
|
|
|
36 |
from torchvision.transforms import Compose, Lambda, Normalize
|
37 |
from torchvision.io.video import read_video
|
38 |
import argparse
|
39 |
+
from PIL import Image
|
40 |
|
41 |
from common.distributed import (
|
42 |
get_device,
|
|
|
63 |
from torch.hub import download_url_to_file, get_dir
|
64 |
import shlex
|
65 |
import uuid
|
66 |
+
import mimetypes
|
67 |
|
68 |
|
69 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
|
78 |
)
|
79 |
|
80 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
81 |
+
"""Load file from http url, will download models if necessary.
|
82 |
|
83 |
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
84 |
|
|
|
227 |
@spaces.GPU(duration=100)
|
228 |
def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
229 |
runner = configure_runner(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def _extract_text_embeds():
|
232 |
# Text encoder forward.
|
|
|
273 |
# set random seed
|
274 |
set_seed(seed, same_across_ranks=True)
|
275 |
os.makedirs('output/', exist_ok=True)
|
|
|
276 |
|
277 |
# get test prompts
|
278 |
original_videos = [video_path.split('/')[-1]]
|
|
|
309 |
# read condition latents
|
310 |
cond_latents = []
|
311 |
for video in videos:
|
312 |
+
media_type, _ = mimetypes.guess_type(video_path)
|
313 |
+
is_image = media_type and media_type.startswith("image")
|
314 |
+
is_video = media_type and media_type.startswith("video")
|
315 |
+
if is_video:
|
316 |
+
video = (
|
317 |
+
read_video(
|
318 |
+
os.path.join(video_path), output_format="TCHW"
|
319 |
+
)[0]
|
320 |
+
/ 255.0
|
321 |
+
)
|
322 |
+
print(f"Read video size: {video.size()}")
|
323 |
+
output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
|
324 |
+
else:
|
325 |
+
img = Image.open(input_file.name).convert("RGB")
|
326 |
+
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
|
327 |
+
video = img_tensor.permute(0, 1, 2, 3) # (T=1, C, H, W)
|
328 |
+
print(f"Read Image size: {video.size()}")
|
329 |
+
output_dir = 'output/' + str(uuid.uuid4()) + '.png'
|
330 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
331 |
|
332 |
ori_lengths = [video.size(1) for video in cond_latents]
|
|
|
375 |
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
|
376 |
sample = sample.to(torch.uint8).numpy()
|
377 |
|
378 |
+
if is_image:
|
379 |
+
mediapy.write(output_dir, sample[0])
|
380 |
+
else:
|
381 |
+
mediapy.write_video(
|
382 |
+
output_dir, sample, fps=fps_out
|
383 |
+
)
|
384 |
|
385 |
# print(f"Generated video size: {sample.shape}")
|
386 |
gc.collect()
|
387 |
torch.cuda.empty_cache()
|
388 |
+
if is_image:
|
389 |
+
return output_dir, None
|
390 |
+
else:
|
391 |
+
return None, output_dir
|
392 |
|
393 |
|
394 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
|
|
|
406 |
|
407 |
# Interface
|
408 |
with gr.Row():
|
409 |
+
input_video = gr.File(label="Upload image or video", type="file")
|
410 |
seed = gr.Number(label="Seeds", value=666)
|
411 |
fps = gr.Number(label="fps", value=24)
|
412 |
|
413 |
with gr.Row():
|
414 |
output_video = gr.Video(label="Output")
|
415 |
+
output_image = gr.Image(label="Output_Image")
|
416 |
download_link = gr.File(label="Download the output")
|
417 |
|
418 |
run_button = gr.Button("Run")
|
419 |
+
run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_image, output_video, download_link])
|
420 |
|
421 |
# Examples
|
422 |
gr.Examples(
|