File size: 9,333 Bytes
65de097
1b96de2
fcda696
65de097
99f1747
fcda696
8b08f87
 
65de097
fcda696
 
d1a9cab
 
 
 
 
 
 
 
 
 
e89abb6
d1a9cab
 
 
eb634d0
65de097
eb634d0
 
 
673ce17
 
 
a0d75fa
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
65de097
 
0f5e1d6
65de097
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65de097
d1a9cab
 
 
 
65de097
d1a9cab
 
 
 
 
 
 
 
65de097
d1a9cab
 
 
 
a0d75fa
d1a9cab
 
 
 
 
 
 
eb634d0
 
d1a9cab
 
65de097
d1a9cab
 
65de097
 
d1a9cab
 
 
 
 
 
65de097
 
d1a9cab
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce6bf0
673ce17
d1a9cab
348d6f5
eb634d0
 
a0d75fa
d1a9cab
 
 
7debb98
673ce17
e992f1f
eb634d0
 
 
 
31b5073
d998c35
eb634d0
673ce17
 
 
 
 
 
 
 
d1a9cab
673ce17
 
 
 
d1a9cab
 
 
 
c0a41a5
673ce17
c0a41a5
 
a261746
d1a9cab
 
 
 
a261746
d1a9cab
 
 
 
a261746
673ce17
c0a41a5
 
 
 
 
 
0d4cf4b
a261746
 
 
 
c0a41a5
a261746
 
 
 
0d4cf4b
d1a9cab
 
 
 
 
 
d306e8b
 
d1a9cab
 
 
 
 
 
 
 
 
673ce17
d1a9cab
 
bbc3d10
d0f36d2
d1a9cab
 
d0f36d2
d1a9cab
 
65de097
2b1ec5b
65de097
03897e6
d1a9cab
 
65de097
d1a9cab
018c57c
d1a9cab
 
03897e6
673ce17
 
 
d1a9cab
03897e6
 
d1a9cab
 
 
03897e6
d1a9cab
03897e6
673ce17
65de097
d1a9cab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import sys
import os

sys.path.append('./')
os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
import spaces
from fastapi import FastAPI
app = FastAPI()

from PIL import Image
import gradio as gr
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
import torch
from typing import List
from transformers import AutoTokenizer
import numpy as np
from utils_mask import get_mask_location
from torchvision import transforms
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from torchvision.transforms.functional import to_pil_image
import apply_net
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation


def pil_to_binary_mask(pil_image, threshold=0):
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*255).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    return output_mask


base_path = 'Keshabwi66/SmartLugaModel'

unet = UNet2DConditionModel.from_pretrained(
    base_path,
    subfolder="unet",
    torch_dtype=torch.float16,
)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer_2",
    revision=None,
    use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")

text_encoder_one = CLIPTextModel.from_pretrained(
    base_path,
    subfolder="text_encoder",
    torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
    base_path,
    subfolder="text_encoder_2",
    torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    base_path,
    subfolder="image_encoder",
    torch_dtype=torch.float16,
    )
vae = AutoencoderKL.from_pretrained(base_path,
                                    subfolder="vae",
                                    torch_dtype=torch.float16,
)

# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
    base_path,
    subfolder="unet_encoder",
    torch_dtype=torch.float16,
)

parsing_model = Parsing(0)
openpose_model = OpenPose(0)

UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
    )

pipe = TryonPipeline.from_pretrained(
        base_path,
        unet=unet,
        vae=vae,
        feature_extractor= CLIPImageProcessor(),
        text_encoder = text_encoder_one,
        text_encoder_2 = text_encoder_two,
        tokenizer = tokenizer_one,
        tokenizer_2 = tokenizer_two,
        scheduler = noise_scheduler,
        image_encoder=image_encoder,
        torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder

@spaces.GPU
def start_tryon(boy,girl,person_img,cloth_img, garment_des, denoise_steps=10, seed=42):
    # Assuming device is set up (e.g., "cuda" or "cpu")
    
    device="cuda"
    openpose_model.preprocessor.body_estimation.model.to(device)
    pipe.to(device)
    pipe.unet_encoder.to(device)

    # Resize and prepare images
    garm_img = cloth_img.convert("RGB").resize((768, 1024))
    human_img = person_img.convert("RGB").resize((768,1024))

    is_checked=True;
    if is_checked:
        keypoints = openpose_model(human_img.resize((384,512)))
        model_parse, _ = parsing_model(human_img.resize((384,512)))
        mask, mask_gray= get_mask_location('hd', "upper_body", model_parse, keypoints)
        mask = mask.resize((768,1024))

    human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
    human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")

    args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
    # verbosity = getattr(args, "verbosity", None)
    pose_img = args.func(args,human_img_arg)    
    pose_img = pose_img[:,:,::-1]    
    pose_img = Image.fromarray(pose_img).resize((768,1024))

    if boy:
        prompt = "A boy is wearing"+garment_des
    if girl:
        prompt= "A girl is wearing"+garment_des

    # Embedding generation for prompts
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            # Generate text embeddings for garment description
            prompt = prompt
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            with torch.inference_mode():
             (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
             )= pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
             )
            prompt = "A photo of " + garment_des
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            if not isinstance(prompt, List):
                        prompt = [prompt] * 1
            if not isinstance(negative_prompt, List):
                        negative_prompt = [negative_prompt] * 1
            with torch.inference_mode():
                (
                   prompt_embeds_cloth, 
                   _,
                   _,
                   _,
                )= pipe.encode_prompt(
                   prompt,
                   num_images_per_prompt=1,
                   do_classifier_free_guidance=False,
                   negative_prompt=negative_prompt,
                )

            # Convert images to tensors for processing
            pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
            garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)

            # Prepare the generator with optional seed
            generator = torch.Generator(device).manual_seed(seed) if seed is not None else None

            # Generate the virtual try-on output image
            images = pipe(
                prompt_embeds=prompt_embeds.to(device, torch.float16),
                negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
                pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
                num_inference_steps=denoise_steps,
                generator=generator,
                strength=1.0,
                pose_img=pose_img_tensor.to(device, torch.float16),
                text_embeds_cloth=prompt_embeds_cloth.to(device, torch.float16),
                cloth=garm_tensor.to(device, torch.float16),
                mask_image=mask,
                image=human_img,
                height=1024,
                width=768,
                ip_adapter_image=garm_img.resize((768, 1024)),
                guidance_scale=2.0,
            )[0]

    return images[0].resize(person_img.size)


# Gradio interface for the virtual try-on model
image_blocks = gr.Blocks().queue()

with image_blocks as demo:
    gr.Markdown("## SmartLuga")    
    with gr.Row():
        with gr.Column():
            person_img = gr.Image(label='Person Image', sources='upload', type="pil")
            boy = gr.Checkbox(label="Yes", info="Boy",value=True)
            girl = gr.Checkbox(label="Yes", info="Girl",value=False)

        with gr.Column():
            cloth_img = gr.Image(label='Garment Image', sources='upload', type="pil")
            garment_des = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", label="Garment Description")


        with gr.Column():
            image_out = gr.Image(label="Output Image", elem_id="output-img", show_share_button=False)

    try_button = gr.Button(value="Try-on")
    try_button.click(fn=start_tryon, inputs=[boy,girl,person_img, cloth_img, garment_des], outputs=[image_out], api_name='tryon')

image_blocks.launch()