import sys sys.path.append('./') import gradio as gr import torch from PIL import Image import torch.nn.functional as F from transformers import CLIPImageProcessor # Add necessary imports and initialize the model as in your code... from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal from ip_adapter.ip_adapter import Resampler import matplotlib.pyplot as plt import torch.utils.data as data import torchvision import numpy as np import torch import torch.nn.functional as F from accelerate.logging import get_logger from accelerate.utils import set_seed from torchvision import transforms from diffusers import AutoencoderKL, DDPMScheduler from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, from src.unet_hacked_tryon import UNet2DConditionModel from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline # Define a class to hold configuration arguments class Args: def __init__(self): self.pretrained_model_name_or_path = "yisol/IDM-VTON" self.width = 768 self.height = 1024 self.num_inference_steps = 10 self.seed = 42 self.guidance_scale = 2.0 self.mixed_precision = None # Determine the device to be used for computations (CUDA if available) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger = get_logger(__name__, log_level="INFO") def pil_to_tensor(images): images = np.array(images).astype(np.float32) / 255.0 images = torch.from_numpy(images.transpose(2, 0, 1)) return images args = Args() # Define the data type for model weights weight_dtype = torch.float16 if args.seed is not None: set_seed(args.seed) # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.float16, ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch.float16, ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( args.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.float16, ) unet_encoder = UNet2DConditionModel_ref.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet_encoder", torch_dtype=torch.float16, ) text_encoder_one = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16, ) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", torch_dtype=torch.float16, ) tokenizer_one = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None, use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=None, use_fast=False, ) # Freeze vae and text_encoder and set unet to trainable unet.requires_grad_(False) vae.requires_grad_(False) image_encoder.requires_grad_(False) unet_encoder.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) unet_encoder.to(device, weight_dtype) unet.eval() unet_encoder.eval() pipe = TryonPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, vae=vae, feature_extractor= CLIPImageProcessor(), text_encoder = text_encoder_one, text_encoder_2 = text_encoder_two, tokenizer = tokenizer_one, tokenizer_2 = tokenizer_two, scheduler = noise_scheduler, image_encoder=image_encoder, unet_encoder = unet_encoder, torch_dtype=torch.float16, ).to(device) # pipe.enable_sequential_cpu_offload() # pipe.enable_model_cpu_offload() # pipe.enable_vae_slicing() # Function to generate the image based on inputs def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des): # Prepare the input images as tensors person_image = person_image.resize((args.width, args.height)) cloth_image = cloth_image.resize((args.width, args.height)) mask_image = mask_image.resize((args.width, args.height)) pose_image = pose_image.resize((args.width, args.height)) # Define transformations transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) guidance_scale=2.0 seed=42 to_tensor = transforms.ToTensor() person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension cloth_pure = transform(cloth_image).unsqueeze(0).to(device) mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel pose_tensor = transform(pose_image).unsqueeze(0).to(device) # Prepare text prompts prompt = ["A person wearing the cloth"+cloth_des] # Example prompt negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"] # Encode prompts with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_cloth = ["a photo of"+cloth_des] with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt_cloth, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) # Encode garment using IP-Adapter clip_processor = CLIPImageProcessor() image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device) # Generate the image generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None with torch.no_grad(): images = pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=args.num_inference_steps, generator=generator, strength=1.0, pose_img=pose_tensor, text_embeds_cloth=prompt_embeds_c, cloth=cloth_pure, mask_image=mask_tensor, image=(person_tensor + 1.0) / 2.0, height=args.height, width=args.width, guidance_scale=guidance_scale, ip_adapter_image=image_embeds, )[0] # Convert output image to PIL format for display generated_image = transforms.ToPILImage()(images[0]) return generated_image # Create Gradio interface iface = gr.Interface( fn=generate_virtual_try_on, inputs=[ gr.Image(type="pil", label="Person Image"), gr.Image(type="pil", label="Cloth Image"), gr.Image(type="pil", label="Mask Image"), gr.Image(type="pil", label="Pose Image"), gr.Textbox(label="cloth_des"), # Add text input ], outputs=gr.Image(type="pil", label="Generated Image"), ) # Launch the interface iface.launch()