Aduc-sdr commited on
Commit
5fad6fa
·
verified ·
1 Parent(s): f85ca57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -5
app.py CHANGED
@@ -78,15 +78,18 @@ os.environ["MASTER_PORT"] = "12355"
78
  os.environ["RANK"] = str(0)
79
  os.environ["WORLD_SIZE"] = str(1)
80
 
 
 
81
  subprocess.run(
82
- "pip install flash-attn --no-build-isolation",
83
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
84
- shell=True,
85
  )
86
 
87
  apex_wheel_path = os.path.join(repo_dir, "apex-0.1-cp310-cp310-linux_x86_64.whl")
88
  if os.path.exists(apex_wheel_path):
89
- subprocess.run(shlex.split(f"pip install {apex_wheel_path}"))
 
90
  print("✅ Apex setup completed.")
91
 
92
  # --- Core Functions ---
@@ -219,4 +222,93 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
219
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
220
  elif is_image:
221
  img = Image.open(video_path).convert("RGB")
222
- img_tensor = T.ToTensor()(img).uns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  os.environ["RANK"] = str(0)
79
  os.environ["WORLD_SIZE"] = str(1)
80
 
81
+ # CORREÇÃO: Usar sys.executable para chamar o pip corretamente
82
+ python_executable = sys.executable
83
  subprocess.run(
84
+ [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
85
+ env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
86
+ check=True
87
  )
88
 
89
  apex_wheel_path = os.path.join(repo_dir, "apex-0.1-cp310-cp310-linux_x86_64.whl")
90
  if os.path.exists(apex_wheel_path):
91
+ # CORREÇÃO: Usar sys.executable aqui também
92
+ subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
93
  print("✅ Apex setup completed.")
94
 
95
  # --- Core Functions ---
 
222
  output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
223
  elif is_image:
224
  img = Image.open(video_path).convert("RGB")
225
+ img_tensor = T.ToTensor()(img).unsqueeze(0)
226
+ video = img_tensor
227
+ print(f"Read Image size: {video.size()}")
228
+ output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
229
+ else:
230
+ raise ValueError("Unsupported file type")
231
+
232
+ cond_latents.append(video_transform(video.to(torch.device("cuda"))))
233
+
234
+ ori_lengths = [v.size(1) for v in cond_latents]
235
+ input_videos = cond_latents
236
+ if is_video:
237
+ cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
238
+
239
+ print(f"Encoding videos: {[v.size() for v in cond_latents]}")
240
+ cond_latents = runner.vae_encode(cond_latents)
241
+
242
+ for i, emb in enumerate(text_embeds["texts_pos"]):
243
+ text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
244
+ for i, emb in enumerate(text_embeds["texts_neg"]):
245
+ text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
246
+
247
+ samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
248
+ del cond_latents
249
+
250
+ for _, input_tensor, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
251
+ if ori_length < sample.shape[0]:
252
+ sample = sample[:ori_length]
253
+
254
+ input_tensor = rearrange(input_tensor, "c t h w -> t c h w")
255
+ if use_colorfix:
256
+ sample = wavelet_reconstruction(sample.to("cpu"), input_tensor[:sample.size(0)].to("cpu"))
257
+ else:
258
+ sample = sample.to("cpu")
259
+
260
+ sample = rearrange(sample, "t c h w -> t h w c")
261
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
262
+ sample = sample.to(torch.uint8).numpy()
263
+
264
+ if is_image:
265
+ mediapy.write_image(output_dir, sample[0])
266
+ else:
267
+ mediapy.write_video(output_dir, sample, fps=fps_out)
268
+
269
+ gc.collect()
270
+ torch.cuda.empty_cache()
271
+ if is_image:
272
+ return output_dir, None, output_dir
273
+ else:
274
+ return None, output_dir, output_dir
275
+
276
+ # --- Gradio UI ---
277
+
278
+ with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
279
+ logo_path = os.path.join(repo_dir, "assets/seedvr_logo.png")
280
+ gr.HTML(f"""
281
+ <div style='text-align:center; margin-bottom: 10px;'>
282
+ <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
283
+ </div>
284
+ <p><b>Official Gradio demo</b> for <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
285
+ 🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.</p>
286
+ """)
287
+
288
+ with gr.Row():
289
+ input_file = gr.File(label="Upload image or video", type="filepath")
290
+ with gr.Column():
291
+ seed = gr.Number(label="Seed", value=666)
292
+ fps = gr.Number(label="Output FPS (for video)", value=24)
293
+
294
+ run_button = gr.Button("Run")
295
+
296
+ with gr.Row():
297
+ output_image = gr.Image(label="Output Image")
298
+ output_video = gr.Video(label="Output Video")
299
+
300
+ download_link = gr.File(label="Download the output")
301
+
302
+ run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
303
+
304
+ gr.HTML("""
305
+ <hr>
306
+ <p>If you find SeedVR helpful, please ⭐ the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:
307
+ <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
308
+ <h4>Notice</h4>
309
+ <p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>. For other use cases, check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>
310
+ <h4>Limitations</h4>
311
+ <p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>
312
+ """)
313
+
314
+ demo.queue().launch(share=True)