Aduc-sdr commited on
Commit
6b1cbcf
·
verified ·
1 Parent(s): 35c34e1

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +43 -29
managers/seedvr_manager.py CHANGED
@@ -2,13 +2,14 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.3.2
6
  #
7
- # Esta versão implementa uma correção robusta para o FileNotFoundError da configuração do VAE,
8
- # antecipando a falha, carregando as configurações manualmente e fundindo-as para
9
- # contornar o caminho fixo problemático na biblioteca externa.
10
 
11
  import torch
 
12
  import os
13
  import gc
14
  import logging
@@ -25,29 +26,29 @@ from tools.tensor_utils import wavelet_reconstruction
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
- # --- Gerenciamento de Dependências ---
29
  DEPS_DIR = Path("./deps")
30
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
31
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
32
  VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
33
 
34
  def setup_seedvr_dependencies():
35
- """Garante que o repositório do SeedVR seja clonado e esteja disponível no sys.path."""
36
  if not SEEDVR_REPO_DIR.exists():
37
- logger.info(f"Repositório SeedVR não encontrado em '{SEEDVR_REPO_DIR}'. Clonando do GitHub...")
38
  try:
39
  DEPS_DIR.mkdir(exist_ok=True)
40
  subprocess.run(["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)], check=True, capture_output=True, text=True)
41
- logger.info("Repositório SeedVR clonado com sucesso.")
42
  except subprocess.CalledProcessError as e:
43
- logger.error(f"Falha ao clonar o repositório SeedVR. Git stderr: {e.stderr}")
44
- raise RuntimeError("Não foi possível clonar a dependência necessária do SeedVR do GitHub.")
45
  else:
46
- logger.info("Repositório SeedVR local encontrado.")
47
 
48
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
49
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
50
- logger.info(f"Adicionado '{SEEDVR_REPO_DIR.resolve()}' ao sys.path.")
51
 
52
  setup_seedvr_dependencies()
53
 
@@ -61,27 +62,29 @@ from torchvision.transforms import Compose, Lambda, Normalize
61
  from torchvision.io.video import read_video
62
  from omegaconf import OmegaConf
63
 
 
64
  def _load_file_from_url(url, model_dir='./', file_name=None):
65
  os.makedirs(model_dir, exist_ok=True)
66
  filename = file_name or os.path.basename(urlparse(url).path)
67
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
68
  if not os.path.exists(cached_file):
69
- logger.info(f'Baixando: "{url}" para {cached_file}')
70
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
71
  return cached_file
72
 
73
  class SeedVrManager:
74
- """Gerencia o modelo SeedVR para tarefas de Masterização HD."""
75
  def __init__(self, workspace_dir="deformes_workspace"):
76
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
77
  self.runner = None
78
  self.workspace_dir = workspace_dir
79
  self.is_initialized = False
80
- logger.info("SeedVrManager inicializado. O modelo será carregado sob demanda.")
 
81
 
82
  def _download_models_and_configs(self):
83
- """Baixa os checkpoints necessários E o arquivo de configuração do VAE que pode estar faltando."""
84
- logger.info("Verificando e baixando modelos e configurações do SeedVR2...")
85
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
86
  config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
87
  ckpt_dir.mkdir(exist_ok=True)
@@ -96,13 +99,19 @@ class SeedVrManager:
96
  }
97
  for key, url in pretrain_model_urls.items():
98
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
99
- logger.info("Modelos e configurações do SeedVR2 baixados com sucesso.")
100
 
101
  def _initialize_runner(self, model_version: str):
102
- """Carrega e configura o modelo SeedVR, com uma correção robusta para o caminho da config do VAE."""
103
  if self.runner is not None: return
104
  self._download_models_and_configs()
105
- logger.info(f"Inicializando o executor do SeedVR2 {model_version}...")
 
 
 
 
 
 
106
  if model_version == '3B':
107
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
108
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
@@ -110,17 +119,17 @@ class SeedVrManager:
110
  config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
111
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
112
  else:
113
- raise ValueError(f"Versão do modelo SeedVR não suportada: {model_version}")
114
 
115
  try:
116
  config = load_config(str(config_path))
117
  except FileNotFoundError:
118
- logger.warning("FileNotFoundError esperado capturado. Carregando config manualmente.")
119
  config = OmegaConf.load(str(config_path))
120
  correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
121
  vae_config = OmegaConf.load(str(correct_vae_config_path))
122
  config.vae = vae_config
123
- logger.info("Configuração carregada e corrigida manualmente com sucesso.")
124
 
125
  self.runner = VideoDiffusionInfer(config)
126
  OmegaConf.set_readonly(self.runner.config, False)
@@ -129,20 +138,25 @@ class SeedVrManager:
129
  if hasattr(self.runner.vae, "set_memory_limit"):
130
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
131
  self.is_initialized = True
132
- logger.info(f"Executor para SeedVR2 {model_version} inicializado e pronto.")
133
 
134
  def _unload_runner(self):
135
- """Remove o executor da VRAM para liberar recursos."""
136
  if self.runner is not None:
137
  del self.runner; self.runner = None
138
  gc.collect(); torch.cuda.empty_cache()
139
  self.is_initialized = False
140
- logger.info("Executor do SeedVR2 descarregado da VRAM.")
 
 
 
 
 
141
 
142
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
143
  model_version: str = '3B', steps: int = 50, seed: int = 666,
144
  progress: gr.Progress = None) -> str:
145
- """Aplica o aprimoramento HD a um vídeo usando a lógica do SeedVR."""
146
  try:
147
  self._initialize_runner(model_version)
148
  set_seed(seed, same_across_ranks=True)
@@ -185,10 +199,10 @@ class SeedVrManager:
185
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
186
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
187
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
188
- logger.info(f"Vídeo Masterizado em HD salvo em: {output_video_path}")
189
  return output_video_path
190
  finally:
191
  self._unload_runner()
192
 
193
- # --- Instância Singleton ---
194
  seedvr_manager_singleton = SeedVrManager()
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.3.3
6
  #
7
+ # This version adds a monkey patch to disable torch.distributed.barrier calls
8
+ # within the SeedVR library, allowing it to run in a single-GPU inference mode
9
+ # without raising a "process group not initialized" error.
10
 
11
  import torch
12
+ import torch.distributed as dist
13
  import os
14
  import gc
15
  import logging
 
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
+ # --- Dependency Management ---
30
  DEPS_DIR = Path("./deps")
31
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
32
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
33
  VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
34
 
35
  def setup_seedvr_dependencies():
36
+ """Ensures the SeedVR repository is cloned and available in the sys.path."""
37
  if not SEEDVR_REPO_DIR.exists():
38
+ logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
39
  try:
40
  DEPS_DIR.mkdir(exist_ok=True)
41
  subprocess.run(["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)], check=True, capture_output=True, text=True)
42
+ logger.info("SeedVR repository cloned successfully.")
43
  except subprocess.CalledProcessError as e:
44
+ logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
45
+ raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
46
  else:
47
+ logger.info("Found local SeedVR repository.")
48
 
49
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
50
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
51
+ logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
52
 
53
  setup_seedvr_dependencies()
54
 
 
62
  from torchvision.io.video import read_video
63
  from omegaconf import OmegaConf
64
 
65
+
66
  def _load_file_from_url(url, model_dir='./', file_name=None):
67
  os.makedirs(model_dir, exist_ok=True)
68
  filename = file_name or os.path.basename(urlparse(url).path)
69
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
70
  if not os.path.exists(cached_file):
71
+ logger.info(f'Downloading: "{url}" to {cached_file}')
72
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
73
  return cached_file
74
 
75
  class SeedVrManager:
76
+ """Manages the SeedVR model for HD Mastering tasks."""
77
  def __init__(self, workspace_dir="deformes_workspace"):
78
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
  self.runner = None
80
  self.workspace_dir = workspace_dir
81
  self.is_initialized = False
82
+ self._original_barrier = None # To store the original distributed barrier function
83
+ logger.info("SeedVrManager initialized. Model will be loaded on demand.")
84
 
85
  def _download_models_and_configs(self):
86
+ """Downloads the necessary checkpoints AND the missing VAE config file."""
87
+ logger.info("Verifying and downloading SeedVR2 models and configs...")
88
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
89
  config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
90
  ckpt_dir.mkdir(exist_ok=True)
 
99
  }
100
  for key, url in pretrain_model_urls.items():
101
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
102
+ logger.info("SeedVR2 models and configs downloaded successfully.")
103
 
104
  def _initialize_runner(self, model_version: str):
105
+ """Loads and configures the SeedVR model, with patches for single-GPU inference."""
106
  if self.runner is not None: return
107
  self._download_models_and_configs()
108
+
109
+ if dist.is_available() and not dist.is_initialized():
110
+ logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
111
+ self._original_barrier = dist.barrier
112
+ dist.barrier = lambda *args, **kwargs: None
113
+
114
+ logger.info(f"Initializing SeedVR2 {model_version} runner...")
115
  if model_version == '3B':
116
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
117
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
 
119
  config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
120
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
121
  else:
122
+ raise ValueError(f"Unsupported SeedVR model version: {model_version}")
123
 
124
  try:
125
  config = load_config(str(config_path))
126
  except FileNotFoundError:
127
+ logger.warning("Caught expected FileNotFoundError. Loading config manually.")
128
  config = OmegaConf.load(str(config_path))
129
  correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
130
  vae_config = OmegaConf.load(str(correct_vae_config_path))
131
  config.vae = vae_config
132
+ logger.info("Configuration loaded and patched manually.")
133
 
134
  self.runner = VideoDiffusionInfer(config)
135
  OmegaConf.set_readonly(self.runner.config, False)
 
138
  if hasattr(self.runner.vae, "set_memory_limit"):
139
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
140
  self.is_initialized = True
141
+ logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
142
 
143
  def _unload_runner(self):
144
+ """Unloads the runner from VRAM and restores any applied patches."""
145
  if self.runner is not None:
146
  del self.runner; self.runner = None
147
  gc.collect(); torch.cuda.empty_cache()
148
  self.is_initialized = False
149
+ logger.info("SeedVR runner unloaded from VRAM.")
150
+
151
+ if self._original_barrier is not None:
152
+ logger.info("Restoring original torch.distributed.barrier function.")
153
+ dist.barrier = self._original_barrier
154
+ self._original_barrier = None
155
 
156
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
157
  model_version: str = '3B', steps: int = 50, seed: int = 666,
158
  progress: gr.Progress = None) -> str:
159
+ """Applies HD enhancement to a video using the SeedVR logic."""
160
  try:
161
  self._initialize_runner(model_version)
162
  set_seed(seed, same_across_ranks=True)
 
199
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
200
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
201
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
202
+ logger.info(f"HD Mastered video saved to: {output_video_path}")
203
  return output_video_path
204
  finally:
205
  self._unload_runner()
206
 
207
+ # --- Singleton Instance ---
208
  seedvr_manager_singleton = SeedVrManager()