Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import sys | |
import warnings | |
import logging | |
if os.environ.get("SPACES_ZERO_GPU") is not None: | |
import spaces | |
else: | |
class spaces: | |
def GPU(*decorator_args, **decorator_kwargs): | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
return wrapper | |
return decorator | |
import difflib | |
# Configure logging settings | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
def _get_output(cmd): | |
try: | |
return subprocess.check_output(cmd).decode("utf-8") | |
except Exception as ex: | |
logging.exception(ex) | |
return None | |
def install_cuda_toolkit(): | |
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" | |
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
print(f"[INFO] Downloading CUDA Toolkit from {CUDA_TOOLKIT_URL} ...") | |
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
print("[INFO] Installing CUDA Toolkit silently ...") | |
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
print("[INFO] Setting CUDA environment variables ...") | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ.get("PATH", "")) | |
os.environ["LD_LIBRARY_PATH"] = "%s/lib64:%s" % ( | |
os.environ["CUDA_HOME"], | |
os.environ.get("LD_LIBRARY_PATH", "") | |
) | |
# Optional: set architecture list for compilation (Ampere and Ada) | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9" | |
if os.path.exists(CUDA_TOOLKIT_FILE): | |
os.remove(CUDA_TOOLKIT_FILE) | |
print(f"[INFO] Removed installer file: {CUDA_TOOLKIT_FILE}") | |
else: | |
print(f"[WARN] Installer file not found: {CUDA_TOOLKIT_FILE}") | |
print(os.listdir("/usr/local/cuda")) | |
print("[INFO] CUDA 12.1 installation complete. CUDA_HOME set to /usr/local/cuda") | |
logging.info("Environment Variables: %s" % os.environ) | |
logging.info("Installing CUDA extensions...") | |
if _get_output(["nvcc", "--version"]) is None: | |
logging.info("Installing CUDA toolkit...") | |
install_cuda_toolkit() | |
logging.info("installCUDA: %s" % _get_output(["nvcc", "--version"])) | |
else: | |
logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"])) | |
import torch | |
import argparse | |
import json | |
import random | |
from datetime import datetime | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from tqdm import tqdm | |
from natsort import natsorted, ns | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
import base64 | |
import imageio_ffmpeg as ffmpeg | |
import subprocess | |
from different_domain_imge_gen.landmark_generation import generate_annotation | |
from transformers import ( | |
Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor | |
) | |
from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
import recon.dnnlib as dnnlib | |
import recon.legacy as legacy | |
from DiT_VAE.diffusion.utils.misc import read_config | |
from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane | |
from DiT_VAE.diffusion import IDDPM, DPMS | |
from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2 | |
from DiT_VAE.diffusion.data.datasets import get_chunks | |
# Get the directory of the current script | |
father_path = os.path.dirname(os.path.abspath(__file__)) | |
# Add necessary paths dynamically | |
sys.path.extend([ | |
os.path.join(father_path, 'recon'), | |
os.path.join(father_path, 'Next3d'), | |
os.path.join(father_path, 'data_process'), | |
os.path.join(father_path, 'data_process/lib') | |
]) | |
from lib.FaceVerse.renderer import Faceverse_manager | |
from data_process.input_img_align_extract_ldm_demo import Process | |
from lib.config.config_demo import cfg | |
import shutil | |
# Suppress warnings (especially for PyTorch) | |
warnings.filterwarnings("ignore") | |
os.environ["MEDIAPIPE_DISABLE_GPU"] = "1" # Disable GPU for MediaPipe | |
# 🔧 Set CUDA_HOME before anything else | |
# os.system("pip uninstall diffusers") | |
# os.system("pip install diffusers==0.20.1") | |
from diffusers import ( | |
StableDiffusionControlNetImg2ImgPipeline, | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
) | |
def get_args(): | |
"""Parse and return command-line arguments.""" | |
parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments") | |
# Configuration and model checkpoints | |
parser.add_argument("--config", type=str, default="./configs/infer_config.py", | |
help="Path to the configuration file.") | |
# Generation parameters | |
parser.add_argument("--bs", type=int, default=1, | |
help="Batch size for processing.") | |
parser.add_argument("--cfg_scale", type=float, default=4.5, | |
help="CFG scale parameter.") | |
parser.add_argument("--sampling_algo", type=str, default="dpm-solver", | |
choices=["iddpm", "dpm-solver"], | |
help="Sampling algorithm to be used.") | |
parser.add_argument("--seed", type=int, default=42, | |
help="Random seed for reproducibility.") | |
# parser.add_argument("--select_img", type=str, default=None, | |
# help="Optional: Select a specific image.") | |
parser.add_argument('--step', default=-1, type=int) | |
# parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters") | |
return parser.parse_args() | |
def set_env(seed=0): | |
"""Set random seed for reproducibility across multiple frameworks.""" | |
torch.manual_seed(seed) # Set PyTorch seed | |
torch.cuda.manual_seed_all(seed) # If using multi-GPU | |
np.random.seed(seed) # Set NumPy seed | |
random.seed(seed) # Set Python built-in random module seed | |
torch.set_grad_enabled(False) # Disable gradients for inference | |
def to_rgb_image(image: Image.Image): | |
"""Convert an image to RGB format if necessary.""" | |
if image.mode == 'RGB': | |
return image | |
elif image.mode == 'RGBA': | |
img = Image.new("RGB", image.size, (127, 127, 127)) | |
img.paste(image, mask=image.getchannel('A')) | |
return img | |
else: | |
raise ValueError(f"Unsupported image type: {image.mode}") | |
def image_process(image_path, clip_image_processor, dino_img_processor, device): | |
"""Preprocess an image for CLIP and DINO models.""" | |
image = to_rgb_image(Image.open(image_path)) | |
clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device) | |
dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device) | |
return dino_image, clip_image | |
# def video_gen(frames_dir, output_path, fps=30): | |
# """Generate a video from image frames.""" | |
# frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH) | |
# frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files] | |
# H, W = frames[0].shape[:2] | |
# video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H)) | |
# for frame in frames: | |
# video_writer.write(frame) | |
# video_writer.release() | |
def trans(tensor_img): | |
img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255. | |
img = img.to(torch.uint8) | |
img = img[0].detach().cpu().numpy() | |
return img | |
def get_vert(vert_dir): | |
uvcoords_image = np.load(os.path.join(vert_dir))[..., :3] | |
uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0 | |
uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1 | |
return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0) | |
def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature, | |
uncond_dino_feature, device, latent_size, sampling_algo): | |
""" | |
Generate latent samples using the specified diffusion model. | |
Args: | |
DiT_model (torch.nn.Module): The diffusion model. | |
cfg_scale (float): The classifier-free guidance scale. | |
sample_steps (int): Number of sampling steps. | |
clip_feature (torch.Tensor): CLIP feature tensor. | |
dino_feature (torch.Tensor): DINO feature tensor. | |
uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor. | |
uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor. | |
device (str): Device for computation. | |
latent_size (tuple): The latent space size. | |
sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver'). | |
Returns: | |
torch.Tensor: The generated samples. | |
""" | |
n = 1 # Batch size | |
z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device) | |
if sampling_algo == 'iddpm': | |
z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance | |
model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]), | |
img_feature=torch.cat([dino_feature, dino_feature]), | |
cfg_scale=cfg_scale) | |
diffusion = IDDPM(str(sample_steps)) | |
samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False, | |
model_kwargs=model_kwargs, progress=True, device=device) | |
samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples | |
elif sampling_algo == 'dpm-solver': | |
dpm_solver = DPMS(DiT_model.forward_with_dpmsolver, | |
condition=[clip_feature, dino_feature], | |
uncondition=[uncond_clip_feature, dino_feature], | |
cfg_scale=cfg_scale) | |
samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep") | |
else: | |
raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.") | |
return samples | |
def load_motion_aware_render_model(ckpt_path, device): | |
"""Load the motion-aware render model from a checkpoint.""" | |
logging.info("Loading motion-aware render model...") | |
with dnnlib.util.open_url(ckpt_path, 'rb') as f: | |
network = legacy.load_network_pkl(f) # type: ignore | |
logging.info("Motion-aware render model loaded.") | |
return network['G_ema'].to(device) | |
def load_diffusion_model(ckpt_path, latent_size, device): | |
"""Load the diffusion model (DiT).""" | |
logging.info("Loading diffusion model (DiT)...") | |
DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device) | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
# Remove keys that can cause mismatches | |
for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']: | |
ckpt['state_dict'].pop(key, None) | |
ckpt.get('state_dict_ema', {}).pop(key, None) | |
state_dict = ckpt.get('state_dict_ema', ckpt) | |
DiT_model.load_state_dict(state_dict, strict=False) | |
DiT_model.eval() | |
logging.info("Diffusion model (DiT) loaded.") | |
return DiT_model | |
def load_vae_clip_dino(config, device): | |
"""Load VAE, CLIP, and DINO models.""" | |
logging.info("Loading VAE, CLIP, and DINO models...") | |
# Load CLIP image encoder | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
config.image_encoder_path) | |
image_encoder.requires_grad_(False) | |
image_encoder.to(device) | |
# Load VAE | |
config_vae = OmegaConf.load(config.vae_triplane_config_path) | |
vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8) | |
vae_triplane.to(device) | |
vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin') | |
if not os.path.isfile(vae_ckpt_path): | |
raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}") | |
vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu")) | |
vae_triplane.requires_grad_(False) | |
# Load DINO model | |
dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained) | |
dinov2.requires_grad_(False) | |
dinov2.to(device) | |
# Load image processors | |
dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained) | |
clip_image_processor = CLIPImageProcessor() | |
logging.info("VAE, CLIP, and DINO models loaded.") | |
return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor | |
def prepare_working_dir(dir, style): | |
print('stylestylestylestylestylestylestyle',style) | |
if style: | |
return dir | |
else: | |
import tempfile | |
working_dir = tempfile.TemporaryDirectory() | |
return working_dir.name | |
def launch_pretrained(): | |
from huggingface_hub import snapshot_download | |
os.system("pip uninstall torch") | |
os.system("pip uninstall torchvision") | |
os.system("pip install https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl") | |
os.system("pip install https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl") | |
snapshot_download( | |
repo_id="KumaPower/AvatarArtist", | |
repo_type="model", | |
local_dir="./pretrained_model", | |
local_dir_use_symlinks=False | |
) | |
snapshot_download( | |
repo_id="stabilityai/stable-diffusion-2-1-base", | |
repo_type="model", | |
local_dir="./pretrained_model/sd21", | |
local_dir_use_symlinks=False | |
) | |
logging.info("delete models.") | |
os.remove('./pretrained_model/sd21/v2-1_512-ema-pruned.ckpt') | |
os.remove('./pretrained_model/sd21/v2-1_512-nonema-pruned.ckpt') | |
# 下载 CrucibleAI/ControlNetMediaPipeFace 的所有文件 | |
snapshot_download( | |
repo_id="CrucibleAI/ControlNetMediaPipeFace", | |
repo_type="model", | |
local_dir="./pretrained_model/control", | |
local_dir_use_symlinks=False | |
) | |
def prepare_image_list(img_dir, selected_img): | |
"""Prepare the list of image paths for processing.""" | |
if selected_img and selected_img in os.listdir(img_dir): | |
return [os.path.join(img_dir, selected_img)] | |
return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)]) | |
def images_to_video(image_folder, output_video, fps=30): | |
# Get all image files and ensure correct order | |
images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))] | |
images = natsorted(images) # Sort filenames naturally to preserve frame order | |
if not images: | |
print("❌ No images found in the directory!") | |
return | |
# Get the path to the FFmpeg executable | |
ffmpeg_exe = ffmpeg.get_ffmpeg_exe() | |
print(f"Using FFmpeg from: {ffmpeg_exe}") | |
# Define input image pattern (expects images named like "%04d.png") | |
image_pattern = os.path.join(image_folder, "%04d.png") | |
# FFmpeg command to encode video (with -y to overwrite) | |
command = [ | |
ffmpeg_exe, | |
'-y', # ✅ Overwrite output file without asking | |
'-framerate', str(fps), | |
'-i', image_pattern, | |
'-c:v', 'libx264', | |
'-preset', 'slow', | |
'-crf', '18', | |
'-pix_fmt', 'yuv420p', | |
'-b:v', '5000k', | |
output_video | |
] | |
# Run FFmpeg command | |
subprocess.run(command, check=True) | |
print(f"✅ High-quality MP4 video has been generated: {output_video}") | |
def model_define(): | |
args = get_args() | |
set_env(args.seed) | |
input_process_model = Process(cfg) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
weight_dtype = torch.float32 | |
logging.info(f"Running inference with {weight_dtype}") | |
# Load configuration | |
default_config = read_config(args.config) | |
# Ensure valid sampling algorithm | |
assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver'] | |
# Load motion-aware render model | |
motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt, device) | |
# Load diffusion model (DiT) | |
triplane_size = (256 * 4, 256) | |
latent_size = (triplane_size[0] // 8, triplane_size[1] // 8) | |
sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[ | |
args.sampling_algo] | |
DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size, device) | |
# Load VAE, CLIP, and DINO | |
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config, | |
device) | |
# Load normalization parameters | |
triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1) | |
triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1) | |
# Load average latent vector | |
ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0] | |
return motion_aware_render_model, sample_steps, DiT_model, \ | |
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, device, input_process_model | |
def duplicate_batch(tensor, batch_size=2): | |
if tensor is None: | |
return None # 如果是 None,则直接返回 | |
return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度 | |
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img, image_name_true): | |
""" | |
Generate avatars from input images. | |
Args: | |
items (list): List of image paths. | |
bs (int): Batch size. | |
sample_steps (int): Number of sampling steps. | |
cfg_scale (float): Classifier-free guidance scale. | |
save_path_base (str): Base directory for saving results. | |
DiT_model (torch.nn.Module): The diffusion model. | |
render_model (torch.nn.Module): The rendering model. | |
std (torch.Tensor): Standard deviation normalization tensor. | |
mean (torch.Tensor): Mean normalization tensor. | |
ws_avg (torch.Tensor): Latent average tensor. | |
""" | |
try: | |
if is_styled: | |
items = [styled_img] | |
else: | |
items = [items] | |
video_folder = "./demo_data/target_video" | |
video_name = os.path.basename(video_path_input).split(".")[0] | |
target_path = os.path.join(video_folder, 'data_' + video_name) | |
exp_base_dir = os.path.join(target_path, 'coeffs') | |
exp_img_base_dir = os.path.join(target_path, 'images512x512') | |
motion_base_dir = os.path.join(target_path, 'motions') | |
label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json') | |
# render_model.to(device) | |
# image_encoder.to(device) | |
# vae_triplane.to(device) | |
# dinov2.to(device) | |
# ws_avg.to(device) | |
# DiT_model.to(device) | |
# Set up face verse for amimation | |
if source_type == 'example': | |
input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs' | |
input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs' | |
elif source_type == 'custom': | |
input_img_fvid = os.path.join(save_path_base, 'processed_img/dataset/coeffs/input_image') | |
input_img_motion = os.path.join(save_path_base, 'processed_img/dataset/motions/input_image') | |
else: | |
raise ValueError("Wrong type") | |
bs = 1 | |
sample_steps = 20 | |
cfg_scale = 4.5 | |
pitch_range = 0.25 | |
yaw_range = 0.35 | |
triplane_size = (256 * 4, 256) | |
latent_size = (triplane_size[0] // 8, triplane_size[1] // 8) | |
for chunk in tqdm(list(get_chunks(items, 1)), unit='batch'): | |
if bs != 1: | |
raise ValueError("Batch size > 1 not implemented") | |
image_dir = chunk[0] | |
image_name = os.path.splitext(image_name_true)[0] | |
# # image_name = os.path.splitext(os.path.basename(image_dir))[0] | |
# if source_type == 'custom': | |
# image_name = os.path.splitext(image_name_true)[0] | |
# else: | |
# image_name = os.path.splitext(os.path.basename(image_dir))[0] | |
dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device) | |
clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[ | |
-2] | |
dino_feature = dinov2(dino_img).last_hidden_state | |
uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state | |
samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, | |
uncond_clip_feature, uncond_dino_feature, device, latent_size, | |
'dpm-solver') | |
samples = (samples / 0.3994218) | |
samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4) | |
samples = vae_triplane.decode(samples) | |
samples = rearrange(samples, "b c f h w -> b f c h w") | |
samples = samples * std + mean | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
save_frames_path_out = os.path.join(save_path_base, image_name, video_name, 'out') | |
save_frames_path_outshow = os.path.join(save_path_base, image_name, video_name,'out_show') | |
save_frames_path_depth = os.path.join(save_path_base, image_name, video_name, 'depth') | |
os.makedirs(save_frames_path_out, exist_ok=True) | |
os.makedirs(save_frames_path_outshow, exist_ok=True) | |
os.makedirs(save_frames_path_depth, exist_ok=True) | |
img_ref = np.array(Image.open(image_dir)) | |
img_ref_out = img_ref.copy() | |
img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device) | |
motion_app_dir = os.path.join(input_img_motion, image_name + '.npy') | |
motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device) | |
id_motions = os.path.join(input_img_fvid, image_name + '.npy') | |
all_pose = json.loads(open(label_file_test).read())['labels'] | |
all_pose = dict(all_pose) | |
if os.path.exists(id_motions): | |
coeff = np.load(id_motions).astype(np.float32) | |
coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0) | |
Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0] | |
motion_dir = os.path.join(motion_base_dir, video_name) | |
exp_dir = os.path.join(exp_base_dir, video_name) | |
for frame_index, motion_name in enumerate( | |
tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")): | |
exp_each_dir_img = os.path.join(exp_img_base_dir, video_name, motion_name.replace('.npy', '.png')) | |
exp_each_dir = os.path.join(exp_dir, motion_name) | |
motion_each_dir = os.path.join(motion_dir, motion_name) | |
# Load pose data | |
pose_key = os.path.join(video_name, motion_name.replace('.npy', '.png')) | |
cam2world_pose = LookAtPoseSampler.sample( | |
3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))), | |
3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))), | |
torch.tensor([0, 0, 0], device=device), radius=2.7, device=device) | |
pose_show = torch.cat([cam2world_pose.reshape(-1, 16), | |
FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device) | |
pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device) | |
# Load and resize expression image | |
exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512))) | |
# Load expression coefficients | |
exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0) | |
exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256) | |
# Load motion data | |
motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device) | |
# img_ref_double = duplicate_batch(img_ref, batch_size=2) | |
# motion_app_double = duplicate_batch(motion_app, batch_size=2) | |
# motion_double = duplicate_batch(motion, batch_size=2) | |
# pose_double = torch.cat([pose_show, pose], dim=0) | |
# exp_target_double = duplicate_batch(exp_target, batch_size=2) | |
# samples_double = duplicate_batch(samples, batch_size=2) | |
# Select refine_net processing method | |
final_out = render_model( | |
img_ref, None, motion_app, motion, c=pose, mesh=exp_target, | |
triplane_recon=samples, | |
ws_avg=ws_avg, motion_scale=1. | |
) | |
# Process output image | |
final_out_show = trans(final_out['image_sr'][0].unsqueeze(0)) | |
# final_out_notshow = trans(final_out['image_sr'][0].unsqueeze(0)) | |
depth = final_out['image_depth'][0].unsqueeze(0) | |
depth = -depth | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1 | |
depth = trans(depth) | |
depth = np.repeat(depth[:, :, :], 3, axis=2) | |
# Save output images | |
frame_name = f'{str(frame_index).zfill(4)}.png' | |
Image.fromarray(depth, 'RGB').save(os.path.join(save_frames_path_depth, frame_name)) | |
Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_out, frame_name)) | |
# Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name)) | |
# Generate videos | |
images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + video_name+ '_out.mp4')) | |
images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + video_name+ '_depth.mp4')) | |
logging.info(f"✅ Video generation completed successfully!") | |
return os.path.join(save_path_base, image_name + video_name+ '_out.mp4'), os.path.join(save_path_base, image_name + video_name+'_depth.mp4') | |
except Exception as e: | |
return None, None | |
def get_image_base64(path): | |
with open(path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode() | |
return f"data:image/png;base64,{encoded_string}" | |
def assert_input_image(input_image): | |
if input_image is None: | |
raise gr.Error("No image selected or uploaded!") | |
def process_image(input_image_dir, source_type, is_style, save_dir): | |
""" 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """ | |
process_img_input_dir = os.path.join(save_dir, 'input_image') | |
process_img_save_dir = os.path.join(save_dir, 'processed_img') | |
base_name = os.path.basename(input_image_dir) # abc123.jpg | |
name_without_ext = os.path.splitext(base_name)[0] # abc123 | |
image_name_true = name_without_ext + ".png" | |
os.makedirs(process_img_save_dir, exist_ok=True) | |
os.makedirs(process_img_input_dir, exist_ok=True) | |
if source_type == "example": | |
image = Image.open(input_image_dir) | |
return image, source_type, image_name_true | |
else: | |
# input_process_model.inference(input_image, process_img_save_dir) | |
shutil.copy(input_image_dir, process_img_input_dir) | |
input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False) | |
files = os.listdir(os.path.join(process_img_save_dir, 'dataset/images512x512/input_image')) | |
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))] | |
# 使用 difflib 查找相似文件名 | |
matches = difflib.get_close_matches(image_name_true, image_files, n=1, cutoff=0.1) | |
closest_match = matches[0] | |
imge_dir = os.path.join(process_img_save_dir, 'dataset/images512x512/input_image', closest_match) | |
image = Image.open(imge_dir) | |
image_name_true = closest_match | |
return image, source_type, image_name_true # 这里替换成 处理用户上传图片的逻辑 | |
def style_transfer(processed_image, style_prompt, cfg, strength, save_base,image_name_true): | |
""" | |
🎭 这个函数用于风格转换 | |
✅ 你可以在这里填入你的风格化代码 | |
""" | |
src_img_pil = Image.open(processed_image) | |
img_name = os.path.basename(processed_image) | |
save_dir = os.path.join(save_base, 'style_img') | |
os.makedirs(save_dir, exist_ok=True) | |
control_image = generate_annotation(src_img_pil, max_faces=1) | |
print(style_prompt) | |
trg_img_pil = pipeline_sd( | |
prompt=style_prompt, | |
image=src_img_pil, | |
strength=strength, | |
control_image=Image.fromarray(control_image), | |
guidance_scale=cfg, | |
negative_prompt='worst quality, normal quality, low quality, low res, blurry', | |
num_inference_steps=30, | |
controlnet_conditioning_scale=1.5 | |
)['images'][0] | |
trg_img_pil.save(os.path.join(save_dir, image_name_true)) | |
return trg_img_pil # 🚨 这里需要替换成你的风格转换逻辑 | |
def reset_flag(): | |
return False | |
css = """ | |
/* ✅ 让所有 Image 居中 + 自适应宽度 */ | |
.gr-image img { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
max-width: 100%; | |
height: auto; | |
} | |
/* ✅ 让所有 Video 居中 + 自适应宽度 */ | |
.gr-video video { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
max-width: 100%; | |
height: auto; | |
} | |
/* ✅ 可选:让按钮和 markdown 居中 */ | |
#generate_block { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
justify-content: center; | |
margin-top: 1rem; | |
} | |
/* 可选:让整个容器宽一点 */ | |
#main_container { | |
max-width: 1280px; /* ✅ 例如限制在 1280px 内 */ | |
margin-left: auto; /* ✅ 水平居中 */ | |
margin-right: auto; | |
padding-left: 1rem; | |
padding-right: 1rem; | |
} | |
""" | |
def launch_gradio_app(): | |
styles = { | |
"Ghibli": "Ghibli style avatar, anime style", | |
"Pixar": "a 3D render of a face in Pixar style", | |
"Lego": "a 3D render of a head of a lego man 3D model", | |
"Greek Statue": "a FHD photo of a white Greek statue", | |
"Elf": "a FHD photo of a face of a beautiful elf with silver hair in live action movie", | |
"Zombie": "a FHD photo of a face of a zombie", | |
"Tekken": "a 3D render of a Tekken game character", | |
"Devil": "a FHD photo of a face of a devil in fantasy movie", | |
"Steampunk": "Steampunk style portrait, mechanical, brass and copper tones", | |
"Mario": "a 3D render of a face of Super Mario", | |
"Orc": "a FHD photo of a face of an orc in fantasy movie", | |
"Masque": "a FHD photo of a face of a person in masquerade", | |
"Skeleton": "a FHD photo of a face of a skeleton in fantasy movie", | |
"Peking Opera": "a FHD photo of face of character in Peking opera with heavy make-up", | |
"Yoda": "a FHD photo of a face of Yoda in Star Wars", | |
"Hobbit": "a FHD photo of a face of Hobbit in Lord of the Rings", | |
"Stained Glass": "Stained glass style, portrait, beautiful, translucent", | |
"Graffiti": "Graffiti style portrait, street art, vibrant, urban, detailed, tag", | |
"Pixel-art": "pixel art style portrait, low res, blocky, pixel art style", | |
"Retro": "Retro game art style portrait, vibrant colors", | |
"Ink": "a portrait in ink style, black and white image", | |
} | |
with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600], css=css, elem_id="main_container") as demo: | |
logo_url = "./docs/AvatarArtist.png" | |
logo_base64 = get_image_base64(logo_url) | |
# 🚀 让 Logo 居中 & 标题对齐 | |
gr.HTML( | |
f""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-bottom: 20px;"> | |
<img src="{logo_base64}" style="height:50px; margin-right: 15px; display: block;" onerror="this.style.display='none'"/> | |
<h1 style="font-size: 32px; font-weight: bold;">AvatarArtist: Open-Domain 4D Avatarization</h1> | |
</div> | |
""" | |
) | |
# 🚀 让按钮在一行对齐 | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; gap: 10px; margin-top: 10px;"> | |
<a title="Website" href="https://kumapowerliu.github.io/AvatarArtist/" target="_blank" rel="noopener noreferrer"> | |
<img src="https://img.shields.io/badge/Website-Visit-blue?style=for-the-badge&logo=GoogleChrome"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2503.19906" target="_blank" rel="noopener noreferrer"> | |
<img src="https://img.shields.io/badge/arXiv-Paper-red?style=for-the-badge&logo=arXiv"> | |
</a> | |
<a title="Github" href="https://github.com/ant-research/AvatarArtist" target="_blank" rel="noopener noreferrer"> | |
<img src="https://img.shields.io/github/stars/ant-research/AvatarArtist?style=for-the-badge&logo=github&logoColor=white&color=orange"> | |
</a> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<div style="color: inherit; text-align: left; font-size: 16px; line-height: 1.6; margin-top: 20px; padding: 16px; border-radius: 10px; border: 1px solid rgba(0,0,0,0.1); background-color: rgba(240, 240, 240, 0.6); backdrop-filter: blur(2px);"> | |
<strong>🧑🎨 How to use this demo:</strong> | |
<ol style="margin-top: 10px; padding-left: 20px;"> | |
<li><strong>Select or upload a source image</strong> – this will be the avatar's face.</li> | |
<li><strong>Select or upload a target video</strong> – the avatar will mimic this motion.</li> | |
<li><strong>Click the <em>Process Image</em> button</strong> – this prepares the source image to meet our model's input requirements.</li> | |
<li><strong>(Optional)</strong> Click <em>Apply Style</em> to change the appearance of the processed image – we offer a variety of fun styles to choose from!</li> | |
<li><strong>Click <em>Generate Avatar</em></strong> to create the final animated result driven by the target video.</li> | |
</ol> | |
<p style="margin-top: 10px;"><strong>🎨 Tip:</strong> Try different styles to get various artistic effects for your avatar!</p> | |
</div> | |
""" | |
) | |
# 🚀 添加重要提示框 | |
gr.HTML( | |
""" | |
<div style="background-color: #FFDDDD; padding: 15px; border-radius: 10px; border: 2px solid red; text-align: center; margin-top: 20px;"> | |
<h4 style="color: red; font-size: 18px;"> | |
🚨 <strong style="color: red;">Important Notes:</strong> Please try to provide a <u>front-facing</u> or <u>full-face</u> image without obstructions. | |
</h4> | |
<p style="color: black; font-size: 16px;"> | |
❌ Our demo does <strong style="color: black;">not</strong> support uploading videos with specific motions because processing requires time.<br> | |
✅ Feel free to check out our <a href="https://github.com/ant-research/AvatarArtist" target="_blank" style="color: red; font-weight: bold;">GitHub repository</a> to drive portraits using your desired motions. | |
</p> | |
</div> | |
""" | |
) | |
# DISPLAY | |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs" | |
video_folder = "./demo_data/target_video" | |
examples_images = sorted( | |
[os.path.join(image_folder, f) for f in os.listdir(image_folder) if | |
f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
) | |
examples_videos = sorted( | |
[os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.lower().endswith('.mp4')] | |
) | |
print(examples_videos) | |
source_type = gr.State("example") | |
is_from_example = gr.State(value=True) | |
is_styled = gr.State(value=False) | |
working_dir = gr.State() | |
image_name_true = gr.State() | |
with gr.Row(): | |
with gr.Column(variant='panel'): | |
with gr.Tabs(elem_id="input_image"): | |
with gr.TabItem('🎨 Upload Image'): | |
input_image = gr.Image( | |
label="Upload Source Image", | |
value=os.path.join(image_folder, '02025.png'), | |
image_mode="RGB", height=512, container=True, | |
sources="upload", type="filepath" | |
) | |
def mark_as_example(example_image): | |
print("✅ mark_as_example called") | |
return "example", True, False | |
def mark_as_custom(user_image, is_from_example_flag): | |
print("✅ mark_as_custom called") | |
if is_from_example_flag: | |
print("⚠️ Ignored mark_as_custom triggered by example") | |
return "example", False, False | |
return "custom", False, False | |
input_image.change( | |
mark_as_custom, | |
inputs=[input_image, is_from_example], | |
outputs=[source_type, is_from_example, is_styled] # ✅ 只返回 source_type,不要输出 input_image | |
) | |
# ✅ 让 `Examples` 组件单独占一行,并绑定点击事件 | |
with gr.Row(): | |
example_component = gr.Examples( | |
examples=examples_images, | |
inputs=[input_image], | |
examples_per_page=10, | |
) | |
# ✅ 监听 `Examples` 的 `click` 事件 | |
example_component.dataset.click( | |
fn=mark_as_example, | |
inputs=[input_image], | |
outputs=[source_type, is_from_example, is_styled] | |
) | |
with gr.Column(variant='panel' ): | |
with gr.Tabs(elem_id="input_video"): | |
with gr.TabItem('🎬 Target Video'): | |
video_input = gr.Video( | |
label="Select Target Motion", | |
height=512, container=True,interactive=False, format="mp4", | |
value=os.path.join(video_folder, 'Obama.mp4') | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=examples_videos, | |
inputs=[video_input], | |
examples_per_page=10, | |
) | |
with gr.Column(variant='panel' ): | |
with gr.Tabs(elem_id="processed_image"): | |
with gr.TabItem('🖼️ Processed Image'): | |
processed_image = gr.Image( | |
label="Processed Image", | |
image_mode="RGB", type="filepath", | |
elem_id="processed_image", | |
height=512, container=True, | |
interactive=False | |
) | |
processed_image_button = gr.Button("🔧 Process Image", variant="primary") | |
with gr.Column(variant='panel' ): | |
with gr.Tabs(elem_id="style_transfer"): | |
with gr.TabItem('🎭 Style Transfer'): | |
style_image = gr.Image( | |
label="Style Image", | |
image_mode="RGB", type="filepath", | |
elem_id="style_image", | |
height=512, container=True, | |
interactive=False | |
) | |
style_choice = gr.Dropdown( | |
choices=list(styles.keys()), | |
label="Choose Style", | |
value="Pixar" | |
) | |
cfg_slider = gr.Slider( | |
minimum=3.0, maximum=10.0, value=7.5, step=0.1, | |
label="CFG Scale" | |
) | |
strength_slider = gr.Slider( | |
minimum=0.45, maximum=0.75, value=0.6, step=0.05, | |
label="SDEdit Strength" | |
) | |
style_button = gr.Button("🎨 Apply Style", interactive=False, elem_id="style_generate", variant='primary') | |
gr.Markdown( | |
""" | |
⚠️ **Please click 'Process Image' first.** Then use **Apply Style** to stylize the image. | |
`SDEdit Strength`: Higher values make the result closer to the target style; lower values preserve more of the original face. | |
Try to keep facial features recognizable — avoid excessive distortion. | |
""" | |
) | |
with gr.Row(): | |
with gr.Tabs(elem_id="render_output"): | |
with gr.TabItem('🎥 Animation Results'): | |
# ✅ 让 `Generate Avatar` 按钮单独占一行 | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="generate_block", min_width=200): | |
submit = gr.Button('🚀 Generate Avatar', elem_id="avatarartist_generate", variant='primary', | |
interactive=False) | |
gr.Markdown("⬇️ Please click **Process Image** first before generating.", | |
elem_id="generate_tip") | |
# ✅ 让两个 `Animation Results` 窗口并排 | |
with gr.Row(): | |
output_video = gr.Video( | |
label="Generated Animation Input Video View", | |
format="mp4", height=512, width=512, | |
autoplay=True | |
) | |
output_video_1 = gr.Video( | |
label="Generated Animation Rotate View Depth", | |
format="mp4", height=512, width=512, | |
autoplay=True | |
) | |
def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir, image_name_true): | |
try: | |
styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir, image_name_true) | |
return styled, True | |
except Exception as e: | |
return None, True | |
def process_image_and_enable_style(input_image, source_type, is_styled, wd): | |
try: | |
processed_result, updated_source_type, image_name_true = process_image(input_image, source_type, is_styled, wd) | |
return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True), image_name_true | |
except Exception as e: | |
return None, updated_source_type, gr.update(interactive=False), gr.update(interactive=False), image_name_true | |
processed_image_button.click( | |
fn=prepare_working_dir, | |
inputs=[working_dir, is_styled], | |
outputs=[working_dir], | |
queue=False, | |
).success( | |
fn=process_image_and_enable_style, | |
inputs=[input_image, source_type, is_styled, working_dir], | |
outputs=[processed_image, source_type, style_button, submit, image_name_true], | |
queue=True | |
) | |
style_button.click( | |
fn=apply_style_and_mark, | |
inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir, image_name_true], | |
outputs=[style_image, is_styled] | |
) | |
submit.click( | |
fn=avatar_generation, | |
inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image, image_name_true], | |
outputs=[output_video, output_video_1], # ⏳ 稍后展示视频 | |
queue=True | |
) | |
demo.queue() | |
demo.launch(server_name="0.0.0.0") | |
if __name__ == '__main__': | |
import torch.multiprocessing as mp | |
import transformers | |
mp.set_start_method('spawn', force=True) | |
# logging.info("Environment Variables: %s" % os.environ) | |
# logging.info("Installing CUDA extensions...") | |
# if _get_output(["nvcc", "--version"]) is None: | |
# logging.info("Installing CUDA toolkit...") | |
# install_cuda_toolkit() | |
# logging.info("installCUDA: %s" % _get_output(["nvcc", "--version"])) | |
# else: | |
# logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"])) | |
# print("CUDA_HOME =", os.environ.get("CUDA_HOME")) | |
# from torch.utils.cpp_extension import CUDA_HOME | |
# print("CUDA_HOME from PyTorch:", CUDA_HOME) | |
launch_pretrained() | |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs" | |
example_img_names = os.listdir(image_folder) | |
render_model, sample_steps, DiT_model, \ | |
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, device, input_process_model = model_define() | |
controlnet_path = './pretrained_model/control' | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, torch_dtype=torch.float16 | |
) | |
sd_path = './pretrained_model/sd21' | |
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
sd_path, torch_dtype=torch.float16, | |
use_safetensors=True, controlnet=controlnet, variant="fp16" | |
).to(device) | |
pipeline_sd.scheduler=DPMSolverMultistepScheduler.from_config(pipeline_sd.scheduler.config, use_karras_sigmas=True) | |
demo_cam = False | |
base_coff = np.load( | |
'pretrained_model/temp.npy').astype( | |
np.float32) | |
base_coff = torch.from_numpy(base_coff).float() | |
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff) | |
launch_gradio_app() | |