File size: 8,431 Bytes
65de097
1b96de2
d1a9cab
65de097
99f1747
3ce6bf0
8b08f87
 
65de097
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
 
d306e8b
 
 
d1a9cab
 
65de097
ec1e3e3
65de097
a0d75fa
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
65de097
 
d1a9cab
65de097
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65de097
d1a9cab
 
 
 
65de097
d1a9cab
 
 
 
 
 
 
 
65de097
d1a9cab
 
 
 
a0d75fa
d1a9cab
 
 
 
 
 
 
45d5f61
d1a9cab
 
65de097
d1a9cab
 
65de097
 
d1a9cab
 
 
 
 
 
65de097
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce6bf0
d306e8b
d1a9cab
a0d75fa
d1a9cab
 
 
7debb98
 
bbc3d10
d1a9cab
 
7debb98
d1a9cab
c0a41a5
2d8cc52
d1a9cab
 
 
c0a41a5
 
 
 
a261746
d1a9cab
 
 
 
a261746
d1a9cab
 
 
 
a261746
c0a41a5
 
 
 
 
 
 
0d4cf4b
a261746
 
 
 
c0a41a5
a261746
 
 
 
0d4cf4b
d1a9cab
 
 
 
 
 
d306e8b
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
bbc3d10
d0f36d2
d1a9cab
 
d0f36d2
d1a9cab
 
65de097
cdc6290
65de097
03897e6
d1a9cab
 
65de097
d1a9cab
018c57c
d1a9cab
 
03897e6
d306e8b
03897e6
65de097
d1a9cab
03897e6
 
d1a9cab
 
03897e6
d306e8b
d1a9cab
 
03897e6
d1a9cab
03897e6
d306e8b
65de097
d1a9cab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import sys
import os

sys.path.append('./')
os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
import spaces
from fastapi import FastAPI
app = FastAPI()

from PIL import Image
import gradio as gr
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List

import torch
import os
from transformers import AutoTokenizer
import numpy as np
from torchvision import transforms


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def pil_to_binary_mask(pil_image, threshold=0):
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*255).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    return output_mask


base_path = 'yisol/IDM-VTON'

unet = UNet2DConditionModel.from_pretrained(
    base_path,
    subfolder="unet",
    torch_dtype=torch.float16,
)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer_2",
    revision=None,
    use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")

text_encoder_one = CLIPTextModel.from_pretrained(
    base_path,
    subfolder="text_encoder",
    torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
    base_path,
    subfolder="text_encoder_2",
    torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    base_path,
    subfolder="image_encoder",
    torch_dtype=torch.float16,
    )
vae = AutoencoderKL.from_pretrained(base_path,
                                    subfolder="vae",
                                    torch_dtype=torch.float16,
)

# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
    base_path,
    subfolder="unet_encoder",
    torch_dtype=torch.float16,
)



UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
    )

pipe = TryonPipeline.from_pretrained(
        base_path,
        unet=unet,
        vae=vae,
        feature_extractor= CLIPImageProcessor(),
        text_encoder = text_encoder_one,
        text_encoder_2 = text_encoder_two,
        tokenizer = tokenizer_one,
        tokenizer_2 = tokenizer_two,
        scheduler = noise_scheduler,
        image_encoder=image_encoder,
        torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder

@spaces.GPU
def start_tryon(person_img, pose_img, mask_img, cloth_img, garment_des, denoise_steps, seed):
    # Assuming device is set up (e.g., "cuda" or "cpu")
    pipe.to(device)
    pipe.unet_encoder.to(device)

    # Resize and prepare images
    garm_img = cloth_img.convert("RGB").resize((768, 1024))
    human_img = person_img.convert("RGB").resize((768, 1024))
    mask = pil_to_binary_mask(mask_img.convert("RGB").resize((768, 1024)))

    # Prepare pose image (already uploaded)
    pose_img = pose_img.resize((768, 1024))

    
    
    # Embedding generation for prompts
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            # Generate text embeddings for garment description
            prompt = f"model is wearing {garment_des}"
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            with torch.inference_mode():
             (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
             )= pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
             )
            prompt = "a photo of " + garment_des
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            if not isinstance(prompt, List):
                        prompt = [prompt] * 1
            if not isinstance(negative_prompt, List):
                        negative_prompt = [negative_prompt] * 1
            with torch.inference_mode():
                (
                   prompt_embeds_cloth, 
                   _,
                   _,
                   _,
                )= pipe.encode_prompt(
                   prompt,
                   num_images_per_prompt=1,
                   do_classifier_free_guidance=False,
                   negative_prompt=negative_prompt,
                )

            # Convert images to tensors for processing
            pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
            garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)

            # Prepare the generator with optional seed
            generator = torch.Generator(device).manual_seed(seed) if seed is not None else None

            # Generate the virtual try-on output image
            images = pipe(
                prompt_embeds=prompt_embeds.to(device, torch.float16),
                negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
                pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
                num_inference_steps=denoise_steps,
                generator=generator,
                strength=1.0,
                pose_img=pose_img_tensor.to(device, torch.float16),
                text_embeds_cloth=prompt_embeds_cloth.to(device, torch.float16),
                cloth=garm_tensor.to(device, torch.float16),
                mask_image=mask,
                image=human_img,
                height=1024,
                width=768,
                ip_adapter_image=garm_img.resize((768, 1024)),
                guidance_scale=2.0,
            )[0]

    return images[0]


# Gradio interface for the virtual try-on model
image_blocks = gr.Blocks().queue()

with image_blocks as demo:
    gr.Markdown("## SmartLuga")    
    with gr.Row():
        with gr.Column():
            person_img = gr.Image(label='Person Image', sources='upload', type="pil")
            pose_img = gr.Image(label='Pose Image', sources='upload', type="pil")
            mask_img = gr.Image(label='Mask Image', sources='upload', type="pil")

        with gr.Column():
            cloth_img = gr.Image(label='Garment Image', sources='upload', type="pil")
            garment_des = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", label="Garment Description")

        with gr.Column():
            denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
            seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)

        with gr.Column():
            image_out = gr.Image(label="Output Image", elem_id="output-img", show_share_button=False)

    try_button = gr.Button(value="Try-on")
    try_button.click(fn=start_tryon, inputs=[person_img, pose_img, mask_img, cloth_img, garment_des, denoise_steps, seed], outputs=[image_out], api_name='tryon')

image_blocks.launch()