Aduc-sdr commited on
Commit
746b66d
·
verified ·
1 Parent(s): f4d4a28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -48
app.py CHANGED
@@ -15,28 +15,6 @@ import spaces
15
  import subprocess
16
  import os
17
  import sys
18
-
19
- # --- Setup: Clone repository, Change Directory, and Update Python Path ---
20
- # Esta é a abordagem definitiva para corrigir todos os problemas de caminho.
21
-
22
- # 1. Clone o repositório
23
- repo_dir_name = "SeedVR2-3B"
24
- if not os.path.exists(repo_dir_name):
25
- print(f"Clonando o repositório {repo_dir_name}...")
26
- subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
27
-
28
- # 2. Mude o diretório de trabalho atual para a raiz do repositório.
29
- os.chdir(repo_dir_name)
30
- print(f"Diretório de trabalho alterado para: {os.getcwd()}")
31
-
32
- # 3. Adicione explicitamente o novo diretório de trabalho ao caminho do sistema do Python.
33
- sys.path.insert(0, os.path.abspath('.'))
34
- print(f"Diretório atual adicionado ao sys.path: {os.path.abspath('.')}")
35
-
36
-
37
- # --- Código Principal da Aplicação ---
38
- # Agora, todas as importações e cargas de arquivos devem funcionar corretamente.
39
-
40
  import torch
41
  import mediapy
42
  from einops import rearrange
@@ -47,6 +25,8 @@ import gc
47
  from PIL import Image
48
  import gradio as gr
49
  from pathlib import Path
 
 
50
  import shlex
51
  import uuid
52
  import mimetypes
@@ -54,33 +34,69 @@ import torchvision.transforms as T
54
  from torchvision.transforms import Compose, Lambda, Normalize
55
  from torchvision.io.video import read_video
56
 
57
- # Importações do repositório (agora funcionarão)
58
- from data.image.transforms.divisible_crop import DivisibleCrop
59
- from data.image.transforms.na_resize import NaResize
60
- from data.video.transforms.rearrange import Rearrange
61
- from common.config import load_config
62
- from common.distributed import init_torch
63
- from common.distributed.advanced import init_sequence_parallel
64
- from common.seed import set_seed
65
- from common.partition import partition_by_size
66
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
67
- from common.distributed.ops import sync_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Verifica o utilitário color_fix (usando caminho relativo)
70
- if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
71
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
72
- use_colorfix = True
73
- else:
74
- use_colorfix = False
75
- print('Atenção!!!!!! A correção de cor não está disponível!')
76
 
77
  # --- Configuração de Ambiente e Dependências ---
 
78
  os.environ["MASTER_ADDR"] = "127.0.0.1"
79
  os.environ["MASTER_PORT"] = "12355"
80
  os.environ["RANK"] = str(0)
81
  os.environ["WORLD_SIZE"] = str(1)
82
 
83
- # Use sys.executable para garantir que estamos usando o pip correto
84
  python_executable = sys.executable
85
  subprocess.run(
86
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
@@ -91,25 +107,40 @@ subprocess.run(
91
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
92
  if os.path.exists(apex_wheel_path):
93
  print("Instalando o Apex a partir do arquivo wheel...")
94
- # CORREÇÃO: Usar --force-reinstall e --no-cache-dir para garantir uma instalação limpa.
95
  subprocess.run(
96
  [python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path],
97
  check=True
98
  )
99
  print("✅ Configuração do Apex concluída.")
100
 
101
- # --- Funções Principais ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def configure_sequence_parallel(sp_size):
104
  if sp_size > 1:
105
  init_sequence_parallel(sp_size)
106
 
107
  def configure_runner(sp_size):
108
- # Os caminhos agora são simples e relativos à raiz do repositório
109
  config_path = 'configs_3b/main.yaml'
110
  checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
111
 
112
- config = load_config(config_path) # Isto agora funcionará corretamente
113
  runner = VideoDiffusionInfer(config)
114
  OmegaConf.set_readonly(runner.config, False)
115
 
@@ -165,7 +196,6 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
165
  def _extract_text_embeds():
166
  positive_prompts_embeds = []
167
  for _ in original_videos_local:
168
- # Os caminhos agora são simples
169
  text_pos_embeds = torch.load('pos_emb.pt')
170
  text_neg_embeds = torch.load('neg_emb.pt')
171
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
@@ -284,8 +314,7 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
284
  # --- UI do Gradio ---
285
 
286
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
287
- # Use um caminho absoluto para o arquivo de logo do Gradio para segurança
288
- logo_path = os.path.abspath("assets/seedvr_logo.png")
289
  gr.HTML(f"""
290
  <div style='text-align:center; margin-bottom: 10px;'>
291
  <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
@@ -310,6 +339,16 @@ with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
310
 
311
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
312
 
 
 
 
 
 
 
 
 
 
 
313
  gr.HTML("""
314
  <hr>
315
  <p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
 
15
  import subprocess
16
  import os
17
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import torch
19
  import mediapy
20
  from einops import rearrange
 
25
  from PIL import Image
26
  import gradio as gr
27
  from pathlib import Path
28
+ from urllib.parse import urlparse
29
+ from torch.hub import download_url_to_file, get_dir
30
  import shlex
31
  import uuid
32
  import mimetypes
 
34
  from torchvision.transforms import Compose, Lambda, Normalize
35
  from torchvision.io.video import read_video
36
 
37
+ # --- Lógica de Download de Arquivos (do script original) ---
38
+
39
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
40
+ """Carrega um arquivo de um URL http, baixando modelos se necessário."""
41
+ if model_dir is None:
42
+ hub_dir = get_dir()
43
+ model_dir = os.path.join(hub_dir, 'checkpoints')
44
+
45
+ os.makedirs(model_dir, exist_ok=True)
46
+
47
+ parts = urlparse(url)
48
+ filename = os.path.basename(parts.path)
49
+ if file_name is not None:
50
+ filename = file_name
51
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
52
+ if not os.path.exists(cached_file):
53
+ print(f'Baixando: "{url}" para {cached_file}\n')
54
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
55
+ return cached_file
56
+
57
+ ckpt_dir = Path('./ckpts')
58
+ if not ckpt_dir.exists():
59
+ ckpt_dir.mkdir()
60
+
61
+ pretrain_model_url = {
62
+ 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
63
+ 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
64
+ 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
65
+ 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
66
+ 'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
67
+ }
68
+
69
+ # Baixa os pesos e dependências se não existirem
70
+ if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'):
71
+ load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/', progress=True)
72
+ if not os.path.exists('./ckpts/ema_vae.pth'):
73
+ load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/', progress=True)
74
+ if not os.path.exists('./pos_emb.pt'):
75
+ load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./', progress=True)
76
+ if not os.path.exists('./neg_emb.pt'):
77
+ load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./', progress=True)
78
+ if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
79
+ load_file_from_url(url=pretrain_model_url['apex'], model_dir='./', progress=True)
80
+
81
+ # Baixa os vídeos de exemplo
82
+ torch.hub.download_url_to_file(
83
+ 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4',
84
+ '01.mp4')
85
+ torch.hub.download_url_to_file(
86
+ 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4',
87
+ '02.mp4')
88
+ torch.hub.download_url_to_file(
89
+ 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
90
+ '03.mp4')
91
 
 
 
 
 
 
 
 
92
 
93
  # --- Configuração de Ambiente e Dependências ---
94
+
95
  os.environ["MASTER_ADDR"] = "127.0.0.1"
96
  os.environ["MASTER_PORT"] = "12355"
97
  os.environ["RANK"] = str(0)
98
  os.environ["WORLD_SIZE"] = str(1)
99
 
 
100
  python_executable = sys.executable
101
  subprocess.run(
102
  [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
 
107
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
108
  if os.path.exists(apex_wheel_path):
109
  print("Instalando o Apex a partir do arquivo wheel...")
 
110
  subprocess.run(
111
  [python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path],
112
  check=True
113
  )
114
  print("✅ Configuração do Apex concluída.")
115
 
116
+ # --- Código Principal da Aplicação ---
117
+
118
+ from data.image.transforms.divisible_crop import DivisibleCrop
119
+ from data.image.transforms.na_resize import NaResize
120
+ from data.video.transforms.rearrange import Rearrange
121
+ if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
122
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
123
+ use_colorfix=True
124
+ else:
125
+ use_colorfix = False
126
+ print('Atenção!!!!!! A correção de cor não está disponível!')
127
+ from common.config import load_config
128
+ from common.distributed import init_torch
129
+ from common.distributed.advanced import init_sequence_parallel
130
+ from common.seed import set_seed
131
+ from common.partition import partition_by_size
132
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
133
+ from common.distributed.ops import sync_data
134
 
135
  def configure_sequence_parallel(sp_size):
136
  if sp_size > 1:
137
  init_sequence_parallel(sp_size)
138
 
139
  def configure_runner(sp_size):
 
140
  config_path = 'configs_3b/main.yaml'
141
  checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
142
 
143
+ config = load_config(config_path)
144
  runner = VideoDiffusionInfer(config)
145
  OmegaConf.set_readonly(runner.config, False)
146
 
 
196
  def _extract_text_embeds():
197
  positive_prompts_embeds = []
198
  for _ in original_videos_local:
 
199
  text_pos_embeds = torch.load('pos_emb.pt')
200
  text_neg_embeds = torch.load('neg_emb.pt')
201
  positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
 
314
  # --- UI do Gradio ---
315
 
316
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
317
+ logo_path = "assets/seedvr_logo.png"
 
318
  gr.HTML(f"""
319
  <div style='text-align:center; margin-bottom: 10px;'>
320
  <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
 
339
 
340
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
341
 
342
+ # Seção de Exemplos, que agora funcionará pois os vídeos são baixados
343
+ gr.Examples(
344
+ examples=[
345
+ ["./01.mp4", 4, 24],
346
+ ["./02.mp4", 4, 24],
347
+ ["./03.mp4", 4, 24],
348
+ ],
349
+ inputs=[input_file, seed, fps]
350
+ )
351
+
352
  gr.HTML("""
353
  <hr>
354
  <p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>: