Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.append('./') | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from transformers import CLIPImageProcessor | |
| # Add necessary imports and initialize the model as in your code... | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal | |
| from ip_adapter.ip_adapter import Resampler | |
| import matplotlib.pyplot as plt | |
| import torch.utils.data as data | |
| import torchvision | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from torchvision import transforms | |
| from diffusers import AutoencoderKL, DDPMScheduler | |
| from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, | |
| from src.unet_hacked_tryon import UNet2DConditionModel | |
| from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref | |
| from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline | |
| # Define a class to hold configuration arguments | |
| class Args: | |
| def __init__(self): | |
| self.pretrained_model_name_or_path = "yisol/IDM-VTON" | |
| self.width = 768 | |
| self.height = 1024 | |
| self.num_inference_steps = 10 | |
| self.seed = 42 | |
| self.guidance_scale = 2.0 | |
| self.mixed_precision = None | |
| # Determine the device to be used for computations (CUDA if available) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| logger = get_logger(__name__, log_level="INFO") | |
| def pil_to_tensor(images): | |
| images = np.array(images).astype(np.float32) / 255.0 | |
| images = torch.from_numpy(images.transpose(2, 0, 1)) | |
| return images | |
| args = Args() | |
| # Define the data type for model weights | |
| weight_dtype = torch.float16 | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Load scheduler, tokenizer and models. | |
| noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="vae", | |
| torch_dtype=torch.float16, | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="unet", | |
| torch_dtype=torch.float16, | |
| ) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="image_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| unet_encoder = UNet2DConditionModel_ref.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="unet_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| text_encoder_one = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| text_encoder_two = CLIPTextModelWithProjection.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.float16, | |
| ) | |
| tokenizer_one = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer", | |
| revision=None, | |
| use_fast=False, | |
| ) | |
| tokenizer_two = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer_2", | |
| revision=None, | |
| use_fast=False, | |
| ) | |
| # Freeze vae and text_encoder and set unet to trainable | |
| unet.requires_grad_(False) | |
| vae.requires_grad_(False) | |
| image_encoder.requires_grad_(False) | |
| unet_encoder.requires_grad_(False) | |
| text_encoder_one.requires_grad_(False) | |
| text_encoder_two.requires_grad_(False) | |
| unet_encoder.to(device, weight_dtype) | |
| unet.eval() | |
| unet_encoder.eval() | |
| pipe = TryonPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| unet=unet, | |
| vae=vae, | |
| feature_extractor= CLIPImageProcessor(), | |
| text_encoder = text_encoder_one, | |
| text_encoder_2 = text_encoder_two, | |
| tokenizer = tokenizer_one, | |
| tokenizer_2 = tokenizer_two, | |
| scheduler = noise_scheduler, | |
| image_encoder=image_encoder, | |
| unet_encoder = unet_encoder, | |
| torch_dtype=torch.float16, | |
| ).to(device) | |
| # pipe.enable_sequential_cpu_offload() | |
| # pipe.enable_model_cpu_offload() | |
| # pipe.enable_vae_slicing() | |
| # Function to generate the image based on inputs | |
| def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des): | |
| # Prepare the input images as tensors | |
| person_image = person_image.resize((args.width, args.height)) | |
| cloth_image = cloth_image.resize((args.width, args.height)) | |
| mask_image = mask_image.resize((args.width, args.height)) | |
| pose_image = pose_image.resize((args.width, args.height)) | |
| # Define transformations | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| guidance_scale=2.0 | |
| seed=42 | |
| to_tensor = transforms.ToTensor() | |
| person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension | |
| cloth_pure = transform(cloth_image).unsqueeze(0).to(device) | |
| mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel | |
| pose_tensor = transform(pose_image).unsqueeze(0).to(device) | |
| # Prepare text prompts | |
| prompt = ["A person wearing the cloth"+cloth_des] # Example prompt | |
| negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"] | |
| # Encode prompts | |
| with torch.inference_mode(): | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| prompt_cloth = ["a photo of"+cloth_des] | |
| with torch.inference_mode(): | |
| ( | |
| prompt_embeds_c, | |
| _, | |
| _, | |
| _, | |
| ) = pipe.encode_prompt( | |
| prompt_cloth, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=False, | |
| negative_prompt=negative_prompt, | |
| ) | |
| # Encode garment using IP-Adapter | |
| clip_processor = CLIPImageProcessor() | |
| image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device) | |
| # Generate the image | |
| generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None | |
| with torch.no_grad(): | |
| images = pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| num_inference_steps=args.num_inference_steps, | |
| generator=generator, | |
| strength=1.0, | |
| pose_img=pose_tensor, | |
| text_embeds_cloth=prompt_embeds_c, | |
| cloth=cloth_pure, | |
| mask_image=mask_tensor, | |
| image=(person_tensor + 1.0) / 2.0, | |
| height=args.height, | |
| width=args.width, | |
| guidance_scale=guidance_scale, | |
| ip_adapter_image=image_embeds, | |
| )[0] | |
| # Convert output image to PIL format for display | |
| generated_image = transforms.ToPILImage()(images[0]) | |
| return generated_image | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_virtual_try_on, | |
| inputs=[ | |
| gr.Image(type="pil", label="Person Image"), | |
| gr.Image(type="pil", label="Cloth Image"), | |
| gr.Image(type="pil", label="Mask Image"), | |
| gr.Image(type="pil", label="Pose Image"), | |
| gr.Textbox(label="cloth_des"), # Add text input | |
| ], | |
| outputs=gr.Image(type="pil", label="Generated Image"), | |
| ) | |
| # Launch the interface | |
| iface.launch() |