Iceclear commited on
Commit
657048f
·
verified ·
1 Parent(s): 09fd3d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -38
app.py CHANGED
@@ -36,6 +36,7 @@ else:
36
  from torchvision.transforms import Compose, Lambda, Normalize
37
  from torchvision.io.video import read_video
38
  import argparse
 
39
 
40
  from common.distributed import (
41
  get_device,
@@ -62,6 +63,7 @@ from urllib.parse import urlparse
62
  from torch.hub import download_url_to_file, get_dir
63
  import shlex
64
  import uuid
 
65
 
66
 
67
  os.environ["MASTER_ADDR"] = "127.0.0.1"
@@ -76,7 +78,7 @@ subprocess.run(
76
  )
77
 
78
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
79
- """Load file form http url, will download models if necessary.
80
 
81
  Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
82
 
@@ -225,29 +227,6 @@ def generation_step(runner, text_embeds_dict, cond_latents):
225
  @spaces.GPU(duration=100)
226
  def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
227
  runner = configure_runner(1)
228
- output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
229
- def _build_pos_and_neg_prompt():
230
- # read positive prompt
231
- positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
232
- hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
233
- skin pore detailing, hyper sharpness, perfect without deformations."
234
- # read negative prompt
235
- negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
236
- CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
237
- signature, jpeg artifacts, deformed, lowres, over-smooth"
238
- return positive_text, negative_text
239
-
240
- def _build_test_prompts(video_path):
241
- positive_text, negative_text = _build_pos_and_neg_prompt()
242
- original_videos = []
243
- prompts = {}
244
- video_list = os.listdir(video_path)
245
- for f in video_list:
246
- # if f.endswith(".mp4"):
247
- original_videos.append(f)
248
- prompts[f] = positive_text
249
- print(f"Total prompts to be generated: {len(original_videos)}")
250
- return original_videos, prompts, negative_text
251
 
252
  def _extract_text_embeds():
253
  # Text encoder forward.
@@ -294,7 +273,6 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
294
  # set random seed
295
  set_seed(seed, same_across_ranks=True)
296
  os.makedirs('output/', exist_ok=True)
297
- tgt_path = 'output/'
298
 
299
  # get test prompts
300
  original_videos = [video_path.split('/')[-1]]
@@ -331,13 +309,24 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
331
  # read condition latents
332
  cond_latents = []
333
  for video in videos:
334
- video = (
335
- read_video(
336
- os.path.join(video_path), output_format="TCHW"
337
- )[0]
338
- / 255.0
339
- )
340
- print(f"Read video size: {video.size()}")
 
 
 
 
 
 
 
 
 
 
 
341
  cond_latents.append(video_transform(video.to(torch.device("cuda"))))
342
 
343
  ori_lengths = [video.size(1) for video in cond_latents]
@@ -386,14 +375,20 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
386
  sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
387
  sample = sample.to(torch.uint8).numpy()
388
 
389
- mediapy.write_video(
390
- output_dir, sample, fps=fps_out
391
- )
 
 
 
392
 
393
  # print(f"Generated video size: {sample.shape}")
394
  gc.collect()
395
  torch.cuda.empty_cache()
396
- return output_dir, output_dir
 
 
 
397
 
398
 
399
  with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
@@ -411,16 +406,17 @@ with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversar
411
 
412
  # Interface
413
  with gr.Row():
414
- input_video = gr.Video(label="Upload a video")
415
  seed = gr.Number(label="Seeds", value=666)
416
  fps = gr.Number(label="fps", value=24)
417
 
418
  with gr.Row():
419
  output_video = gr.Video(label="Output")
 
420
  download_link = gr.File(label="Download the output")
421
 
422
  run_button = gr.Button("Run")
423
- run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_video, download_link])
424
 
425
  # Examples
426
  gr.Examples(
 
36
  from torchvision.transforms import Compose, Lambda, Normalize
37
  from torchvision.io.video import read_video
38
  import argparse
39
+ from PIL import Image
40
 
41
  from common.distributed import (
42
  get_device,
 
63
  from torch.hub import download_url_to_file, get_dir
64
  import shlex
65
  import uuid
66
+ import mimetypes
67
 
68
 
69
  os.environ["MASTER_ADDR"] = "127.0.0.1"
 
78
  )
79
 
80
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
81
+ """Load file from http url, will download models if necessary.
82
 
83
  Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
84
 
 
227
  @spaces.GPU(duration=100)
228
  def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
229
  runner = configure_runner(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def _extract_text_embeds():
232
  # Text encoder forward.
 
273
  # set random seed
274
  set_seed(seed, same_across_ranks=True)
275
  os.makedirs('output/', exist_ok=True)
 
276
 
277
  # get test prompts
278
  original_videos = [video_path.split('/')[-1]]
 
309
  # read condition latents
310
  cond_latents = []
311
  for video in videos:
312
+ media_type, _ = mimetypes.guess_type(video_path)
313
+ is_image = media_type and media_type.startswith("image")
314
+ is_video = media_type and media_type.startswith("video")
315
+ if is_video:
316
+ video = (
317
+ read_video(
318
+ os.path.join(video_path), output_format="TCHW"
319
+ )[0]
320
+ / 255.0
321
+ )
322
+ print(f"Read video size: {video.size()}")
323
+ output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
324
+ else:
325
+ img = Image.open(input_file.name).convert("RGB")
326
+ img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
327
+ video = img_tensor.permute(0, 1, 2, 3) # (T=1, C, H, W)
328
+ print(f"Read Image size: {video.size()}")
329
+ output_dir = 'output/' + str(uuid.uuid4()) + '.png'
330
  cond_latents.append(video_transform(video.to(torch.device("cuda"))))
331
 
332
  ori_lengths = [video.size(1) for video in cond_latents]
 
375
  sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
376
  sample = sample.to(torch.uint8).numpy()
377
 
378
+ if is_image:
379
+ mediapy.write(output_dir, sample[0])
380
+ else:
381
+ mediapy.write_video(
382
+ output_dir, sample, fps=fps_out
383
+ )
384
 
385
  # print(f"Generated video size: {sample.shape}")
386
  gc.collect()
387
  torch.cuda.empty_cache()
388
+ if is_image:
389
+ return output_dir, None
390
+ else:
391
+ return None, output_dir
392
 
393
 
394
  with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
 
406
 
407
  # Interface
408
  with gr.Row():
409
+ input_video = gr.File(label="Upload image or video", type="file")
410
  seed = gr.Number(label="Seeds", value=666)
411
  fps = gr.Number(label="fps", value=24)
412
 
413
  with gr.Row():
414
  output_video = gr.Video(label="Output")
415
+ output_image = gr.Image(label="Output_Image")
416
  download_link = gr.File(label="Download the output")
417
 
418
  run_button = gr.Button("Run")
419
+ run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_image, output_video, download_link])
420
 
421
  # Examples
422
  gr.Examples(