import sys sys.path.append('./') import gradio as gr import torch from PIL import Image import torch.nn.functional as F from transformers import CLIPImageProcessor # Add necessary imports and initialize the model as in your code... from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal import matplotlib.pyplot as plt import torch.utils.data as data import torchvision import numpy as np import torch import torch.nn.functional as F from accelerate.logging import get_logger from accelerate.utils import set_seed from torchvision import transforms from diffusers import AutoencoderKL, DDPMScheduler from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel from src.unet_hacked_tryon import UNet2DConditionModel from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline # Define a class to hold configuration arguments class Args: def __init__(self): self.pretrained_model_name_or_path = "yisol/IDM-VTON" self.width = 768 self.height = 1024 self.num_inference_steps = 10 self.seed = 42 self.guidance_scale = 2.0 self.mixed_precision = None # Determine the device to be used for computations (CUDA if available) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger = get_logger(__name__, log_level="INFO") def pil_to_tensor(images): images = np.array(images).astype(np.float32) / 255.0 images = torch.from_numpy(images.transpose(2, 0, 1)) return images args = Args() # Define the data type for model weights weight_dtype = torch.float16 if args.seed is not None: set_seed(args.seed) # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.float16, ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch.float16, ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( args.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.float16, ) unet_encoder = UNet2DConditionModel_ref.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet_encoder", torch_dtype=torch.float16, ) text_encoder_one = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16, ) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", torch_dtype=torch.float16, ) tokenizer_one = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None, use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=None, use_fast=False, ) # Freeze vae and text_encoder and set unet to trainable unet.requires_grad_(False) vae.requires_grad_(False) image_encoder.requires_grad_(False) unet_encoder.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) unet_encoder.to(device, weight_dtype) unet.eval() unet_encoder.eval() pipe = TryonPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, vae=vae, feature_extractor= CLIPImageProcessor(), text_encoder = text_encoder_one, text_encoder_2 = text_encoder_two, tokenizer = tokenizer_one, tokenizer_2 = tokenizer_two, scheduler = noise_scheduler, image_encoder=image_encoder, unet_encoder = unet_encoder, torch_dtype=torch.float16, ).to(device) # pipe.enable_sequential_cpu_offload() # pipe.enable_model_cpu_offload() # pipe.enable_vae_slicing() # Function to generate the image based on inputs def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des): # Prepare the input images as tensors person_image = person_image.resize((args.width, args.height)) cloth_image = cloth_image.resize((args.width, args.height)) mask_image = mask_image.resize((args.width, args.height)) pose_image = pose_image.resize((args.width, args.height)) # Define transformations transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) guidance_scale=2.0 seed=42 to_tensor = transforms.ToTensor() person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension cloth_pure = transform(cloth_image).unsqueeze(0).to(device) mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel pose_tensor = transform(pose_image).unsqueeze(0).to(device) # Prepare text prompts prompt = ["A person wearing the cloth"+cloth_des] # Example prompt negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"] # Encode prompts with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_cloth = ["a photo of"+cloth_des] with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt_cloth, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) # Encode garment using IP-Adapter clip_processor = CLIPImageProcessor() image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device) # Generate the image generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None with torch.no_grad(): images = pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=args.num_inference_steps, generator=generator, strength=1.0, pose_img=pose_tensor, text_embeds_cloth=prompt_embeds_c, cloth=cloth_pure, mask_image=mask_tensor, image=(person_tensor + 1.0) / 2.0, height=args.height, width=args.width, guidance_scale=guidance_scale, ip_adapter_image=image_embeds, )[0] # Convert output image to PIL format for display generated_image = transforms.ToPILImage()(images[0]) return generated_image # Create Gradio interface iface = gr.Interface( fn=generate_virtual_try_on, inputs=[ gr.Image(type="pil", label="Person Image"), gr.Image(type="pil", label="Cloth Image"), gr.Image(type="pil", label="Mask Image"), gr.Image(type="pil", label="Pose Image"), gr.Textbox(label="cloth_des"), # Add text input ], outputs=gr.Image(type="pil", label="Generated Image"), ) # Launch the interface iface.launch()