ginipick commited on
Commit
a3c8137
·
verified ·
1 Parent(s): a3e6550

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -235
app.py CHANGED
@@ -11,107 +11,45 @@ import traceback
11
  import warnings
12
  import sys
13
 
14
- # Suppress specific warnings
15
  warnings.filterwarnings("ignore", category=FutureWarning)
16
  warnings.filterwarnings("ignore", message=".*_supports_sdpa.*")
17
 
18
- # CRITICAL: Fix Florence2 model before any imports
19
- def fix_florence2_import():
20
- """Pre-patch the Florence2 model class before it's imported"""
21
- import importlib.util
22
- import types
23
-
24
- # Create a custom import hook
25
- class Florence2ImportHook:
26
- def find_spec(self, fullname, path, target=None):
27
- if "florence2" in fullname.lower() or "modeling_florence2" in fullname:
28
- return importlib.util.spec_from_loader(fullname, Florence2Loader())
29
- return None
30
-
31
- class Florence2Loader:
32
- def create_module(self, spec):
33
- return None
34
-
35
- def exec_module(self, module):
36
- # Load the original module
37
- import importlib.machinery
38
- import importlib.util
39
-
40
- # Find the actual florence2 module
41
- for path in sys.path:
42
- florence_path = os.path.join(path, "modeling_florence2.py")
43
- if os.path.exists(florence_path):
44
- spec = importlib.util.spec_from_file_location("modeling_florence2", florence_path)
45
- if spec and spec.loader:
46
- spec.loader.exec_module(module)
47
-
48
- # Patch the module after loading
49
- if hasattr(module, 'Florence2ForConditionalGeneration'):
50
- original_init = module.Florence2ForConditionalGeneration.__init__
51
-
52
- def patched_init(self, config):
53
- # Add the missing attribute before calling super().__init__
54
- self._supports_sdpa = False
55
- original_init(self, config)
56
-
57
- module.Florence2ForConditionalGeneration.__init__ = patched_init
58
- module.Florence2ForConditionalGeneration._supports_sdpa = False
59
- break
60
-
61
- # Install the import hook
62
- hook = Florence2ImportHook()
63
- sys.meta_path.insert(0, hook)
64
-
65
- # Apply the fix before any model imports
66
- try:
67
- fix_florence2_import()
68
- except Exception as e:
69
- print(f"Warning: Could not apply import hook: {e}")
70
-
71
- # Alternative fix: Monkey-patch transformers before importing utils
72
- def monkey_patch_transformers():
73
- """Monkey patch transformers to handle _supports_sdpa"""
74
  try:
75
  import transformers.modeling_utils as modeling_utils
76
 
 
77
  original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation
78
 
79
  def patched_check(self, *args, **kwargs):
80
- # Add the attribute if missing
81
  if not hasattr(self, '_supports_sdpa'):
82
- self._supports_sdpa = False
 
83
  try:
84
  return original_check(self, *args, **kwargs)
85
  except AttributeError as e:
86
  if '_supports_sdpa' in str(e):
87
- # Return a safe default
88
  return "eager"
89
  raise
90
 
91
  modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check
92
-
93
- # Also patch the getter
94
- original_getattr = modeling_utils.PreTrainedModel.__getattribute__
95
-
96
- def patched_getattr(self, name):
97
- if name == '_supports_sdpa' and not hasattr(self, '_supports_sdpa'):
98
- return False
99
- return original_getattr(self, name)
100
-
101
- modeling_utils.PreTrainedModel.__getattribute__ = patched_getattr
102
-
103
- print("Successfully patched transformers for Florence2 compatibility")
104
 
105
  except Exception as e:
106
  print(f"Warning: Could not patch transformers: {e}")
107
 
108
- # Apply the monkey patch
109
- monkey_patch_transformers()
110
 
111
- # Now import the utils after patching
112
- from util.utils import check_ocr_box, get_yolo_model, get_som_labeled_img
113
 
114
- # Download repository (if not already downloaded)
115
  repo_id = "microsoft/OmniParser-v2.0"
116
  local_dir = "weights"
117
 
@@ -121,75 +59,105 @@ if not os.path.exists(local_dir):
121
  else:
122
  print(f"Weights already exist at: {local_dir}")
123
 
124
- # Custom function to load caption model with proper error handling
125
  def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"):
126
- """Safely load caption model with multiple fallback methods"""
127
 
128
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
129
 
 
130
  try:
131
- # Method 1: Try the original function with patching
132
- from util.utils import get_caption_model_processor
133
  return get_caption_model_processor(model_name, model_name_or_path)
134
- except AttributeError as e:
135
- if '_supports_sdpa' in str(e):
136
- print(f"SDPA error detected, trying alternative loading method...")
137
- else:
138
- raise
139
 
140
- # Method 2: Load directly with specific configuration
141
  try:
142
  from transformers import AutoProcessor, AutoModelForCausalLM
143
 
144
- print(f"Loading caption model from {model_name_or_path} with alternative method...")
145
 
146
- # Load processor
147
  processor = AutoProcessor.from_pretrained(
148
  model_name_or_path,
149
- trust_remote_code=True,
150
- revision="main"
151
  )
152
 
153
- # Try to load model with different configurations
154
- configs_to_try = [
155
- {"attn_implementation": "eager", "use_cache": False},
156
- {"use_flash_attention_2": False, "use_cache": False},
157
- {"torch_dtype": torch.float32}, # Try float32 instead of float16
158
- ]
159
-
160
- model = None
161
- for config in configs_to_try:
162
- try:
163
- model = AutoModelForCausalLM.from_pretrained(
164
- model_name_or_path,
165
- trust_remote_code=True,
166
- device_map="auto" if torch.cuda.is_available() else None,
167
- **config
168
- )
169
-
170
- # Ensure the attribute exists
171
- if not hasattr(model, '_supports_sdpa'):
172
- model._supports_sdpa = False
173
-
174
- print(f"Model loaded successfully with config: {config}")
175
- break
176
-
177
- except Exception as e:
178
- print(f"Failed with config {config}: {e}")
179
- continue
180
 
181
- if model is None:
182
- raise RuntimeError("Could not load model with any configuration")
 
183
 
184
- # Move to device if needed
185
- if device.type == 'cuda' and not next(model.parameters()).is_cuda:
186
  model = model.to(device)
187
 
 
188
  return {'model': model, 'processor': processor}
189
 
190
  except Exception as e:
191
- print(f"Error in alternative loading: {e}")
192
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # Load models
195
  try:
@@ -205,9 +173,9 @@ except Exception as e:
205
  print(f"Critical error loading models: {e}")
206
  print(traceback.format_exc())
207
  caption_model_processor = None
208
- # Don't raise here, let the UI handle it
209
 
210
- # Markdown header text
211
  MARKDOWN = """
212
  # OmniParser V2 Pro🔥
213
 
@@ -220,7 +188,6 @@ MARKDOWN = """
220
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
221
  print(f"Using device: {DEVICE}")
222
 
223
- # Custom CSS for UI enhancement
224
  custom_css = """
225
  body { background-color: #f0f2f5; }
226
  .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; }
@@ -230,8 +197,6 @@ button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.
230
  .output-image { border: 2px solid #e1e4e8; border-radius: 8px; }
231
  #input_image { border: 2px dashed #4a90e2; border-radius: 8px; }
232
  #input_image:hover { border-color: #2c5aa0; }
233
- .gr-box { border-radius: 8px; }
234
- .gr-padded { padding: 16px; }
235
  """
236
 
237
  @spaces.GPU
@@ -243,22 +208,19 @@ def process(
243
  use_paddleocr,
244
  imgsz
245
  ) -> tuple:
246
- """Process image with error handling and validation"""
247
 
248
- # Input validation
249
  if image_input is None:
250
  return None, "⚠️ Please upload an image for processing."
251
 
252
- # Check if caption model is loaded
253
- if caption_model_processor is None:
254
- return None, "⚠️ Caption model not loaded. There was an error during initialization. Please check the logs."
255
 
256
  try:
257
- # Log processing parameters
258
- print(f"Processing with parameters: box_threshold={box_threshold}, "
259
- f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}")
260
 
261
- # Calculate overlay ratio based on input image width
262
  image_width = image_input.size[0]
263
  box_overlay_ratio = max(0.5, min(2.0, image_width / 3200))
264
 
@@ -269,7 +231,7 @@ def process(
269
  'thickness': max(int(3 * box_overlay_ratio), 1),
270
  }
271
 
272
- # Run OCR bounding box detection
273
  try:
274
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
275
  image_input,
@@ -280,42 +242,37 @@ def process(
280
  use_paddleocr=use_paddleocr
281
  )
282
 
283
- # Handle None result from OCR
284
  if ocr_bbox_rslt is None:
285
- print("OCR returned None, using empty results")
286
  text, ocr_bbox = [], []
287
  else:
288
  text, ocr_bbox = ocr_bbox_rslt
289
 
290
- # Validate OCR results
291
- if text is None:
292
- text = []
293
- if ocr_bbox is None:
294
- ocr_bbox = []
295
-
296
  print(f"OCR found {len(text)} text regions")
297
 
298
  except Exception as e:
299
- print(f"OCR error: {e}, continuing with empty OCR results")
300
  text, ocr_bbox = [], []
301
 
302
- # Get labeled image and parsed content
303
  try:
304
- # Ensure the model has the required attribute
305
  if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
306
  model = caption_model_processor['model']
307
  if not hasattr(model, '_supports_sdpa'):
308
- model._supports_sdpa = False
309
 
310
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
311
  image_input,
312
  yolo_model,
313
  BOX_TRESHOLD=box_threshold,
314
  output_coord_in_ratio=True,
315
- ocr_bbox=ocr_bbox if ocr_bbox else [],
316
  draw_bbox_config=draw_bbox_config,
317
  caption_model_processor=caption_model_processor,
318
- ocr_text=text if text else [],
319
  iou_threshold=iou_threshold,
320
  imgsz=imgsz
321
  )
@@ -324,121 +281,100 @@ def process(
324
  raise ValueError("Failed to generate labeled image")
325
 
326
  except Exception as e:
327
- print(f"Error in SOM processing: {e}")
328
- print(traceback.format_exc())
329
- return image_input, f"⚠️ Error during element detection: {str(e)}"
330
 
331
- # Decode processed image from base64
332
  try:
333
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
334
- print('Successfully decoded processed image')
335
  except Exception as e:
336
- print(f"Error decoding image: {e}")
337
- return image_input, f"⚠️ Error decoding processed image: {str(e)}"
338
 
339
- # Format parsed content list
340
  if parsed_content_list and len(parsed_content_list) > 0:
341
  parsed_text = "🎯 **Detected Elements:**\n\n"
342
  for i, v in enumerate(parsed_content_list):
343
- if v: # Only add non-empty content
344
- parsed_text += f"**Icon {i}:** {v}\n"
345
  else:
346
- parsed_text = "ℹ️ No UI elements detected. Try adjusting the detection thresholds."
347
 
348
- print(f'Finished processing image. Found {len(parsed_content_list)} elements.')
349
  return image, parsed_text
350
 
351
  except Exception as e:
352
- error_msg = f"⚠️ Unexpected error: {str(e)}"
353
- print(f"Error during processing: {e}")
354
  print(traceback.format_exc())
355
- return None, error_msg
356
 
357
- # Build Gradio UI
358
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo:
359
  gr.Markdown(MARKDOWN)
360
 
361
- # Check if models loaded successfully
362
- if caption_model_processor is None:
363
- gr.Markdown("### ⚠️ Warning: Caption model failed to load. Some features may not work.")
364
 
365
  with gr.Row():
366
- # Left sidebar: Upload and settings
367
  with gr.Column(scale=1):
368
- with gr.Accordion("📤 Upload Image & Settings", open=True):
369
  image_input_component = gr.Image(
370
  type='pil',
371
- label='Upload Screenshot/UI Image',
372
  elem_id="input_image"
373
  )
374
 
375
  gr.Markdown("### 🎛️ Detection Settings")
376
 
377
- with gr.Group():
378
- box_threshold_component = gr.Slider(
379
- label='📊 Box Threshold',
380
- minimum=0.01,
381
- maximum=1.0,
382
- step=0.01,
383
- value=0.05,
384
- info="Lower values detect more elements"
385
- )
386
-
387
- iou_threshold_component = gr.Slider(
388
- label='🔲 IOU Threshold',
389
- minimum=0.01,
390
- maximum=1.0,
391
- step=0.01,
392
- value=0.1,
393
- info="Controls overlap filtering"
394
- )
395
-
396
- use_paddleocr_component = gr.Checkbox(
397
- label='🔤 Use PaddleOCR',
398
- value=True,
399
- info="✓ PaddleOCR | ✗ EasyOCR"
400
- )
401
-
402
- imgsz_component = gr.Slider(
403
- label='📐 Detection Image Size',
404
- minimum=640,
405
- maximum=1920,
406
- step=32,
407
- value=640,
408
- info="Higher = better accuracy but slower"
409
- )
410
 
411
- submit_button_component = gr.Button(
412
- value='🚀 Process Image',
413
- variant='primary',
414
- size='lg'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  )
416
 
417
- gr.Markdown("### 💡 Quick Tips")
418
- gr.Markdown("""
419
- - **Mobile apps:** Use default settings
420
- - **Desktop apps:** Try image size 1280
421
- - **Complex UIs:** Lower box threshold to 0.03
422
- - **Too many boxes:** Increase IOU threshold
423
- """)
424
 
425
- # Right main area: Results tabs
426
  with gr.Column(scale=2):
427
  with gr.Tabs():
428
- with gr.Tab("🖼️ Annotated Image"):
429
  image_output_component = gr.Image(
430
  type='pil',
431
- label='Processed Image with Annotations',
432
- elem_classes=["output-image"]
433
  )
434
 
435
- with gr.Tab("📝 Extracted Elements"):
436
  text_output_component = gr.Markdown(
437
- value="*Parsed elements will appear here after processing...*",
438
- elem_classes=["parsed-text"]
439
  )
440
 
441
- # Button click event
442
  submit_button_component.click(
443
  fn=process,
444
  inputs=[
@@ -452,13 +388,9 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro"
452
  show_progress=True
453
  )
454
 
455
- # Launch with queue support
456
  if __name__ == "__main__":
457
  try:
458
- # Set environment variables
459
- os.environ['TRANSFORMERS_OFFLINE'] = '0'
460
- os.environ['HF_HUB_OFFLINE'] = '0'
461
-
462
  demo.queue(max_size=10)
463
  demo.launch(
464
  share=False,
@@ -467,5 +399,4 @@ if __name__ == "__main__":
467
  server_port=7860
468
  )
469
  except Exception as e:
470
- print(f"Failed to launch app: {e}")
471
- print(traceback.format_exc())
 
11
  import warnings
12
  import sys
13
 
14
+ # Suppress warnings
15
  warnings.filterwarnings("ignore", category=FutureWarning)
16
  warnings.filterwarnings("ignore", message=".*_supports_sdpa.*")
17
 
18
+ # Simple monkey patch for transformers - avoid recursion
19
+ def simple_patch_transformers():
20
+ """Simple patch to fix _supports_sdpa issue"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
  import transformers.modeling_utils as modeling_utils
23
 
24
+ # Store original method
25
  original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation
26
 
27
  def patched_check(self, *args, **kwargs):
28
+ # Simply set the attribute if it doesn't exist
29
  if not hasattr(self, '_supports_sdpa'):
30
+ object.__setattr__(self, '_supports_sdpa', False)
31
+
32
  try:
33
  return original_check(self, *args, **kwargs)
34
  except AttributeError as e:
35
  if '_supports_sdpa' in str(e):
36
+ # Return default attention implementation
37
  return "eager"
38
  raise
39
 
40
  modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check
41
+ print("Applied simple transformers patch")
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  except Exception as e:
44
  print(f"Warning: Could not patch transformers: {e}")
45
 
46
+ # Apply the patch BEFORE importing utils
47
+ simple_patch_transformers()
48
 
49
+ # Now import the utils
50
+ from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
51
 
52
+ # Download repository
53
  repo_id = "microsoft/OmniParser-v2.0"
54
  local_dir = "weights"
55
 
 
59
  else:
60
  print(f"Weights already exist at: {local_dir}")
61
 
62
+ # Custom function to load caption model
63
  def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"):
64
+ """Safely load caption model"""
65
 
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
 
68
+ # Method 1: Try original function
69
  try:
 
 
70
  return get_caption_model_processor(model_name, model_name_or_path)
71
+ except Exception as e:
72
+ print(f"Original loading failed: {e}, trying alternative...")
 
 
 
73
 
74
+ # Method 2: Load with specific configs
75
  try:
76
  from transformers import AutoProcessor, AutoModelForCausalLM
77
 
78
+ print(f"Loading caption model from {model_name_or_path}...")
79
 
 
80
  processor = AutoProcessor.from_pretrained(
81
  model_name_or_path,
82
+ trust_remote_code=True
 
83
  )
84
 
85
+ # Load model with safer config
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ model_name_or_path,
88
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
89
+ trust_remote_code=True,
90
+ attn_implementation="eager", # Use eager attention
91
+ low_cpu_mem_usage=True
92
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Ensure attribute exists (using object.__setattr__ to avoid recursion)
95
+ if not hasattr(model, '_supports_sdpa'):
96
+ object.__setattr__(model, '_supports_sdpa', False)
97
 
98
+ if device.type == 'cuda':
 
99
  model = model.to(device)
100
 
101
+ print("Model loaded successfully with alternative method")
102
  return {'model': model, 'processor': processor}
103
 
104
  except Exception as e:
105
+ print(f"Alternative loading also failed: {e}")
106
+
107
+ # Method 3: Manual loading as last resort
108
+ try:
109
+ print("Attempting manual model loading...")
110
+
111
+ # Import required modules
112
+ from transformers import AutoProcessor, AutoConfig
113
+ import importlib.util
114
+
115
+ # Load processor
116
+ processor = AutoProcessor.from_pretrained(
117
+ model_name_or_path,
118
+ trust_remote_code=True
119
+ )
120
+
121
+ # Load config
122
+ config = AutoConfig.from_pretrained(
123
+ model_name_or_path,
124
+ trust_remote_code=True
125
+ )
126
+
127
+ # Manually import and instantiate model
128
+ model_file = os.path.join(model_name_or_path, "modeling_florence2.py")
129
+ if os.path.exists(model_file):
130
+ spec = importlib.util.spec_from_file_location("modeling_florence2_custom", model_file)
131
+ module = importlib.util.module_from_spec(spec)
132
+ spec.loader.exec_module(module)
133
+
134
+ # Get model class
135
+ if hasattr(module, 'Florence2ForConditionalGeneration'):
136
+ model_class = module.Florence2ForConditionalGeneration
137
+
138
+ # Create model instance
139
+ model = model_class(config)
140
+
141
+ # Set the attribute before loading weights
142
+ object.__setattr__(model, '_supports_sdpa', False)
143
+
144
+ # Load weights
145
+ weight_file = os.path.join(model_name_or_path, "model.safetensors")
146
+ if os.path.exists(weight_file):
147
+ from safetensors.torch import load_file
148
+ state_dict = load_file(weight_file)
149
+ model.load_state_dict(state_dict, strict=False)
150
+
151
+ if device.type == 'cuda':
152
+ model = model.to(device)
153
+ model = model.half() # Use half precision
154
+
155
+ print("Model loaded successfully with manual method")
156
+ return {'model': model, 'processor': processor}
157
+
158
+ except Exception as e:
159
+ print(f"Manual loading failed: {e}")
160
+ raise RuntimeError(f"Could not load model with any method: {e}")
161
 
162
  # Load models
163
  try:
 
173
  print(f"Critical error loading models: {e}")
174
  print(traceback.format_exc())
175
  caption_model_processor = None
176
+ yolo_model = None
177
 
178
+ # UI Configuration
179
  MARKDOWN = """
180
  # OmniParser V2 Pro🔥
181
 
 
188
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
189
  print(f"Using device: {DEVICE}")
190
 
 
191
  custom_css = """
192
  body { background-color: #f0f2f5; }
193
  .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; }
 
197
  .output-image { border: 2px solid #e1e4e8; border-radius: 8px; }
198
  #input_image { border: 2px dashed #4a90e2; border-radius: 8px; }
199
  #input_image:hover { border-color: #2c5aa0; }
 
 
200
  """
201
 
202
  @spaces.GPU
 
208
  use_paddleocr,
209
  imgsz
210
  ) -> tuple:
211
+ """Process image with error handling"""
212
 
 
213
  if image_input is None:
214
  return None, "⚠️ Please upload an image for processing."
215
 
216
+ if caption_model_processor is None or yolo_model is None:
217
+ return None, "⚠️ Models not loaded properly. Please restart the application."
 
218
 
219
  try:
220
+ print(f"Processing: box_threshold={box_threshold}, iou_threshold={iou_threshold}, "
221
+ f"use_paddleocr={use_paddleocr}, imgsz={imgsz}")
 
222
 
223
+ # Calculate overlay ratio
224
  image_width = image_input.size[0]
225
  box_overlay_ratio = max(0.5, min(2.0, image_width / 3200))
226
 
 
231
  'thickness': max(int(3 * box_overlay_ratio), 1),
232
  }
233
 
234
+ # OCR processing
235
  try:
236
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
237
  image_input,
 
242
  use_paddleocr=use_paddleocr
243
  )
244
 
 
245
  if ocr_bbox_rslt is None:
 
246
  text, ocr_bbox = [], []
247
  else:
248
  text, ocr_bbox = ocr_bbox_rslt
249
 
250
+ text = text if text is not None else []
251
+ ocr_bbox = ocr_bbox if ocr_bbox is not None else []
252
+
 
 
 
253
  print(f"OCR found {len(text)} text regions")
254
 
255
  except Exception as e:
256
+ print(f"OCR error: {e}")
257
  text, ocr_bbox = [], []
258
 
259
+ # Object detection and captioning
260
  try:
261
+ # Ensure model has _supports_sdpa attribute
262
  if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
263
  model = caption_model_processor['model']
264
  if not hasattr(model, '_supports_sdpa'):
265
+ object.__setattr__(model, '_supports_sdpa', False)
266
 
267
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
268
  image_input,
269
  yolo_model,
270
  BOX_TRESHOLD=box_threshold,
271
  output_coord_in_ratio=True,
272
+ ocr_bbox=ocr_bbox,
273
  draw_bbox_config=draw_bbox_config,
274
  caption_model_processor=caption_model_processor,
275
+ ocr_text=text,
276
  iou_threshold=iou_threshold,
277
  imgsz=imgsz
278
  )
 
281
  raise ValueError("Failed to generate labeled image")
282
 
283
  except Exception as e:
284
+ print(f"Detection error: {e}")
285
+ return image_input, f"⚠️ Error during detection: {str(e)}"
 
286
 
287
+ # Decode image
288
  try:
289
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
 
290
  except Exception as e:
291
+ print(f"Image decode error: {e}")
292
+ return image_input, f"⚠️ Error decoding image: {str(e)}"
293
 
294
+ # Format results
295
  if parsed_content_list and len(parsed_content_list) > 0:
296
  parsed_text = "🎯 **Detected Elements:**\n\n"
297
  for i, v in enumerate(parsed_content_list):
298
+ if v:
299
+ parsed_text += f"**Element {i}:** {v}\n"
300
  else:
301
+ parsed_text = "ℹ️ No UI elements detected. Try adjusting the thresholds."
302
 
303
+ print(f'Processing complete. Found {len(parsed_content_list)} elements.')
304
  return image, parsed_text
305
 
306
  except Exception as e:
307
+ print(f"Processing error: {e}")
 
308
  print(traceback.format_exc())
309
+ return None, f"⚠️ Error: {str(e)}"
310
 
311
+ # Build UI
312
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
313
  gr.Markdown(MARKDOWN)
314
 
315
+ if caption_model_processor is None or yolo_model is None:
316
+ gr.Markdown("### ⚠️ Warning: Models failed to load. Please check logs.")
 
317
 
318
  with gr.Row():
 
319
  with gr.Column(scale=1):
320
+ with gr.Accordion("📤 Upload & Settings", open=True):
321
  image_input_component = gr.Image(
322
  type='pil',
323
+ label='Upload Screenshot',
324
  elem_id="input_image"
325
  )
326
 
327
  gr.Markdown("### 🎛️ Detection Settings")
328
 
329
+ box_threshold_component = gr.Slider(
330
+ label='Box Threshold',
331
+ minimum=0.01,
332
+ maximum=1.0,
333
+ step=0.01,
334
+ value=0.05,
335
+ info="Lower = more detections"
336
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ iou_threshold_component = gr.Slider(
339
+ label='IOU Threshold',
340
+ minimum=0.01,
341
+ maximum=1.0,
342
+ step=0.01,
343
+ value=0.1,
344
+ info="Overlap filtering"
345
+ )
346
+
347
+ use_paddleocr_component = gr.Checkbox(
348
+ label='Use PaddleOCR',
349
+ value=True
350
+ )
351
+
352
+ imgsz_component = gr.Slider(
353
+ label='Image Size',
354
+ minimum=640,
355
+ maximum=1920,
356
+ step=32,
357
+ value=640
358
  )
359
 
360
+ submit_button_component = gr.Button(
361
+ value='🚀 Process',
362
+ variant='primary'
363
+ )
 
 
 
364
 
 
365
  with gr.Column(scale=2):
366
  with gr.Tabs():
367
+ with gr.Tab("🖼️ Result"):
368
  image_output_component = gr.Image(
369
  type='pil',
370
+ label='Annotated Image'
 
371
  )
372
 
373
+ with gr.Tab("📝 Elements"):
374
  text_output_component = gr.Markdown(
375
+ value="*Results will appear here...*"
 
376
  )
377
 
 
378
  submit_button_component.click(
379
  fn=process,
380
  inputs=[
 
388
  show_progress=True
389
  )
390
 
391
+ # Launch
392
  if __name__ == "__main__":
393
  try:
 
 
 
 
394
  demo.queue(max_size=10)
395
  demo.launch(
396
  share=False,
 
399
  server_port=7860
400
  )
401
  except Exception as e:
402
+ print(f"Launch failed: {e}")