weihongliang commited on
Commit
5becd44
Β·
verified Β·
1 Parent(s): d63e46a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -30
app.py CHANGED
@@ -20,10 +20,13 @@ import torch.backends.cudnn as cudnn
20
  from torchvision import transforms as pth_transforms
21
  import shutil
22
  import os
 
23
 
24
 
25
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth")
26
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth")
 
 
27
 
28
 
29
  sys.path.append("./segment-anything")
@@ -40,21 +43,29 @@ OBJECT_SAVE_PATH = "./database/Objects/masks"
40
  FACE_SAVE_PATH = "./database/Faces/masks"
41
 
42
  # Initialize SAM model
 
 
 
 
43
  def initialize_sam(sam_checkpoint, model_type="vit_h"):
44
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
45
- sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
 
 
46
  return sam
47
 
48
- # Path to the SAM checkpoint
49
- sam_checkpoint = "./sam_vit_h_4b8939.pth"
50
- sam = initialize_sam(sam_checkpoint)
51
- predictor = None
52
-
53
 
54
  # Load RADIO model
55
  model_version = "radio_v2.5-h" # Using RADIOv2.5-H model (ViT-H/16)
56
- model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=model_version, progress=True, skip_validation=True)
57
- model.cuda().eval()
 
 
 
 
 
 
 
58
 
59
  def extract_features(image_path):
60
  """Extract features from an image using the RADIO model."""
@@ -140,11 +151,14 @@ def _robust_collate_fn_for_extract_features(batch):
140
  return image_data_list, batched_indices
141
 
142
 
 
143
  def extract_features(object_dataset, batch_size, num_workers):
144
  """
145
  Extracts features from images, handling inputs as paths, PIL Images, or Tensors.
146
  Assumes `model`, `model_version`, `pil_to_tensor` are in calling scope.
147
  """
 
 
148
  dataloader = DataLoader(
149
  object_dataset,
150
  batch_size=batch_size,
@@ -473,12 +487,18 @@ def navigate_images(is_same_object=False):
473
  status_text = state.get_status_text()
474
  return current_image, mask_display, status_text, state.get_gallery(), None # Return None to clear file upload
475
 
 
476
  def generate_mask(image, evt: gr.SelectData): # 'image' is the numpy array from the clicked component
477
- global predictor
478
 
479
  # Use the image passed by the event!
480
  if image is None:
481
  return None, None, "Cannot segment: Image component is empty.", state.get_gallery()
 
 
 
 
 
482
 
483
  # Ensure the image is a NumPy array in RGB format (Gradio usually provides this)
484
  if not isinstance(image, np.ndarray):
@@ -935,7 +955,7 @@ imsize = 224
935
  args = args_parser.parse_args(args=[
936
  "--train_path", "./database/Objects/masks",
937
  "--test_path", "temp_path_placeholder", # This will be updated during runtime
938
- "--pretrained_weights", "./dinov2_vitl14_reg4_pretrain.pth",
939
  "--output_dir", f"exps/output_RankSelect_{imsize}_mask", # Default tag, will be updated
940
  ])
941
 
@@ -943,8 +963,15 @@ args = args_parser.parse_args(args=[
943
  os.makedirs(args.output_dir, exist_ok=True)
944
  #model, autocast_dtype = setup_and_build_model(args)
945
 
 
946
  def detect_objects(input_img, score_threshold=0.52, tag="mask"):
947
  """Main function to detect objects in an image"""
 
 
 
 
 
 
948
  # Create temporary file for the input image
949
  with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
950
  temp_path = f.name
@@ -1256,6 +1283,7 @@ def detect_objects(input_img, score_threshold=0.52, tag="mask"):
1256
  # ===== FACE DETECTION AND RECOGNITION PART =====
1257
 
1258
  # Initialize face detection and recognition models
 
1259
  def initialize_face_models():
1260
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1261
  mtcnn = MTCNN(
@@ -1351,6 +1379,7 @@ def get_face_embeddings(face_dir=FACE_SAVE_PATH):
1351
  return embeddings, face_names, face_paths, face_anns
1352
 
1353
  # Detect and recognize faces in an image
 
1354
  def detect_faces(input_img, score_threshold=0.7):
1355
  mtcnn, resnet, device = initialize_face_models()
1356
 
@@ -1504,6 +1533,7 @@ def match_faces_stable_matching(face_embeddings, detected_embeddings, score_thre
1504
 
1505
  return matches, similarities
1506
  # 1. Add the combined detection function
 
1507
  def combined_detection(img, obj_threshold, face_threshold, tag):
1508
  """
1509
  Run both object detection and face detection on the same image
@@ -1599,37 +1629,40 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
1599
  from qwen_vl_utils import process_vision_info
1600
 
1601
  # Load model and processor at the application level for reuse
 
1602
  def load_qwen2vl_model():
1603
  model = Qwen2VLForConditionalGeneration.from_pretrained(
1604
- #"/mnt/14T-disk/code/Contextual_Referring_Understanding/LLaMA-Factory/models/qwen2_vl_7b_citation_lora_sft_face_3/goodcaption-20000",
1605
- "/mnt/14T-disk/code/Contextual_Referring_Understanding/LLaMA-Factory/models/qwen2_vl_2b_citation_lora_sft_face_3/goodcaption-20000",
1606
  torch_dtype=torch.bfloat16,
1607
- device_map="cuda:0"
1608
  )
1609
  min_pixels = 256 * 28 * 28
1610
  max_pixels = 1280 * 28 * 28
1611
  processor = AutoProcessor.from_pretrained(
1612
  "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
1613
  )
1614
- #processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
1615
  return model, processor
1616
 
1617
- # Try to load the model, but handle errors if it fails
1618
- try:
1619
- qwen_model, qwen_processor = load_qwen2vl_model()
1620
- qwen_model_loaded = True
1621
- except Exception as e:
1622
- print(f"Failed to load Qwen2-VL model: {e}")
1623
- qwen_model_loaded = False
1624
 
1625
  # Function to process detection results and use Qwen2-VL for answering questions
 
1626
  def ask_qwen_about_detections(input_image, question, obj_threshold, face_threshold, tag):
1627
  """
1628
  Process an image with detection and use Qwen2-VL to answer questions
1629
  """
1630
- # Check if the model is loaded
1631
- if not qwen_model_loaded:
1632
- return "Qwen2-VL model not loaded. Please check console for errors.", None, None
 
 
 
 
 
 
1633
 
1634
  # Get detection results and formatted text
1635
  qwen_input, output_img = process_image_for_qwen(input_image, obj_threshold, face_threshold, tag)
@@ -2146,7 +2179,7 @@ with gr.Blocks() as app:
2146
  lines=2
2147
  )
2148
 
2149
- qwen_ask_button = gr.Button("Ask RC-MLLM-7B")
2150
 
2151
  with gr.Column():
2152
  qwen_output_image = gr.Image(label="Detection Result")
@@ -2174,8 +2207,7 @@ with gr.Blocks() as app:
2174
 
2175
  # Model status display
2176
  model_status = gr.Markdown(
2177
- "βœ… RC-MLLM model loaded successfully" if qwen_model_loaded else
2178
- "❌ RC-MLLM model not loaded. Please check console for errors."
2179
  )
2180
 
2181
  # Instructions for RC-MLLM section
 
20
  from torchvision import transforms as pth_transforms
21
  import shutil
22
  import os
23
+ import spaces
24
 
25
 
26
+ # Download SAM checkpoint if not exists
27
+ import subprocess
28
+ if not os.path.exists("./sam_vit_h_4b8939.pth"):
29
+ subprocess.run(["wget", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"])
30
 
31
 
32
  sys.path.append("./segment-anything")
 
43
  FACE_SAVE_PATH = "./database/Faces/masks"
44
 
45
  # Initialize SAM model
46
+ sam = None
47
+ predictor = None
48
+
49
+ @spaces.GPU
50
  def initialize_sam(sam_checkpoint, model_type="vit_h"):
51
+ global sam
52
+ if sam is None:
53
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
54
+ sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
55
  return sam
56
 
 
 
 
 
 
57
 
58
  # Load RADIO model
59
  model_version = "radio_v2.5-h" # Using RADIOv2.5-H model (ViT-H/16)
60
+ model = None # Initialize as None, will be loaded when needed
61
+
62
+ @spaces.GPU
63
+ def load_radio_model():
64
+ global model
65
+ if model is None:
66
+ model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=model_version, progress=True, skip_validation=True)
67
+ model.cuda().eval()
68
+ return model
69
 
70
  def extract_features(image_path):
71
  """Extract features from an image using the RADIO model."""
 
151
  return image_data_list, batched_indices
152
 
153
 
154
+ @spaces.GPU
155
  def extract_features(object_dataset, batch_size, num_workers):
156
  """
157
  Extracts features from images, handling inputs as paths, PIL Images, or Tensors.
158
  Assumes `model`, `model_version`, `pil_to_tensor` are in calling scope.
159
  """
160
+ # Ensure model is loaded
161
+ model = load_radio_model()
162
  dataloader = DataLoader(
163
  object_dataset,
164
  batch_size=batch_size,
 
487
  status_text = state.get_status_text()
488
  return current_image, mask_display, status_text, state.get_gallery(), None # Return None to clear file upload
489
 
490
+ @spaces.GPU
491
  def generate_mask(image, evt: gr.SelectData): # 'image' is the numpy array from the clicked component
492
+ global predictor, sam
493
 
494
  # Use the image passed by the event!
495
  if image is None:
496
  return None, None, "Cannot segment: Image component is empty.", state.get_gallery()
497
+
498
+ # Initialize SAM if not already done
499
+ if sam is None:
500
+ sam_checkpoint = "./sam_vit_h_4b8939.pth"
501
+ sam = initialize_sam(sam_checkpoint)
502
 
503
  # Ensure the image is a NumPy array in RGB format (Gradio usually provides this)
504
  if not isinstance(image, np.ndarray):
 
955
  args = args_parser.parse_args(args=[
956
  "--train_path", "./database/Objects/masks",
957
  "--test_path", "temp_path_placeholder", # This will be updated during runtime
958
+ "--pretrained_weights", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
959
  "--output_dir", f"exps/output_RankSelect_{imsize}_mask", # Default tag, will be updated
960
  ])
961
 
 
963
  os.makedirs(args.output_dir, exist_ok=True)
964
  #model, autocast_dtype = setup_and_build_model(args)
965
 
966
+ @spaces.GPU
967
  def detect_objects(input_img, score_threshold=0.52, tag="mask"):
968
  """Main function to detect objects in an image"""
969
+ global sam
970
+
971
+ # Initialize SAM if not already done
972
+ if sam is None:
973
+ sam_checkpoint = "./sam_vit_h_4b8939.pth"
974
+ sam = initialize_sam(sam_checkpoint)
975
  # Create temporary file for the input image
976
  with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
977
  temp_path = f.name
 
1283
  # ===== FACE DETECTION AND RECOGNITION PART =====
1284
 
1285
  # Initialize face detection and recognition models
1286
+ @spaces.GPU
1287
  def initialize_face_models():
1288
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1289
  mtcnn = MTCNN(
 
1379
  return embeddings, face_names, face_paths, face_anns
1380
 
1381
  # Detect and recognize faces in an image
1382
+ @spaces.GPU
1383
  def detect_faces(input_img, score_threshold=0.7):
1384
  mtcnn, resnet, device = initialize_face_models()
1385
 
 
1533
 
1534
  return matches, similarities
1535
  # 1. Add the combined detection function
1536
+ @spaces.GPU
1537
  def combined_detection(img, obj_threshold, face_threshold, tag):
1538
  """
1539
  Run both object detection and face detection on the same image
 
1629
  from qwen_vl_utils import process_vision_info
1630
 
1631
  # Load model and processor at the application level for reuse
1632
+ @spaces.GPU
1633
  def load_qwen2vl_model():
1634
  model = Qwen2VLForConditionalGeneration.from_pretrained(
1635
+ "Qwen/Qwen2-VL-2B-Instruct", # Use the base model for HF Spaces
 
1636
  torch_dtype=torch.bfloat16,
1637
+ device_map="auto"
1638
  )
1639
  min_pixels = 256 * 28 * 28
1640
  max_pixels = 1280 * 28 * 28
1641
  processor = AutoProcessor.from_pretrained(
1642
  "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
1643
  )
 
1644
  return model, processor
1645
 
1646
+ # Initialize model variables
1647
+ qwen_model = None
1648
+ qwen_processor = None
1649
+ qwen_model_loaded = False
 
 
 
1650
 
1651
  # Function to process detection results and use Qwen2-VL for answering questions
1652
+ @spaces.GPU
1653
  def ask_qwen_about_detections(input_image, question, obj_threshold, face_threshold, tag):
1654
  """
1655
  Process an image with detection and use Qwen2-VL to answer questions
1656
  """
1657
+ global qwen_model, qwen_processor, qwen_model_loaded
1658
+
1659
+ # Load model if not already loaded
1660
+ if qwen_model is None:
1661
+ try:
1662
+ qwen_model, qwen_processor = load_qwen2vl_model()
1663
+ qwen_model_loaded = True
1664
+ except Exception as e:
1665
+ return f"Failed to load Qwen2-VL model: {e}", None, None
1666
 
1667
  # Get detection results and formatted text
1668
  qwen_input, output_img = process_image_for_qwen(input_image, obj_threshold, face_threshold, tag)
 
2179
  lines=2
2180
  )
2181
 
2182
+ qwen_ask_button = gr.Button("Ask RC-MLLM-2B")
2183
 
2184
  with gr.Column():
2185
  qwen_output_image = gr.Image(label="Detection Result")
 
2207
 
2208
  # Model status display
2209
  model_status = gr.Markdown(
2210
+ "πŸ”„ RC-MLLM model will be loaded when first used (ZeroGPU)"
 
2211
  )
2212
 
2213
  # Instructions for RC-MLLM section