Himanshu-AT commited on
Commit
3534d80
·
1 Parent(s): 2f6f08a

update titles in README and requirements, add opencv-python

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +175 -92
  3. readme.md +1 -1
  4. requirements.txt +1 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,61 +1,148 @@
 
1
  import gradio as gr
2
  import numpy as np
3
- import os
4
- import spaces
5
- import random
6
- import json
7
- # from image_gen_aux import DepthPreprocessor
8
- from PIL import Image
9
  import torch
10
- from torchvision import transforms
11
-
12
- from diffusers import FluxFillPipeline, AutoencoderKL
13
  from PIL import Image
 
 
14
 
 
 
15
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
 
19
- pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
20
- # pipe.load_lora_weights("Himanshu806/testLora")
21
- # pipe.enable_lora()
 
 
22
 
23
- with open("lora_models.json", "r") as f:
24
- lora_models = json.load(f)
 
25
 
26
- def download_model(model_name, model_path):
27
- print(f"Downloading model: {model_name} from {model_path}")
28
- try:
29
- pipe.load_lora_weights(model_path)
30
- print(f"Successfully downloaded model: {model_name}")
31
- except Exception as e:
32
- print(f"Failed to download model: {model_name}. Error: {e}")
33
 
34
- # Iterate through the models and download each one
35
- for model_name, model_path in lora_models.items():
36
- download_model(model_name, model_path)
37
 
38
- lora_models["None"] = None
 
 
 
 
 
 
39
 
40
- @spaces.GPU(durations=300)
41
- def infer(edit_images, prompt, width, height, lora_model, seed=42, randomize_seed=False, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
42
- # pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- if lora_model != "None":
45
- pipe.load_lora_weights(lora_models[lora_model])
46
- pipe.enable_lora()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  image = edit_images["background"]
49
- # width, height = calculate_optimal_dimensions(image)
50
- mask = edit_images["layers"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if randomize_seed:
52
  seed = random.randint(0, MAX_SEED)
53
 
54
- # controlImage = processor(image)
55
- image = pipe(
56
- # mask_image_latent=vae.encode(controlImage),
57
  prompt=prompt,
58
- prompt_2=prompt,
59
  image=image,
60
  mask_image=mask,
61
  height=height,
@@ -63,23 +150,14 @@ def infer(edit_images, prompt, width, height, lora_model, seed=42, randomize_see
63
  guidance_scale=guidance_scale,
64
  num_inference_steps=num_inference_steps,
65
  generator=torch.Generator(device='cuda').manual_seed(seed),
66
- # lora_scale=0.75 // not supported in this version
67
  ).images[0]
68
 
69
- output_image_jpg = image.convert("RGB")
70
  output_image_jpg.save("output.jpg", "JPEG")
71
-
72
  return output_image_jpg, seed
73
- # return image, seed
74
-
75
- examples = [
76
- "photography of a young woman, accent lighting, (front view:1.4), "
77
- # "a tiny astronaut hatching from an egg on the moon",
78
- # "a cat holding a sign that says hello world",
79
- # "an anime illustration of a wiener schnitzel",
80
- ]
81
 
82
- css="""
 
83
  #col-container {
84
  margin: 0 auto;
85
  max-width: 1000px;
@@ -87,41 +165,51 @@ css="""
87
  """
88
 
89
  with gr.Blocks(css=css) as demo:
90
-
91
  with gr.Column(elem_id="col-container"):
92
- gr.Markdown(f"""# FLUX.1 [dev]
93
- """)
94
  with gr.Row():
95
  with gr.Column():
 
96
  edit_image = gr.ImageEditor(
97
- label='Upload and draw mask for inpainting',
98
  type='pil',
99
  sources=["upload", "webcam"],
100
  image_mode='RGB',
101
- layers=False,
102
  brush=gr.Brush(colors=["#FFFFFF"]),
103
- # height=600
104
  )
105
  prompt = gr.Text(
106
- label="Prompt",
107
  show_label=False,
108
  max_lines=2,
109
- placeholder="Enter your prompt",
110
  container=False,
111
  )
112
-
113
- lora_model = gr.Dropdown(
114
- label="Select LoRA Model",
115
- choices=list(lora_models.keys()),
116
- value="None",
117
  )
118
-
 
119
  run_button = gr.Button("Run")
120
-
121
  result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  with gr.Accordion("Advanced Settings", open=False):
124
-
125
  seed = gr.Slider(
126
  label="Seed",
127
  minimum=0,
@@ -129,50 +217,45 @@ with gr.Blocks(css=css) as demo:
129
  step=1,
130
  value=0,
131
  )
132
-
133
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
134
-
135
  with gr.Row():
136
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  guidance_scale = gr.Slider(
138
  label="Guidance Scale",
139
  minimum=1,
140
  maximum=30,
141
  step=0.5,
142
- value=50,
143
  )
144
-
145
  num_inference_steps = gr.Slider(
146
- label="Number of inference steps",
147
  minimum=1,
148
  maximum=50,
149
  step=1,
150
  value=28,
151
  )
152
 
153
- with gr.Row():
154
-
155
- width = gr.Slider(
156
- label="width",
157
- minimum=512,
158
- maximum=3072,
159
- step=1,
160
- value=1024,
161
- )
162
-
163
- height = gr.Slider(
164
- label="height",
165
- minimum=512,
166
- maximum=3072,
167
- step=1,
168
- value=1024,
169
- )
170
-
171
  gr.on(
172
  triggers=[run_button.click, prompt.submit],
173
- fn = infer,
174
- inputs = [edit_image, prompt, width, height, lora_model, seed, randomize_seed, guidance_scale, num_inference_steps],
175
- outputs = [result, seed]
176
  )
177
 
178
  # demo.launch()
 
1
+
2
  import gradio as gr
3
  import numpy as np
 
 
 
 
 
 
4
  import torch
5
+ import random
 
 
6
  from PIL import Image
7
+ import cv2
8
+ import spaces
9
 
10
+ # ------------------ Inpainting Pipeline Setup ------------------ #
11
+ from diffusers import FluxFillPipeline
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
16
+ pipe = FluxFillPipeline.from_pretrained(
17
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
18
+ )
19
+ pipe.load_lora_weights("alvdansen/flux-koda")
20
+ pipe.enable_lora()
21
 
22
+ def calculate_optimal_dimensions(image: Image.Image):
23
+ # Extract the original dimensions
24
+ original_width, original_height = image.size
25
 
26
+ # Set constants
27
+ MIN_ASPECT_RATIO = 9 / 16
28
+ MAX_ASPECT_RATIO = 16 / 9
29
+ FIXED_DIMENSION = 1024
 
 
 
30
 
31
+ # Calculate the aspect ratio of the original image
32
+ original_aspect_ratio = original_width / original_height
 
33
 
34
+ # Determine which dimension to fix
35
+ if original_aspect_ratio > 1: # Wider than tall
36
+ width = FIXED_DIMENSION
37
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
38
+ else: # Taller than wide
39
+ height = FIXED_DIMENSION
40
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
41
 
42
+ # Ensure dimensions are multiples of 8
43
+ width = (width // 8) * 8
44
+ height = (height // 8) * 8
45
+
46
+ # Enforce aspect ratio limits
47
+ calculated_aspect_ratio = width / height
48
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
49
+ width = (height * MAX_ASPECT_RATIO // 8) * 8
50
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
51
+ height = (width / MIN_ASPECT_RATIO // 8) * 8
52
+
53
+ # Ensure minimum dimensions are met
54
+ width = max(width, 576) if width == FIXED_DIMENSION else width
55
+ height = max(height, 576) if height == FIXED_DIMENSION else height
56
+
57
+ return width, height
58
 
59
+ # ------------------ SAM (Transformers) Imports and Initialization ------------------ #
60
+ from transformers import SamModel, SamProcessor
 
61
 
62
+ # Load the model and processor from Hugging Face.
63
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base")
64
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
65
+
66
+ @spaces.GPU(durations=300)
67
+ def generate_mask_with_sam(image: Image.Image, mask_prompt: str):
68
+ """
69
+ Generate a segmentation mask using SAM (via Hugging Face Transformers).
70
+
71
+ The mask_prompt is expected to be a comma-separated string of two integers,
72
+ e.g. "450,600" representing an (x,y) coordinate in the image.
73
+
74
+ The function converts the coordinate into the proper input format for SAM and returns a binary mask.
75
+ """
76
+ if mask_prompt.strip() == "":
77
+ raise ValueError("No mask prompt provided.")
78
+
79
+ try:
80
+ # Parse the mask_prompt into a coordinate
81
+ coords = [int(x.strip()) for x in mask_prompt.split(",")]
82
+ if len(coords) != 2:
83
+ raise ValueError("Expected two comma-separated integers (x,y).")
84
+ except Exception as e:
85
+ raise ValueError("Invalid mask prompt. Please provide coordinates as 'x,y'. Error: " + str(e))
86
+
87
+ # The SAM processor expects a list of input points.
88
+ # Format the point as a list of lists; here we assume one point per image.
89
+ # (The Transformers SAM expects the points in [x, y] order.)
90
+ input_points = [coords] # e.g. [[450,600]]
91
+ # Optionally, you can supply input_labels (1 for foreground, 0 for background)
92
+ input_labels = [1]
93
+
94
+ # Prepare the inputs for the SAM processor.
95
+ inputs = sam_processor(images=image,
96
+ input_points=[input_points],
97
+ input_labels=[input_labels],
98
+ return_tensors="pt")
99
+
100
+ # Move tensors to the same device as the model.
101
+ device = next(sam_model.parameters()).device
102
+ inputs = {k: v.to(device) for k, v in inputs.items()}
103
+
104
+ # Forward pass through SAM.
105
+ with torch.no_grad():
106
+ outputs = sam_model(**inputs)
107
+
108
+ # The output contains predicted masks; we take the first mask from the first prompt.
109
+ # (Assuming outputs.pred_masks is of shape (batch_size, num_masks, H, W))
110
+ pred_masks = outputs.pred_masks # Tensor of shape (1, num_masks, H, W)
111
+ mask = pred_masks[0][0].detach().cpu().numpy()
112
+
113
+ # Convert the mask to binary (0 or 255) using a threshold.
114
+ mask_bin = (mask > 0.5).astype(np.uint8) * 255
115
+ mask_pil = Image.fromarray(mask_bin)
116
+ return mask_pil
117
+
118
+ # ------------------ Inference Function ------------------ #
119
+ @spaces.GPU(durations=300)
120
+ def infer(edit_images, prompt, mask_prompt,
121
+ seed=42, randomize_seed=False, width=1024, height=1024,
122
+ guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
123
+ # Get the base image from the "background" layer.
124
  image = edit_images["background"]
125
+ width, height = calculate_optimal_dimensions(image)
126
+
127
+ # If a mask prompt is provided, use the SAM-based mask generator.
128
+ if mask_prompt and mask_prompt.strip() != "":
129
+ try:
130
+ mask = generate_mask_with_sam(image, mask_prompt)
131
+ except Exception as e:
132
+ raise ValueError("Error generating mask from prompt: " + str(e))
133
+ else:
134
+ # Fall back to using a manually drawn mask (from the first layer).
135
+ try:
136
+ mask = edit_images["layers"][0]
137
+ except (TypeError, IndexError):
138
+ raise ValueError("No mask provided. Please either draw a mask or supply a mask prompt.")
139
+
140
  if randomize_seed:
141
  seed = random.randint(0, MAX_SEED)
142
 
143
+ # Run the inpainting diffusion pipeline with the provided prompt and mask.
144
+ image_out = pipe(
 
145
  prompt=prompt,
 
146
  image=image,
147
  mask_image=mask,
148
  height=height,
 
150
  guidance_scale=guidance_scale,
151
  num_inference_steps=num_inference_steps,
152
  generator=torch.Generator(device='cuda').manual_seed(seed),
 
153
  ).images[0]
154
 
155
+ output_image_jpg = image_out.convert("RGB")
156
  output_image_jpg.save("output.jpg", "JPEG")
 
157
  return output_image_jpg, seed
 
 
 
 
 
 
 
 
158
 
159
+ # ------------------ Gradio UI ------------------ #
160
+ css = """
161
  #col-container {
162
  margin: 0 auto;
163
  max-width: 1000px;
 
165
  """
166
 
167
  with gr.Blocks(css=css) as demo:
 
168
  with gr.Column(elem_id="col-container"):
169
+ gr.Markdown("# FLUX.1 [dev] with SAM (Transformers) Mask Generation")
 
170
  with gr.Row():
171
  with gr.Column():
172
+ # The image editor now allows you to optionally draw a mask.
173
  edit_image = gr.ImageEditor(
174
+ label='Upload Image (and optionally draw a mask)',
175
  type='pil',
176
  sources=["upload", "webcam"],
177
  image_mode='RGB',
178
+ layers=False, # We will generate a mask automatically if needed.
179
  brush=gr.Brush(colors=["#FFFFFF"]),
 
180
  )
181
  prompt = gr.Text(
182
+ label="Inpainting Prompt",
183
  show_label=False,
184
  max_lines=2,
185
+ placeholder="Enter your inpainting prompt",
186
  container=False,
187
  )
188
+ mask_prompt = gr.Text(
189
+ label="Mask Prompt (enter a coordinate as 'x,y')",
190
+ show_label=True,
191
+ placeholder="E.g. 450,600",
192
+ container=True,
193
  )
194
+ generate_mask_btn = gr.Button("Generate Mask")
195
+ mask_preview = gr.Image(label="Mask Preview", show_label=True)
196
  run_button = gr.Button("Run")
 
197
  result = gr.Image(label="Result", show_label=False)
198
+
199
+ # Button to preview the generated mask.
200
+ def on_generate_mask(image, mask_prompt):
201
+ if image is None or mask_prompt.strip() == "":
202
+ return None
203
+ mask = generate_mask_with_sam(image, mask_prompt)
204
+ return mask
205
+
206
+ generate_mask_btn.click(
207
+ fn=on_generate_mask,
208
+ inputs=[edit_image, mask_prompt],
209
+ outputs=[mask_preview]
210
+ )
211
 
212
  with gr.Accordion("Advanced Settings", open=False):
 
213
  seed = gr.Slider(
214
  label="Seed",
215
  minimum=0,
 
217
  step=1,
218
  value=0,
219
  )
 
220
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
221
  with gr.Row():
222
+ width = gr.Slider(
223
+ label="Width",
224
+ minimum=256,
225
+ maximum=MAX_IMAGE_SIZE,
226
+ step=32,
227
+ value=1024,
228
+ visible=False
229
+ )
230
+ height = gr.Slider(
231
+ label="Height",
232
+ minimum=256,
233
+ maximum=MAX_IMAGE_SIZE,
234
+ step=32,
235
+ value=1024,
236
+ visible=False
237
+ )
238
+ with gr.Row():
239
  guidance_scale = gr.Slider(
240
  label="Guidance Scale",
241
  minimum=1,
242
  maximum=30,
243
  step=0.5,
244
+ value=3.5,
245
  )
 
246
  num_inference_steps = gr.Slider(
247
+ label="Number of Inference Steps",
248
  minimum=1,
249
  maximum=50,
250
  step=1,
251
  value=28,
252
  )
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  gr.on(
255
  triggers=[run_button.click, prompt.submit],
256
+ fn=infer,
257
+ inputs=[edit_image, prompt, mask_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
258
+ outputs=[result, seed]
259
  )
260
 
261
  # demo.launch()
readme.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Inpainting
3
  emoji: 🏆
4
  colorFrom: blue
5
  colorTo: purple
 
1
  ---
2
+ title: Inpainting test
3
  emoji: 🏆
4
  colorFrom: blue
5
  colorTo: purple
requirements.txt CHANGED
@@ -8,3 +8,4 @@ peft
8
  xformers
9
  torchvision
10
  torch
 
 
8
  xformers
9
  torchvision
10
  torch
11
+ opencv-python