Keshabwi66 commited on
Commit
8f0759c
·
verified ·
1 Parent(s): d21101d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -208
app.py CHANGED
@@ -1,243 +1,230 @@
1
  import sys
2
  import os
 
3
  sys.path.append('./')
4
- os.system("pip install huggingface_hub==0.24.7")
5
  os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
6
- import gradio as gr
7
- import torch
8
  import spaces
9
  from fastapi import FastAPI
10
-
11
  app = FastAPI()
12
- from PIL import Image
13
- import torch.nn.functional as F
14
- from transformers import CLIPImageProcessor
15
-
16
- # Add necessary imports and initialize the model as in your code...
17
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
18
- import matplotlib.pyplot as plt
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- import torch.utils.data as data
22
- import torchvision
23
- import numpy as np
24
  import torch
25
- import torch.nn.functional as F
26
- from accelerate.logging import get_logger
27
- from accelerate.utils import set_seed
 
28
  from torchvision import transforms
29
-
30
- from diffusers import AutoencoderKL, DDPMScheduler
31
- from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel
32
-
33
-
34
- from src.unet_hacked_tryon import UNet2DConditionModel
35
- from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
36
- from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
37
- # Define a class to hold configuration arguments
38
- class Args:
39
- def __init__(self):
40
- self.pretrained_model_name_or_path = "yisol/IDM-VTON"
41
- self.width = 768
42
- self.height = 1024
43
- self.num_inference_steps = 10
44
- self.seed = 42
45
- self.guidance_scale = 2.0
46
- self.mixed_precision = None
47
 
48
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- def pil_to_tensor(images):
52
- images = np.array(images).astype(np.float32) / 255.0
53
- images = torch.from_numpy(images.transpose(2, 0, 1))
54
- return images
55
-
56
-
57
-
58
- args = Args()
59
 
60
- # Define the data type for model weights
61
- weight_dtype = torch.float16
62
 
63
- if args.seed is not None:
64
- set_seed(args.seed)
65
-
66
-
67
- # Load scheduler, tokenizer and models.
68
-
69
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
70
- vae = AutoencoderKL.from_pretrained(
71
- args.pretrained_model_name_or_path,
72
- subfolder="vae",
73
- torch_dtype=torch.float16,
74
- )
75
  unet = UNet2DConditionModel.from_pretrained(
76
- args.pretrained_model_name_or_path,
77
- subfolder="unet",
78
- torch_dtype=torch.float16,
79
- )
80
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
81
- args.pretrained_model_name_or_path,
82
- subfolder="image_encoder",
83
- torch_dtype=torch.float16,
84
- )
85
- unet_encoder = UNet2DConditionModel_ref.from_pretrained(
86
- args.pretrained_model_name_or_path,
87
- subfolder="unet_encoder",
88
- torch_dtype=torch.float16,
89
- )
 
 
 
 
 
90
  text_encoder_one = CLIPTextModel.from_pretrained(
91
- args.pretrained_model_name_or_path,
92
- subfolder="text_encoder",
93
- torch_dtype=torch.float16,
94
- )
95
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
96
- args.pretrained_model_name_or_path,
97
- subfolder="text_encoder_2",
98
- torch_dtype=torch.float16,
 
 
 
 
 
99
  )
 
 
 
 
100
 
101
- tokenizer_one = AutoTokenizer.from_pretrained(
102
- args.pretrained_model_name_or_path,
103
- subfolder="tokenizer",
104
- revision=None,
105
- use_fast=False,
106
- )
107
- tokenizer_two = AutoTokenizer.from_pretrained(
108
- args.pretrained_model_name_or_path,
109
- subfolder="tokenizer_2",
110
- revision=None,
111
- use_fast=False,
112
- )
113
- # Freeze vae and text_encoder and set unet to trainable
114
- unet.requires_grad_(False)
115
- vae.requires_grad_(False)
116
  image_encoder.requires_grad_(False)
117
- unet_encoder.requires_grad_(False)
 
118
  text_encoder_one.requires_grad_(False)
119
  text_encoder_two.requires_grad_(False)
120
- unet_encoder.requires_grad_(False)
121
- unet.eval()
122
- unet_encoder.eval()
123
-
 
 
124
 
125
  pipe = TryonPipeline.from_pretrained(
126
- args.pretrained_model_name_or_path,
127
- unet=unet,
128
- vae=vae,
129
- feature_extractor= CLIPImageProcessor(),
130
- text_encoder = text_encoder_one,
131
- text_encoder_2 = text_encoder_two,
132
- tokenizer = tokenizer_one,
133
- tokenizer_2 = tokenizer_two,
134
- scheduler = noise_scheduler,
135
- image_encoder=image_encoder,
136
- unet_encoder = unet_encoder,
137
- torch_dtype=torch.float16,
138
- )
 
139
  @spaces.GPU
140
- def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des):
 
 
141
  pipe.to(device)
142
- # Prepare the input images as tensors
143
- person_image = person_image.resize((args.width, args.height))
144
- cloth_image = cloth_image.resize((args.width, args.height))
145
- mask_image = mask_image.resize((args.width, args.height))
146
- pose_image = pose_image.resize((args.width, args.height))
147
- # Define transformations
148
- transform = transforms.Compose([
149
- transforms.ToTensor(),
150
- transforms.Normalize([0.5], [0.5]),
151
- ])
152
- guidance_scale=2.0
153
- seed=42
154
-
155
- to_tensor = transforms.ToTensor()
156
-
157
- person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension
158
- cloth_pure = transform(cloth_image).unsqueeze(0).to(device)
159
- mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel
160
- pose_tensor = transform(pose_image).unsqueeze(0).to(device)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- # Prepare text prompts
164
- prompt = ["A person wearing the cloth"+cloth_des] # Example prompt
165
- negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"]
166
-
167
- # Encode prompts
168
- with torch.inference_mode():
169
- (
170
- prompt_embeds,
171
- negative_prompt_embeds,
172
- pooled_prompt_embeds,
173
- negative_pooled_prompt_embeds,
174
- ) = pipe.encode_prompt(
175
- prompt,
176
- num_images_per_prompt=1,
177
- do_classifier_free_guidance=True,
178
- negative_prompt=negative_prompt,
179
- )
180
- prompt_cloth = ["a photo of"+cloth_des]
181
- with torch.inference_mode():
182
- (
183
- prompt_embeds_c,
184
- _,
185
- _,
186
- _,
187
- ) = pipe.encode_prompt(
188
- prompt_cloth,
189
- num_images_per_prompt=1,
190
- do_classifier_free_guidance=False,
191
- negative_prompt=negative_prompt,
192
- )
193
 
194
- # Encode garment using IP-Adapter
195
- clip_processor = CLIPImageProcessor()
196
- image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device)
197
 
198
- # Generate the image
199
- generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None
 
 
 
 
 
200
 
201
- with torch.no_grad():
202
- images = pipe(
203
- prompt_embeds=prompt_embeds.to(device,torch.float16),
204
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
205
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
206
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
207
- num_inference_steps=args.num_inference_steps,
208
- generator=generator,
209
- strength=1.0,
210
- pose_img=pose_tensor.to(device,torch.float16),
211
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
212
- cloth=cloth_pure.to(device,torch.float16),
213
- mask_image=mask_tensor.to(device,torch.float16),
214
- image=(person_tensor + 1.0) / 2.0,
215
- height=args.height,
216
- width=args.width,
217
- guidance_scale=guidance_scale,
218
- ip_adapter_image=image_embeds.to(device,torch.float16),
219
- )[0]
220
-
221
- # Convert output image to PIL format for display
222
- generated_image = transforms.ToPILImage()(images[0])
223
- return generated_image
224
-
225
- # Create Gradio interface
226
- iface = gr.Interface(
227
- fn=generate_virtual_try_on,
228
- inputs=[
229
- gr.Image(type="pil", label="Person Image"),
230
- gr.Image(type="pil", label="Cloth Image"),
231
- gr.Image(type="pil", label="Mask Image"),
232
- gr.Image(type="pil", label="Pose Image"),
233
- gr.Textbox(label="cloth_des"), # Add text input
234
-
235
-
236
-
237
-
238
- ],
239
- outputs=gr.Image(type="pil", label="Generated Image"),
240
- )
241
 
242
- # Launch the interface
243
- iface.launch()
 
1
  import sys
2
  import os
3
+
4
  sys.path.append('./')
 
5
  os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
 
 
6
  import spaces
7
  from fastapi import FastAPI
 
8
  app = FastAPI()
 
 
 
 
 
 
 
9
 
10
+ from PIL import Image
11
+ import gradio as gr
12
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
13
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
14
+ from src.unet_hacked_tryon import UNet2DConditionModel
15
+ from transformers import (
16
+ CLIPImageProcessor,
17
+ CLIPVisionModelWithProjection,
18
+ CLIPTextModel,
19
+ CLIPTextModelWithProjection,
20
+ )
21
+ from diffusers import DDPMScheduler,AutoencoderKL
22
+ from typing import List
23
 
 
 
 
24
  import torch
25
+ import os
26
+ from transformers import AutoTokenizer
27
+ import numpy as np
28
+ from utils_mask import get_mask_location
29
  from torchvision import transforms
30
+ import apply_net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
33
 
34
+ def pil_to_binary_mask(pil_image, threshold=0):
35
+ np_image = np.array(pil_image)
36
+ grayscale_image = Image.fromarray(np_image).convert("L")
37
+ binary_mask = np.array(grayscale_image) > threshold
38
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
39
+ for i in range(binary_mask.shape[0]):
40
+ for j in range(binary_mask.shape[1]):
41
+ if binary_mask[i,j] == True :
42
+ mask[i,j] = 1
43
+ mask = (mask*255).astype(np.uint8)
44
+ output_mask = Image.fromarray(mask)
45
+ return output_mask
46
 
 
 
 
 
 
 
 
 
47
 
48
+ base_path = 'yisol/IDM-VTON'
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  unet = UNet2DConditionModel.from_pretrained(
51
+ base_path,
52
+ subfolder="unet",
53
+ torch_dtype=torch.float16,
54
+ )
55
+ unet.requires_grad_(False)
56
+ tokenizer_one = AutoTokenizer.from_pretrained(
57
+ base_path,
58
+ subfolder="tokenizer",
59
+ revision=None,
60
+ use_fast=False,
61
+ )
62
+ tokenizer_two = AutoTokenizer.from_pretrained(
63
+ base_path,
64
+ subfolder="tokenizer_2",
65
+ revision=None,
66
+ use_fast=False,
67
+ )
68
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
69
+
70
  text_encoder_one = CLIPTextModel.from_pretrained(
71
+ base_path,
72
+ subfolder="text_encoder",
73
+ torch_dtype=torch.float16,
74
+ )
75
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
76
+ base_path,
77
+ subfolder="text_encoder_2",
78
+ torch_dtype=torch.float16,
79
+ )
80
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
81
+ base_path,
82
+ subfolder="image_encoder",
83
+ torch_dtype=torch.float16,
84
  )
85
+ vae = AutoencoderKL.from_pretrained(base_path,
86
+ subfolder="vae",
87
+ torch_dtype=torch.float16,
88
+ )
89
 
90
+ # "stabilityai/stable-diffusion-xl-base-1.0",
91
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
92
+ base_path,
93
+ subfolder="unet_encoder",
94
+ torch_dtype=torch.float16,
95
+ )
96
+
97
+ parsing_model = Parsing(0)
98
+ openpose_model = OpenPose(0)
99
+
100
+ UNet_Encoder.requires_grad_(False)
 
 
 
 
101
  image_encoder.requires_grad_(False)
102
+ vae.requires_grad_(False)
103
+ unet.requires_grad_(False)
104
  text_encoder_one.requires_grad_(False)
105
  text_encoder_two.requires_grad_(False)
106
+ tensor_transfrom = transforms.Compose(
107
+ [
108
+ transforms.ToTensor(),
109
+ transforms.Normalize([0.5], [0.5]),
110
+ ]
111
+ )
112
 
113
  pipe = TryonPipeline.from_pretrained(
114
+ base_path,
115
+ unet=unet,
116
+ vae=vae,
117
+ feature_extractor= CLIPImageProcessor(),
118
+ text_encoder = text_encoder_one,
119
+ text_encoder_2 = text_encoder_two,
120
+ tokenizer = tokenizer_one,
121
+ tokenizer_2 = tokenizer_two,
122
+ scheduler = noise_scheduler,
123
+ image_encoder=image_encoder,
124
+ torch_dtype=torch.float16,
125
+ )
126
+ pipe.unet_encoder = UNet_Encoder
127
+
128
  @spaces.GPU
129
+ def start_tryon(person_img, pose_img, mask_img, cloth_img, garment_des, denoise_steps, seed):
130
+ # Assuming device is set up (e.g., "cuda" or "cpu")
131
+ openpose_model.preprocessor.body_estimation.model.to(device)
132
  pipe.to(device)
133
+ pipe.unet_encoder.to(device)
134
+
135
+ # Resize and prepare images
136
+ garm_img = cloth_img.convert("RGB").resize((768, 1024))
137
+ human_img = person_img.convert("RGB").resize((768, 1024))
138
+ mask = mask_img.convert("RGB").resize((768, 1024))
139
+
140
+ # Prepare pose image (already uploaded)
141
+ pose_img = pose_img.resize((768, 1024))
142
+
143
+ # Generate text embeddings for garment description
144
+ prompt = f"model is wearing {garment_des}"
145
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
 
 
 
 
 
 
146
 
147
+ # Embedding generation for prompts
148
+ with torch.no_grad():
149
+ with torch.cuda.amp.autocast():
150
+ (
151
+ prompt_embeds,
152
+ negative_prompt_embeds,
153
+ pooled_prompt_embeds,
154
+ negative_pooled_prompt_embeds,
155
+ ) = pipe.encode_prompt(
156
+ prompt,
157
+ num_images_per_prompt=1,
158
+ do_classifier_free_guidance=True,
159
+ negative_prompt=negative_prompt,
160
+ )
161
+
162
+ prompt_embeds_cloth, _ = pipe.encode_prompt(
163
+ f"a photo of {garment_des}",
164
+ num_images_per_prompt=1,
165
+ do_classifier_free_guidance=False,
166
+ negative_prompt=negative_prompt,
167
+ )
168
+
169
+ # Convert images to tensors for processing
170
+ pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
171
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
172
+ mask_tensor = tensor_transfrom(mask).unsqueeze(0).to(device, torch.float16)
173
+
174
+ # Prepare the generator with optional seed
175
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
176
+
177
+ # Generate the virtual try-on output image
178
+ images = pipe(
179
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
180
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
181
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
182
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
183
+ num_inference_steps=denoise_steps,
184
+ generator=generator,
185
+ strength=1.0,
186
+ pose_img=pose_img_tensor.to(device, torch.float16),
187
+ text_embeds_cloth=prompt_embeds_cloth.to(device, torch.float16),
188
+ cloth=garm_tensor.to(device, torch.float16),
189
+ mask_image=mask_tensor,
190
+ image=human_img,
191
+ height=1024,
192
+ width=768,
193
+ ip_adapter_image=garm_img.resize((768, 1024)),
194
+ guidance_scale=2.0,
195
+ )[0]
196
 
197
+ return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Gradio interface for the virtual try-on model
200
+ image_blocks = gr.Blocks().queue()
 
201
 
202
+ with image_blocks as demo:
203
+ gr.Markdown("## SmartLuga ")
204
+ with gr.Row():
205
+ with gr.Column():
206
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Image', interactive=True)
207
+ with gr.Row():
208
+ is_checked_crop = gr.Checkbox(label="Use auto-crop & resizing", value=False)
209
 
210
+ with gr.Column():
211
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
212
+ with gr.Row(elem_id="prompt-container"):
213
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
214
+
215
+ with gr.Column():
216
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
217
+
218
+ with gr.Column():
219
+ image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
220
+
221
+ with gr.Column():
222
+ try_button = gr.Button(value="Try-on")
223
+ with gr.Accordion(label="Advanced Settings", open=False):
224
+ with gr.Row():
225
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
226
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
227
+
228
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, denoise_steps, seed], outputs=[image_out, masked_img], api_name='tryon')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ image_blocks.launch()