Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
import time | |
import gradio as gr | |
import torch | |
# from diffusers import CogVideoXPipeline | |
import torch | |
from models.pipeline import VchitectXLPipeline | |
import random | |
import numpy as np | |
import os | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import torch | |
from transformers import ( | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
T5TokenizerFast, | |
) | |
from models.modeling_t5 import T5EncoderModel | |
from models.VchitectXL import VchitectXLTransformerModel | |
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel, CLIPTextModelWithProjection | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin | |
from diffusers.models.autoencoders import AutoencoderKL | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from diffusers.utils import ( | |
is_torch_xla_available, | |
logging, | |
replace_example_docstring, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
# from patch_conv import convert_model | |
from op_replace import replace_all_layernorms | |
if is_torch_xla_available(): | |
import torch_xla.core.xla_model as xm | |
XLA_AVAILABLE = True | |
else: | |
XLA_AVAILABLE = False | |
import math | |
from diffusers.utils import export_to_video | |
from datetime import datetime, timedelta | |
# from openai import OpenAI | |
import spaces | |
import moviepy.editor as mp | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
import torch.fft | |
def myfft(tensor): | |
if True: | |
if True: | |
tensor_fft = torch.fft.fft2(tensor) | |
# 将频谱中心移到图像中心 | |
tensor_fft_shifted = torch.fft.fftshift(tensor_fft) | |
# 获取张量的尺寸 | |
B, C, H, W = tensor.size() | |
# 定义频率分离的半径 | |
radius = min(H, W) // 5 # 可以调整此值 | |
# 创建一个中心为(H/2, W/2)的圆形掩码 | |
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
center_x, center_y = W // 2, H // 2 | |
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 | |
# 创建高频和低频掩码 | |
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) | |
high_freq_mask = ~low_freq_mask | |
# 获取低频分量 | |
low_freq_fft = tensor_fft_shifted * low_freq_mask | |
# low_freq_fft_shifted = torch.fft.ifftshift(low_freq_fft) | |
# low_freq = torch.fft.ifft2(low_freq_fft_shifted).real | |
# 获取高频分量 | |
high_freq_fft = tensor_fft_shifted * high_freq_mask | |
# high_freq_fft_shifted = torch.fft.ifftshift(high_freq_fft) | |
# high_freq = torch.fft.ifft2(high_freq_fft_shifted).real | |
return low_freq_fft, high_freq_fft | |
def acc_call( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
prompt_3: Optional[Union[str, List[str]]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
frames: Optional[int] = None, | |
num_inference_steps: int = 28, | |
timesteps: List[int] = None, | |
guidance_scale: float = 7.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
negative_prompt_3: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
clip_skip: Optional[int] = None, | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
): | |
if True: | |
# print('acc call.......') | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
frames = frames or 24 | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
prompt_3, | |
height, | |
width, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
negative_prompt_3=negative_prompt_3, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
) | |
self._guidance_scale = guidance_scale | |
self._clip_skip = clip_skip | |
self._joint_attention_kwargs = joint_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self.execution_device | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = self.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_3=prompt_3, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
negative_prompt_3=negative_prompt_3, | |
do_classifier_free_guidance=self.do_classifier_free_guidance, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
device=device, | |
clip_skip=self.clip_skip, | |
num_images_per_prompt=num_images_per_prompt, | |
) | |
if self.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
# 4. Prepare timesteps | |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
self._num_timesteps = len(timesteps) | |
# 5. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
frames, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 6. Denoising loop | |
# with self.progress_bar(total=num_inference_steps) as progress_bar: | |
from tqdm import tqdm | |
for i, t in tqdm(enumerate(timesteps)): | |
if self.interrupt: | |
continue | |
# print(i, t,'******',timesteps) | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latents.shape[0]) | |
noise_pred_text = self.transformer( | |
hidden_states=latent_model_input[1,:].unsqueeze(0), | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds[1,:].unsqueeze(0), | |
pooled_projections=pooled_prompt_embeds[1,:].unsqueeze(0), | |
joint_attention_kwargs=self.joint_attention_kwargs, | |
return_dict=False, | |
# idx=i, | |
)[0] | |
if i<30 or (i>30 and i%5==0): | |
noise_pred_uncond = self.transformer( | |
hidden_states=latent_model_input[0,:].unsqueeze(0), | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds[0,:].unsqueeze(0), | |
pooled_projections=pooled_prompt_embeds[0,:].unsqueeze(0), | |
joint_attention_kwargs=self.joint_attention_kwargs, | |
return_dict=False, | |
# idx=i, | |
)[0] | |
# print(noise_pred_uncond.shape,noise_pred_text.shape) | |
# exit(0) | |
# torch.Size([80, 16, 54, 96]) torch.Size([80, 16, 54, 96]) | |
if i>=28: | |
lf_uc,hf_uc = myfft(noise_pred_uncond.float()) | |
lf_c, hf_c = myfft(noise_pred_text.float()) | |
delta_lf = lf_uc -lf_c | |
delta_hf = hf_uc - hf_c | |
else: | |
lf_c, hf_c = myfft(noise_pred_text.float()) | |
delta_lf = delta_lf * 1.1 | |
delta_hf = delta_hf * 1.25 | |
new_lf_uc = delta_lf + lf_c | |
new_hf_uc = delta_hf + hf_c | |
combine_uc = new_lf_uc + new_hf_uc | |
combined_fft = torch.fft.ifftshift(combine_uc) | |
noise_pred_uncond = torch.fft.ifft2(combined_fft).real | |
self._guidance_scale = 1 + guidance_scale * ( | |
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 | |
) | |
# perform guidance | |
if self.do_classifier_free_guidance: | |
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents_dtype = latents.dtype | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
if latents.dtype != latents_dtype: | |
if torch.backends.mps.is_available(): | |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
latents = latents.to(latents_dtype) | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
negative_pooled_prompt_embeds = callback_outputs.pop( | |
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds | |
) | |
# call the callback, if provided | |
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
# progress_bar.update() | |
if XLA_AVAILABLE: | |
xm.mark_step() | |
# if output_type == "latent": | |
# image = latents | |
# else: | |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
videos = [] | |
for v_idx in range(latents.shape[1]): | |
image = self.vae.decode(latents[:,v_idx], return_dict=False)[0] | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
videos.append(image[0]) | |
return videos | |
import os | |
from huggingface_hub import login | |
login(token=os.getenv('HF_TOKEN')) | |
dtype = torch.float16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = VchitectXLPipeline("Vchitect/Vchitect-XL-2B",device) | |
# pipe.acc_call = acc_call.__get__(pipe) | |
import types | |
# pipe.__call__ = types.MethodType(acc_call, pipe) | |
pipe.__class__.__call__ = acc_call | |
os.makedirs("./output", exist_ok=True) | |
os.makedirs("./gradio_tmp", exist_ok=True) | |
def infer(prompt: str, progress=gr.Progress(track_tqdm=True)): | |
torch.cuda.empty_cache() | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
video = pipe( | |
prompt, | |
negative_prompt="", | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
width=768, | |
height=432, #480x288 624x352 432x240 768x432 | |
frames=16 | |
) | |
return video | |
def save_video(tensor): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
video_path = f"./output/{timestamp}.mp4" | |
os.makedirs(os.path.dirname(video_path), exist_ok=True) | |
export_to_video(tensor, video_path) | |
return video_path | |
def convert_to_gif(video_path): | |
clip = mp.VideoFileClip(video_path) | |
clip = clip.set_fps(8) | |
clip = clip.resize(height=240) | |
gif_path = video_path.replace(".mp4", ".gif") | |
clip.write_gif(gif_path, fps=8) | |
return gif_path | |
def delete_old_files(): | |
while True: | |
now = datetime.now() | |
cutoff = now - timedelta(minutes=10) | |
directories = ["./output", "./gradio_tmp"] | |
for directory in directories: | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
if os.path.isfile(file_path): | |
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
if file_mtime < cutoff: | |
os.remove(file_path) | |
time.sleep(600) | |
threading.Thread(target=delete_old_files, daemon=True).start() | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
Vchitect-2.0 Huggingface Space🤗 | |
</div> | |
<div style="text-align: center;"> | |
<a href="https://huggingface.co/Vchitect-XL/Vchitect-XL-2B">🤗 2B Model Hub</a> | | |
<a href="https://vchitect.intern-ai.org.cn/">🌐 Website</a> | | |
</div> | |
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;"> | |
⚠️ This demo is for academic research and experiential use only. | |
Users should strictly adhere to local laws and ethics. | |
</div> | |
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;"> | |
Note: Due to GPU memory limitations, the demo only supports 2s video generation. For the full version, you'll need to run it locally. | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5) | |
# with gr.Row(): | |
# gr.Markdown( | |
# "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.") | |
# enhance_button = gr.Button("✨ Enhance Prompt(Optional)") | |
with gr.Column(): | |
# gr.Markdown("**Optional Parameters** (default values are recommended)<br>" | |
# "Increasing the number of inference steps will produce more detailed videos, but it will slow down the process.<br>" | |
# "50 steps are recommended for most cases.<br>" | |
# "For the 5B model, 50 steps will take approximately 350 seconds.") | |
# with gr.Row(): | |
# num_inference_steps = gr.Number(label="Inference Steps", value=50) | |
# guidance_scale = gr.Number(label="Guidance Scale", value=7.5) | |
generate_button = gr.Button("🎬 Generate Video") | |
with gr.Column(): | |
video_output = gr.Video(label="Generate Video", width=768, height=432) | |
with gr.Row(): | |
download_video_button = gr.File(label="📥 Download Video", visible=False) | |
download_gif_button = gr.File(label="📥 Download GIF", visible=False) | |
def generate(prompt, model_choice, progress=gr.Progress(track_tqdm=True)): | |
tensor = infer(prompt, progress=progress) | |
video_path = save_video(tensor) | |
video_update = gr.update(visible=True, value=video_path) | |
gif_path = convert_to_gif(video_path) | |
gif_update = gr.update(visible=True, value=gif_path) | |
return video_path, video_update, gif_update | |
# def enhance_prompt_func(prompt): | |
# return convert_prompt(prompt, retry_times=1) | |
generate_button.click( | |
generate, | |
inputs=[prompt], | |
outputs=[video_output, download_video_button, download_gif_button] | |
) | |
# enhance_button.click( | |
# enhance_prompt_func, | |
# inputs=[prompt], | |
# outputs=[prompt] | |
# ) | |
if __name__ == "__main__": | |
demo.launch() | |