Aduc-sdr commited on
Commit
397e9af
·
verified ·
1 Parent(s): 9e3a7d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -32
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # //
3
  # // Licensed under the Apache License, Version 2.0 (the "License");
4
  # // you may not use this file except in compliance with the License.
5
- # // You may not obtain a copy of the License at
6
  # //
7
  # // http://www.apache.org/licenses/LICENSE-2.0
8
  # //
@@ -11,22 +11,12 @@
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
- import torch.distributed as dist
 
15
  import os
16
- import gc
17
- import logging
18
  import sys
19
- import subprocess
20
- from pathlib import Path
21
- from urllib.parse import urlparse
22
- from torch.hub import download_url_to_file
23
- import gradio as gr
24
- import mediapy
25
- from einops import rearrange
26
- import shutil
27
- from omegaconf import OmegaConf
28
 
29
- # --- ETAPA 1: Clonar o Repositório Oficial do GitHub ---
30
  repo_name = "SeedVR"
31
  if not os.path.exists(repo_name):
32
  print(f"Clonando o repositório {repo_name} do GitHub...")
@@ -36,14 +26,22 @@ if not os.path.exists(repo_name):
36
  os.chdir(repo_name)
37
  print(f"Diretório de trabalho alterado para: {os.getcwd()}")
38
 
39
- # Adicionar o diretório ao path do Python para que as importações funcionem
40
  sys.path.insert(0, os.path.abspath('.'))
41
  print(f"Diretório atual adicionado ao sys.path.")
42
 
43
- # --- ETAPA 3: Instalar Dependências Conforme as Instruções ---
44
  python_executable = sys.executable
45
- print("Instalando dependências do requirements.txt...")
46
- subprocess.run([python_executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)
 
 
 
 
 
 
 
 
 
47
 
48
  print("Instalando flash-attn...")
49
  subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
@@ -52,7 +50,6 @@ from pathlib import Path
52
  from urllib.parse import urlparse
53
  from torch.hub import download_url_to_file, get_dir
54
 
55
- # Função auxiliar para downloads
56
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
57
  os.makedirs(model_dir, exist_ok=True)
58
  if not file_name:
@@ -64,7 +61,6 @@ def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
64
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
65
  return cached_file
66
 
67
- # Baixar e instalar Apex pré-compilado (crucial para o ambiente do Spaces)
68
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
69
  apex_wheel_path = load_file_from_url(url=apex_url)
70
  print("Instalando Apex a partir do wheel baixado...")
@@ -73,6 +69,8 @@ print("✅ Configuração do Apex concluída.")
73
 
74
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
75
  print("Baixando modelos pré-treinados...")
 
 
76
  pretrain_model_url = {
77
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
78
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
@@ -85,8 +83,8 @@ for key, url in pretrain_model_url.items():
85
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
86
  load_file_from_url(url=url, model_dir=model_dir)
87
 
 
88
  # --- ETAPA 5: Executar a Aplicação Principal ---
89
- import torch
90
  import mediapy
91
  from einops import rearrange
92
  from omegaconf import OmegaConf
@@ -112,16 +110,20 @@ from common.partition import partition_by_size
112
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
113
  from common.distributed.ops import sync_data
114
 
 
 
 
 
 
 
 
 
115
  os.environ["MASTER_ADDR"] = "127.0.0.1"
116
  os.environ["MASTER_PORT"] = "12355"
117
  os.environ["RANK"] = str(0)
118
  os.environ["WORLD_SIZE"] = str(1)
119
 
120
- if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
121
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
122
- use_colorfix = True
123
- else:
124
- use_colorfix = False
125
 
126
  def configure_runner():
127
  config = load_config('configs_3b/main.yaml')
@@ -136,10 +138,9 @@ def configure_runner():
136
 
137
  def generation_step(runner, text_embeds_dict, cond_latents):
138
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
139
- noises = [torch.randn_like(latent) for latent in cond_latents]
140
- aug_noises = [torch.randn_like(latent) for latent in cond_latents]
141
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
142
- noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
143
  def _add_noise(x, aug_noise):
144
  t = torch.tensor([100.0], device="cuda")
145
  shape = torch.tensor(x.shape[1:], device="cuda")[None]
@@ -158,9 +159,10 @@ def generation_loop(video_path, seed=666, fps_out=24):
158
  runner.configure_diffusion()
159
  set_seed(int(seed))
160
  os.makedirs("output", exist_ok=True)
161
- video_transform = Compose([NaResize(1024), DivisibleCrop(16), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w")])
162
  media_type, _ = mimetypes.guess_type(video_path)
163
  is_video = media_type and media_type.startswith("video")
 
164
  if is_video:
165
  video, _, _ = read_video(video_path, output_format="TCHW")
166
  video = video[:121] / 255.0
@@ -168,12 +170,14 @@ def generation_loop(video_path, seed=666, fps_out=24):
168
  else:
169
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
170
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
171
- cond_latents = [video_transform(video.to("cuda"))]
 
172
  ori_length = cond_latents[0].size(2)
173
  cond_latents = runner.vae_encode(cond_latents)
174
  samples = generation_step(runner, text_embeds, cond_latents)
175
  sample = samples[0][:ori_length].cpu()
176
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
 
177
  if is_video:
178
  mediapy.write_video(output_path, sample, fps=fps_out)
179
  return None, output_path, output_path
@@ -182,7 +186,16 @@ def generation_loop(video_path, seed=666, fps_out=24):
182
  return output_path, None, output_path
183
 
184
  with gr.Blocks(title="SeedVR") as demo:
185
- gr.HTML(f"""<div style='text-align:center; margin-bottom: 10px;'><img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;'/></div>...""")
 
 
 
 
 
 
 
 
 
186
  with gr.Row():
187
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
188
  with gr.Column():
@@ -193,5 +206,11 @@ with gr.Blocks(title="SeedVR") as demo:
193
  output_video = gr.Video(label="Vídeo de Saída")
194
  download_link = gr.File(label="Baixar Resultado")
195
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
 
 
 
 
 
 
196
 
197
  demo.queue().launch(share=True)
 
2
  # //
3
  # // Licensed under the Apache License, Version 2.0 (the "License");
4
  # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
  # //
7
  # // http://www.apache.org/licenses/LICENSE-2.0
8
  # //
 
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
+ import spaces
15
+ import subprocess
16
  import os
 
 
17
  import sys
 
 
 
 
 
 
 
 
 
18
 
19
+ # --- ETAPA 1: Clonar o Repositório do GitHub ---
20
  repo_name = "SeedVR"
21
  if not os.path.exists(repo_name):
22
  print(f"Clonando o repositório {repo_name} do GitHub...")
 
26
  os.chdir(repo_name)
27
  print(f"Diretório de trabalho alterado para: {os.getcwd()}")
28
 
 
29
  sys.path.insert(0, os.path.abspath('.'))
30
  print(f"Diretório atual adicionado ao sys.path.")
31
 
32
+ # --- ETAPA 3: Instalar Dependências Corretamente ---
33
  python_executable = sys.executable
34
+
35
+ # CORREÇÃO CRÍTICA: Filtrar requirements.txt para evitar conflitos com torch/torchvision
36
+ print("Filtrando requirements.txt para evitar conflitos de versão...")
37
+ with open("requirements.txt", "r") as f_in, open("filtered_requirements.txt", "w") as f_out:
38
+ for line in f_in:
39
+ # Ignora as linhas que podem causar conflitos
40
+ if not line.strip().startswith(('torch', 'torchvision')):
41
+ f_out.write(line)
42
+
43
+ print("Instalando dependências filtradas...")
44
+ subprocess.run([python_executable, "-m", "pip", "install", "-r", "filtered_requirements.txt"], check=True)
45
 
46
  print("Instalando flash-attn...")
47
  subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
 
50
  from urllib.parse import urlparse
51
  from torch.hub import download_url_to_file, get_dir
52
 
 
53
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
54
  os.makedirs(model_dir, exist_ok=True)
55
  if not file_name:
 
61
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
62
  return cached_file
63
 
 
64
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
65
  apex_wheel_path = load_file_from_url(url=apex_url)
66
  print("Instalando Apex a partir do wheel baixado...")
 
69
 
70
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
71
  print("Baixando modelos pré-treinados...")
72
+ import torch
73
+
74
  pretrain_model_url = {
75
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
76
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
 
83
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
84
  load_file_from_url(url=url, model_dir=model_dir)
85
 
86
+
87
  # --- ETAPA 5: Executar a Aplicação Principal ---
 
88
  import mediapy
89
  from einops import rearrange
90
  from omegaconf import OmegaConf
 
110
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
111
  from common.distributed.ops import sync_data
112
 
113
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
114
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
115
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
116
+ print("✅ Setup completo. Iniciando a aplicação...")
117
+
118
+
119
+
120
+
121
  os.environ["MASTER_ADDR"] = "127.0.0.1"
122
  os.environ["MASTER_PORT"] = "12355"
123
  os.environ["RANK"] = str(0)
124
  os.environ["WORLD_SIZE"] = str(1)
125
 
126
+ use_colorfix = os.path.exists("projects/video_diffusion_sr/color_fix.py")
 
 
 
 
127
 
128
  def configure_runner():
129
  config = load_config('configs_3b/main.yaml')
 
138
 
139
  def generation_step(runner, text_embeds_dict, cond_latents):
140
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
141
+ noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
 
142
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
143
+ noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
144
  def _add_noise(x, aug_noise):
145
  t = torch.tensor([100.0], device="cuda")
146
  shape = torch.tensor(x.shape[1:], device="cuda")[None]
 
159
  runner.configure_diffusion()
160
  set_seed(int(seed))
161
  os.makedirs("output", exist_ok=True)
162
+ transform = Compose([NaResize(1024), DivisibleCrop(16), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w")])
163
  media_type, _ = mimetypes.guess_type(video_path)
164
  is_video = media_type and media_type.startswith("video")
165
+
166
  if is_video:
167
  video, _, _ = read_video(video_path, output_format="TCHW")
168
  video = video[:121] / 255.0
 
170
  else:
171
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
172
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
173
+
174
+ cond_latents = [transform(video.to("cuda"))]
175
  ori_length = cond_latents[0].size(2)
176
  cond_latents = runner.vae_encode(cond_latents)
177
  samples = generation_step(runner, text_embeds, cond_latents)
178
  sample = samples[0][:ori_length].cpu()
179
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
180
+
181
  if is_video:
182
  mediapy.write_video(output_path, sample, fps=fps_out)
183
  return None, output_path, output_path
 
186
  return output_path, None, output_path
187
 
188
  with gr.Blocks(title="SeedVR") as demo:
189
+ gr.HTML(f"""
190
+ <div style='text-align:center; margin-bottom: 10px;'>
191
+ <img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
192
+ </div>
193
+ <p><b>Demonstração oficial do Gradio</b> para
194
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
195
+ <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
196
+ 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
197
+ </p>
198
+ """)
199
  with gr.Row():
200
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
201
  with gr.Column():
 
206
  output_video = gr.Video(label="Vídeo de Saída")
207
  download_link = gr.File(label="Baixar Resultado")
208
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
209
+ gr.Examples(examples=[["01.mp4", 42, 24], ["02.mp4", 42, 24], ["03.mp4", 42, 24]], inputs=[input_file, seed, fps])
210
+ gr.HTML("""
211
+ <hr>
212
+ <p>Se você achou o SeedVR útil, por favor ⭐ o
213
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
214
+ """)
215
 
216
  demo.queue().launch(share=True)