tokeron commited on
Commit
0a9b595
Β·
verified Β·
1 Parent(s): f9cb207

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. .gitignore +56 -0
  2. LICENSE +51 -0
  3. README.md +156 -6
  4. app.py +15 -0
  5. data/snoopy.jpg +0 -0
  6. requirements.txt +9 -0
  7. sam_gui.py +921 -0
  8. sam_inference.py +625 -0
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # PyTorch
24
+ *.pth
25
+ *.pt
26
+
27
+ # Jupyter Notebook
28
+ .ipynb_checkpoints
29
+
30
+ # Environment
31
+ .env
32
+ .venv
33
+ env/
34
+ venv/
35
+
36
+ # IDE
37
+ .vscode/
38
+ .idea/
39
+ *.swp
40
+ *.swo
41
+
42
+ # OS
43
+ .DS_Store
44
+ Thumbs.db
45
+
46
+ # Gradio temporary files
47
+ gradio_cached_examples/
48
+ flagged/
49
+
50
+ # SAM outputs (keep structure but ignore content)
51
+ masks/*/*
52
+ !masks/.gitkeep
53
+
54
+ # Large model files
55
+ *.bin
56
+ *.safetensors
LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 SAM GUI Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ================================================================================
24
+
25
+ ACKNOWLEDGMENTS AND THIRD-PARTY LICENSES:
26
+
27
+ This software is a GUI wrapper that integrates the following models and research:
28
+
29
+ 1. SAM 2.1 (Segment Anything Model 2.1) by Meta AI
30
+ - Original Paper: "SAM 2: Segment Anything in Images and Videos"
31
+ - Authors: Nikhila Ravi, Valentin Gabeur, Yuan-Ting Hu, et al.
32
+ - License: Apache 2.0
33
+ - Repository: https://github.com/facebookresearch/segment-anything-2
34
+ - This GUI is NOT affiliated with Meta AI - it's an independent interface
35
+
36
+ 2. Grounding DINO by IDEA Research
37
+ - Original Paper: "Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection"
38
+ - Authors: Shilong Liu, Zhaoyang Zeng, Tianhe Ren, et al.
39
+ - License: Apache 2.0
40
+ - Repository: https://github.com/IDEA-Research/GroundingDINO
41
+ - This GUI is NOT affiliated with IDEA Research - it's an independent interface
42
+
43
+ DISCLAIMER:
44
+ This is purely a GUI interface to make these powerful AI models easier to use.
45
+ All credit for the underlying AI technology goes to the original researchers.
46
+ This project only provides a user-friendly web interface and does not claim
47
+ any ownership of the underlying models or algorithms.
48
+
49
+ The models are downloaded from Hugging Face and used according to their
50
+ respective licenses. Please refer to the original repositories for detailed
51
+ license terms and attribution requirements.
README.md CHANGED
@@ -1,12 +1,162 @@
1
  ---
2
- title: SAM Grounding DINO
3
- emoji: πŸ‘€
4
- colorFrom: green
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SAM-Grounding-DINO
3
+ emoji: 🎭
4
+ colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
 
7
  app_file: app.py
 
8
  ---
9
+ # 🎭 SAM 2.1 + Grounding DINO Interactive Segmentation
10
 
11
+ A web application combining Meta's SAM 2.1 and Grounding DINO for both text-based and point-based image segmentation to enable creating and downloading a desired mask.
12
+
13
+ ## ✨ Features
14
+
15
+ - **πŸ” Text-Based Segmentation**: Type what you want to segment (e.g., "snoopy", "person", "car")
16
+ - **πŸ“ Point-Based Segmentation**: Click on objects for precise manual control
17
+ - **🎭 Multiple Mask Generation**: Generate 1-5 masks and browse through them
18
+ - **πŸ€– SAM 2.1 + Grounding DINO**: Powered by Meta's SAM 2.1 and IDEA Research's Grounding DINO
19
+ - **πŸ“± Smart Auto-Detection**: Automatically chooses between text and point modes
20
+ - **πŸ’Ύ Multiple Export Formats**: Download masks as PNG, JPG, or PyTorch tensors
21
+ - **πŸ–ΌοΈ High-Resolution Display**: View images and masks in full detail
22
+ - **⚑ Real-Time Processing**: Fast inference with GPU acceleration
23
+
24
+ ## πŸš€ Quick Start
25
+
26
+ ### Installation
27
+
28
+ 1. Clone or download the repository
29
+ 2. Install dependencies:
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ### Running the App
35
+
36
+ ```bash
37
+ streamlit run streamlit_sam_app.py
38
+ ```
39
+
40
+ The app will open in your browser at `http://localhost:8501`
41
+
42
+ ## 🎯 How to Use
43
+
44
+ ### 1. Upload an Image
45
+ - Click "πŸ“· Upload an image" to select an image file
46
+ - Supported formats: JPG, JPEG, PNG, BMP
47
+
48
+ ### 2. Add Points
49
+ Choose between **Positive** (include) or **Negative** (exclude) point mode:
50
+
51
+ #### Quick Presets:
52
+ - **🎯 Center**: Add point at image center
53
+ - **↖️ Top-Left**: Add point at top-left quarter
54
+ - **↗️ Top-Right**: Add point at top-right quarter
55
+ - **🎲 Random**: Add random point anywhere
56
+
57
+ #### Manual Input:
58
+ - Enter X,Y coordinates manually
59
+ - Points are validated against image boundaries
60
+
61
+ ### 3. Generate Segmentation Mask
62
+ - Click "🎯 Generate Segmentation Mask"
63
+ - Adjust the mask threshold in the sidebar (0.0-1.0)
64
+ - Wait for SAM 2.0 to process (may take 10-30 seconds)
65
+
66
+ ### 4. View Results
67
+ - **Original Image with Points**: Shows your input selections
68
+ - **Generated Segmentation Mask**: Red overlay on original image
69
+ - **Binary Mask Preview**: Black/white mask for download
70
+ - **Statistics**: Pixel counts and coverage percentage
71
+
72
+ ### 5. Download Results
73
+ - **πŸ“₯ Download Mask (PNG)**: Binary mask file
74
+ - **πŸ“₯ Download Overlay (PNG)**: Mask overlaid on original
75
+ - **πŸ“₯ Download Data (JSON)**: Complete metadata and statistics
76
+
77
+ ## πŸŽ›οΈ Advanced Controls
78
+
79
+ ### Sidebar Options:
80
+ - **Point Mode**: Switch between Positive/Negative points
81
+ - **Mask Threshold**: Control mask sensitivity (lower = larger masks)
82
+ - **Clear Points**: Remove all points at once
83
+
84
+ ### Point Management:
85
+ - View all current points with coordinates
86
+ - Delete individual points with πŸ—‘οΈ buttons
87
+ - Real-time count of positive/negative points
88
+
89
+ ## πŸ”§ Technical Details
90
+
91
+ ### SAM 2.0 Model
92
+ - Uses `facebook/sam2-hiera-small` by default
93
+ - Automatically downloads model weights on first run
94
+ - Runs on GPU if available, CPU otherwise
95
+
96
+ ### Dependencies
97
+ - `streamlit`: Web interface
98
+ - `torch`: PyTorch for model inference
99
+ - `transformers`: Hugging Face model loading
100
+ - `PIL`: Image processing
101
+ - `matplotlib`: Visualization
102
+ - `numpy`: Numerical operations
103
+ - `opencv-python`: Image processing utilities
104
+
105
+ ### System Requirements
106
+ - Python 3.8+
107
+ - 4GB+ RAM recommended
108
+ - GPU recommended for faster processing
109
+
110
+ ## πŸ› Troubleshooting
111
+
112
+ ### Common Issues:
113
+
114
+ 1. **Model Download Fails**:
115
+ - Check internet connection
116
+ - Ensure Hugging Face access (may require token for some models)
117
+
118
+ 2. **CUDA Out of Memory**:
119
+ - Try smaller model size
120
+ - Reduce image resolution
121
+ - Use CPU mode: set `CUDA_VISIBLE_DEVICES=""`
122
+
123
+ 3. **Slow Processing**:
124
+ - Use GPU if available
125
+ - Try `sam2-hiera-tiny` model for faster inference
126
+
127
+ 4. **Import Errors**:
128
+ - Ensure all dependencies are installed: `pip install -r requirements.txt`
129
+
130
+ ## πŸ“ File Structure
131
+
132
+ ```
133
+ SAM/
134
+ β”œβ”€β”€ streamlit_sam_app.py # Main application
135
+ β”œβ”€β”€ fixed_sam_interface.py # Original Gradio version
136
+ β”œβ”€β”€ requirements.txt # Dependencies
137
+ └── README.md # This file
138
+ ```
139
+
140
+ ## 🎨 Interface Screenshots
141
+
142
+ The app features a clean, modern interface with:
143
+ - Full-width image display
144
+ - Intuitive sidebar controls
145
+ - Real-time point visualization
146
+ - Side-by-side result comparison
147
+ - Comprehensive download options
148
+
149
+ ## 🀝 Contributing
150
+
151
+ Feel free to submit issues, feature requests, or pull requests!
152
+
153
+ ## πŸ“„ License
154
+
155
+ This project uses Meta's SAM 2.0 model. Please refer to Meta's license terms for the model weights.
156
+
157
+ ## πŸ™ Acknowledgments
158
+
159
+ - Meta AI for the incredible SAM 2.0 model
160
+ - Streamlit for the amazing web app framework
161
+ - Hugging Face for model hosting
162
+ - The open-source community for all the dependencies
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM 2.1 + Grounding DINO Web Interface
4
+ Hugging Face Spaces Entry Point
5
+ """
6
+
7
+ import gradio as gr
8
+ from sam_gui import create_interface
9
+
10
+ # Create the interface - this is what Hugging Face Spaces will use
11
+ demo = create_interface()
12
+
13
+ # Launch the interface
14
+ if __name__ == "__main__":
15
+ demo.launch()
data/snoopy.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers>=4.40.0
4
+ gradio>=4.0.0
5
+ pillow>=9.0.0
6
+ numpy>=1.21.0
7
+ matplotlib>=3.5.0
8
+ opencv-python>=4.5.0
9
+ groundingdino-py>=0.4.0
sam_gui.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM 2.1 Interface
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+ from transformers import Sam2Model, Sam2Processor
12
+ import warnings
13
+ import io
14
+ import base64
15
+ import os
16
+ from datetime import datetime
17
+ # Grounding DINO will be imported dynamically in the initialization function
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+ # Global model instance to avoid reloading
22
+ MODEL = None
23
+ PROCESSOR = None
24
+ DEVICE = None
25
+
26
+ # Global Grounding DINO instance
27
+ GROUNDING_DINO = None
28
+
29
+ # Global state for saving
30
+ CURRENT_MASK = None
31
+ CURRENT_IMAGE_NAME = None
32
+ CURRENT_POINTS = None
33
+
34
+ def initialize_sam(model_size="small"):
35
+ """Initialize SAM model once"""
36
+ global MODEL, PROCESSOR, DEVICE
37
+
38
+ if MODEL is None:
39
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(f"Initializing SAM 2.1 {model_size} on {DEVICE}...")
41
+
42
+ model_name = f"facebook/sam2-hiera-{model_size}"
43
+ MODEL = Sam2Model.from_pretrained(model_name).to(DEVICE)
44
+ PROCESSOR = Sam2Processor.from_pretrained(model_name)
45
+
46
+ print("βœ“ Model loaded successfully!")
47
+
48
+ return MODEL, PROCESSOR, DEVICE
49
+
50
+ def initialize_grounding_dino():
51
+ """Initialize Grounding DINO model once"""
52
+ global GROUNDING_DINO, DEVICE
53
+
54
+ if GROUNDING_DINO is None:
55
+ if DEVICE is None:
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ print(f"Initializing Grounding DINO on {DEVICE}...")
59
+
60
+ try:
61
+ # Use Hugging Face model for Grounding DINO
62
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
63
+
64
+ model_id = "IDEA-RESEARCH/grounding-dino-base"
65
+ GROUNDING_DINO = {
66
+ 'processor': AutoProcessor.from_pretrained(model_id),
67
+ 'model': AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
68
+ }
69
+ print("βœ“ Grounding DINO loaded successfully!")
70
+ except Exception as e:
71
+ print(f"❌ Failed to load Grounding DINO: {e}")
72
+ print("Note: Falling back to manual point selection only")
73
+ GROUNDING_DINO = None
74
+
75
+ return GROUNDING_DINO
76
+
77
+ def detect_objects_with_text(image, text_prompt, confidence_threshold=0.25):
78
+ """Use Grounding DINO to detect objects based on text prompt"""
79
+ global GROUNDING_DINO
80
+
81
+ try:
82
+ # Initialize Grounding DINO if needed
83
+ grounding_dino = initialize_grounding_dino()
84
+ if grounding_dino is None:
85
+ return None, "❌ Grounding DINO not available"
86
+
87
+ # Fix image format
88
+ pil_image = fix_image_array(image)
89
+
90
+ # Prepare inputs for Grounding DINO
91
+ processor = grounding_dino['processor']
92
+ model = grounding_dino['model']
93
+
94
+ # Process inputs
95
+ inputs = processor(images=pil_image, text=text_prompt, return_tensors="pt").to(DEVICE)
96
+
97
+ # Run inference
98
+ with torch.no_grad():
99
+ outputs = model(**inputs)
100
+
101
+ # Post-process results
102
+ results = processor.post_process_grounded_object_detection(
103
+ outputs,
104
+ input_ids=inputs.input_ids,
105
+ threshold=confidence_threshold,
106
+ text_threshold=0.25,
107
+ target_sizes=[pil_image.size[::-1]] # (height, width)
108
+ )[0]
109
+
110
+ if len(results['boxes']) == 0:
111
+ return None, f"No objects found for prompt: '{text_prompt}'"
112
+
113
+ # Convert boxes to the format expected by SAM [x1, y1, x2, y2]
114
+ detected_boxes = []
115
+ for box in results['boxes']:
116
+ x1, y1, x2, y2 = box.tolist()
117
+ detected_boxes.append([int(x1), int(y1), int(x2), int(y2)])
118
+
119
+ return detected_boxes, f"βœ“ Found {len(detected_boxes)} object(s) for '{text_prompt}'"
120
+
121
+ except Exception as e:
122
+ return None, f"❌ Detection failed: {str(e)}"
123
+
124
+ def fix_image_array(image):
125
+ """Fix image input for SAM processing - handles filepath, numpy array, or PIL Image"""
126
+ if isinstance(image, str):
127
+ # Handle filepath input from Gradio
128
+ return Image.open(image).convert("RGB")
129
+
130
+ elif isinstance(image, np.ndarray):
131
+ # Make sure array is contiguous
132
+ if not image.flags['C_CONTIGUOUS']:
133
+ image = np.ascontiguousarray(image)
134
+
135
+ # Ensure uint8 dtype
136
+ if image.dtype != np.uint8:
137
+ if image.max() <= 1.0:
138
+ image = (image * 255).astype(np.uint8)
139
+ else:
140
+ image = image.astype(np.uint8)
141
+
142
+ # Convert to PIL Image to avoid any stride issues
143
+ return Image.fromarray(image).convert("RGB")
144
+
145
+ elif isinstance(image, Image.Image):
146
+ return image.convert("RGB")
147
+
148
+ else:
149
+ raise ValueError(f"Unsupported image type: {type(image)}")
150
+
151
+ def apply_mask_post_processing(mask, stability_threshold=0.95):
152
+ """Apply post-processing to refine mask size and quality"""
153
+ import cv2
154
+
155
+ # Convert to binary mask
156
+ binary_mask = (mask > 0).astype(np.uint8)
157
+
158
+ # Apply morphological operations to clean up the mask
159
+ kernel_size = max(3, int(mask.shape[0] * 0.01)) # Adaptive kernel size
160
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
161
+
162
+ # Close small holes
163
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
164
+
165
+ # Remove small noise
166
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
167
+
168
+ return binary_mask.astype(np.float32)
169
+
170
+ def apply_erosion_dilation(mask, erosion_dilation_value):
171
+ """Apply erosion or dilation to adjust mask size"""
172
+ import cv2
173
+
174
+ binary_mask = (mask > 0).astype(np.uint8)
175
+
176
+ if erosion_dilation_value == 0:
177
+ return mask
178
+
179
+ kernel_size = abs(erosion_dilation_value)
180
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
181
+
182
+ if erosion_dilation_value > 0:
183
+ # Dilate (make larger)
184
+ binary_mask = cv2.dilate(binary_mask, kernel, iterations=1)
185
+ else:
186
+ # Erode (make smaller)
187
+ binary_mask = cv2.erode(binary_mask, kernel, iterations=1)
188
+
189
+ return binary_mask.astype(np.float32)
190
+
191
+ def save_binary_mask(mask, image_name, points, mask_threshold, erosion_dilation, save_low_res=False, custom_folder_name=None):
192
+ """Save binary mask to organized folder structure"""
193
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
194
+
195
+ try:
196
+ # Store current state for saving
197
+ CURRENT_MASK = mask
198
+ CURRENT_IMAGE_NAME = image_name
199
+ CURRENT_POINTS = points
200
+
201
+ # Extract image name without extension and sanitize
202
+ if image_name:
203
+ base_name = os.path.splitext(os.path.basename(image_name))[0]
204
+ # Remove any path separators and special characters
205
+ base_name = base_name.replace('/', '_').replace('\\', '_').replace(':', '_').replace(' ', '_')
206
+ else:
207
+ base_name = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
208
+
209
+ # Choose folder tag: user-provided name if available, else 'default'
210
+ folder_tag = None
211
+ if custom_folder_name and str(custom_folder_name).strip():
212
+ folder_tag = str(custom_folder_name).strip().replace(' ', '_')
213
+ else:
214
+ folder_tag = "default"
215
+
216
+
217
+
218
+ # Create folder structure: masks/<image_base>/<folder_tag>/
219
+ folder_name = f"masks/{base_name}/{folder_tag}"
220
+ os.makedirs(folder_name, exist_ok=True)
221
+
222
+ # Create binary mask (0 and 255 values)
223
+ binary_mask = (mask > 0).astype(np.uint8) * 255
224
+
225
+ # Calculate low resolution dimensions if requested
226
+ original_height, original_width = binary_mask.shape
227
+ if save_low_res:
228
+ # Calculate sqrt-based resolution
229
+ sqrt_factor = int(np.sqrt(max(original_width, original_height)))
230
+ low_res_width = sqrt_factor
231
+ low_res_height = sqrt_factor
232
+ print(f"Original mask size: {original_width}x{original_height}")
233
+ print(f"Low-res mask size: {low_res_width}x{low_res_height}")
234
+
235
+ # Save binary mask
236
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
237
+ # Sanitize filename - replace problematic characters
238
+ threshold_str = f"{mask_threshold:.2f}".replace('.', 'p') # 0.30 -> 0p30
239
+ adj_str = f"{erosion_dilation:+d}".replace('+', 'plus').replace('-', 'minus') # +2 -> plus2, -2 -> minus2
240
+
241
+ saved_paths = []
242
+
243
+ # Save full resolution mask as JPEG with a simple filename
244
+ mask_filename = "image.jpg"
245
+ mask_path = os.path.join(folder_name, mask_filename)
246
+
247
+ mask_image = Image.fromarray(binary_mask, mode='L')
248
+ mask_image.save(mask_path, format="JPEG", quality=95, optimize=True)
249
+ saved_paths.append(mask_path)
250
+
251
+ # Save tensor mask (.pt) as float tensor (0.0/1.0)
252
+ tensor_filename = "image.pt"
253
+ tensor_path = os.path.join(folder_name, tensor_filename)
254
+ torch.save(torch.from_numpy((mask > 0).astype(np.float32)), tensor_path)
255
+ saved_paths.append(tensor_path)
256
+
257
+ # Save low resolution mask if requested
258
+ if save_low_res:
259
+ # Resize mask to low resolution
260
+ low_res_mask = mask_image.resize((low_res_width, low_res_height), Image.Resampling.NEAREST)
261
+
262
+ low_res_filename = f"mask_lowres_{sqrt_factor}x{sqrt_factor}_t{threshold_str}_adj{adj_str}_{timestamp}.png"
263
+ low_res_path = os.path.join(folder_name, low_res_filename)
264
+
265
+ low_res_mask.save(low_res_path)
266
+ saved_paths.append(low_res_path)
267
+
268
+ # Also save metadata
269
+ metadata = {
270
+ "timestamp": timestamp,
271
+ "points": points,
272
+ "mask_threshold": mask_threshold,
273
+ "erosion_dilation": erosion_dilation,
274
+ "image_name": image_name,
275
+ "original_resolution": f"{original_width}x{original_height}",
276
+ "saved_paths": saved_paths,
277
+ "low_resolution_saved": save_low_res
278
+ }
279
+
280
+ if save_low_res:
281
+ metadata["low_resolution"] = f"{low_res_width}x{low_res_height}"
282
+ metadata["sqrt_factor"] = sqrt_factor
283
+
284
+ import json
285
+ metadata_path = os.path.join(folder_name, f"metadata_{timestamp}.json")
286
+ with open(metadata_path, 'w') as f:
287
+ json.dump(metadata, f, indent=2)
288
+
289
+ # Return appropriate message
290
+ if save_low_res:
291
+ return f"βœ… Masks saved:\nπŸ“ Full: {os.path.basename(mask_path)}\nπŸ“ Low-res: {os.path.basename(low_res_path)}"
292
+ else:
293
+ return f"βœ… Mask saved to: {os.path.basename(mask_path)}"
294
+
295
+ except Exception as e:
296
+ return f"❌ Save failed: {str(e)}"
297
+
298
+ def process_sam_segmentation(image, points_data, bbox_data, mode, image_name=None, top_k=3, mask_threshold=0.0, stability_score_threshold=0.95, erosion_dilation=0, text_prompt=None, confidence_threshold=0.25):
299
+ """Main processing function with mask size controls - supports points, bounding boxes, and text prompts"""
300
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
301
+
302
+ if image is None:
303
+ return None, None, "Please upload an image first."
304
+
305
+ # Check input based on mode
306
+ if mode == "Points":
307
+ if not points_data or len(points_data) == 0:
308
+ return None, None, "Please click on the image to select points."
309
+ elif mode == "Bounding Box":
310
+ if bbox_data is None:
311
+ return None, None, "Please click two corners to define a bounding box."
312
+ elif mode == "Text Prompt":
313
+ if not text_prompt or not text_prompt.strip():
314
+ return None, None, "Please enter a text prompt to detect objects."
315
+
316
+ try:
317
+ # Initialize model
318
+ model, processor, device = initialize_sam()
319
+
320
+ # Fix image
321
+ pil_image = fix_image_array(image)
322
+
323
+ # Prepare SAM inputs based on mode
324
+ input_points = None
325
+ input_labels = None
326
+ input_boxes = None
327
+ points = None
328
+
329
+ if mode == "Points":
330
+ # Extract points with positive/negative labels
331
+ points = []
332
+ labels = []
333
+ for point_info in points_data:
334
+ if isinstance(point_info, dict):
335
+ points.append([point_info.get("x", 0), point_info.get("y", 0)])
336
+ labels.append(1 if point_info.get("positive", True) else 0) # 1 = positive, 0 = negative
337
+ elif isinstance(point_info, (list, tuple)) and len(point_info) >= 2:
338
+ points.append([point_info[0], point_info[1]])
339
+ labels.append(1) # Default to positive for old format
340
+
341
+ if not points:
342
+ return None, "No valid points found."
343
+
344
+ print(f"Processing {len(points)} points: {points} with labels: {labels}")
345
+ input_points = [[points]]
346
+ input_labels = [[labels]]
347
+
348
+ elif mode == "Bounding Box":
349
+ # Use bounding box
350
+ bbox = bbox_data # [x1, y1, x2, y2]
351
+ print(f"Processing bounding box: {bbox}")
352
+ input_boxes = [[bbox]]
353
+ # For visualization, store the bbox corners as points
354
+ points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]]
355
+
356
+ elif mode == "Text Prompt":
357
+ # Use Grounding DINO to detect objects from text prompt
358
+ detected_boxes, detection_status = detect_objects_with_text(pil_image, text_prompt, confidence_threshold)
359
+ if detected_boxes is None:
360
+ return None, None, detection_status
361
+
362
+ # Use the first detected bounding box (highest confidence)
363
+ bbox = detected_boxes[0]
364
+ print(f"Using detected bounding box: {bbox}")
365
+ input_boxes = [[bbox]]
366
+ # For visualization, store the bbox corners as points
367
+ points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]]
368
+
369
+ # Process with SAM
370
+ processor_inputs = {
371
+ "images": pil_image,
372
+ "return_tensors": "pt"
373
+ }
374
+
375
+ # Add points and/or boxes based on what's available
376
+ if input_points is not None:
377
+ processor_inputs["input_points"] = input_points
378
+ processor_inputs["input_labels"] = input_labels
379
+
380
+ if input_boxes is not None:
381
+ processor_inputs["input_boxes"] = input_boxes
382
+
383
+ inputs = processor(**processor_inputs).to(device)
384
+
385
+ # Generate masks with multiple outputs for better control
386
+ with torch.no_grad():
387
+ outputs = model(**inputs, multimask_output=True)
388
+
389
+ # Get masks and scores
390
+ masks = processor.post_process_masks(
391
+ outputs.pred_masks.cpu(),
392
+ inputs["original_sizes"]
393
+ )[0]
394
+
395
+ scores = outputs.iou_scores.cpu().numpy().flatten()
396
+
397
+ # Get top-k masks and process all of them
398
+ top_indices = np.argsort(scores)[::-1][:top_k]
399
+
400
+ processed_masks = []
401
+ mask_scores = []
402
+
403
+ for i, idx in enumerate(top_indices):
404
+ mask = masks[0, idx].numpy()
405
+ score = scores[idx]
406
+
407
+ # Apply threshold to control mask size
408
+ if mask_threshold > 0:
409
+ mask = (mask > mask_threshold).astype(np.float32)
410
+
411
+ # Additional mask processing for size control
412
+ mask = apply_mask_post_processing(mask, stability_score_threshold)
413
+
414
+ # Apply erosion/dilation for fine size control
415
+ if erosion_dilation != 0:
416
+ mask = apply_erosion_dilation(mask, erosion_dilation)
417
+
418
+ processed_masks.append(mask)
419
+ mask_scores.append(score)
420
+
421
+ # Store current state for saving (use first mask as default)
422
+ CURRENT_MASK = processed_masks[0]
423
+ CURRENT_IMAGE_NAME = image_name
424
+ CURRENT_POINTS = points
425
+
426
+ # Create visualizations for the first mask
427
+ original_with_input = create_original_with_input_visualization(pil_image, points, bbox_data, mode)
428
+ mask_result = create_mask_visualization(pil_image, processed_masks[0], mask_scores[0], mask_threshold)
429
+
430
+ status = f"βœ“ Generated {len(processed_masks)} masks\nπŸ”„ Use navigation to browse masks"
431
+
432
+ # Return multiple masks and related data
433
+ return original_with_input, mask_result, status, processed_masks, mask_scores
434
+
435
+ except Exception as e:
436
+ print(f"Error in processing: {e}")
437
+ return None, None, f"Error: {str(e)}"
438
+
439
+ def create_original_with_input_visualization(pil_image, points, bbox, mode, negative_points=None):
440
+ """Create visualization of original image with input points/bbox overlay"""
441
+ # Convert PIL to numpy for matplotlib
442
+ img_array = np.array(pil_image)
443
+
444
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
445
+
446
+ # Show original image only
447
+ ax.imshow(img_array)
448
+
449
+ # Show input visualization based on mode
450
+ if mode == "Points":
451
+ total_points = 0
452
+ # Show positive points (green)
453
+ if points:
454
+ for point in points:
455
+ ax.plot(point[0], point[1], 'go', markersize=12, markeredgewidth=3, markerfacecolor='lime')
456
+ total_points += len(points)
457
+
458
+ # Show negative points (red)
459
+ if negative_points:
460
+ for point in negative_points:
461
+ ax.plot(point[0], point[1], 'ro', markersize=12, markeredgewidth=3, markerfacecolor='red')
462
+ total_points += len(negative_points)
463
+
464
+ pos_count = len(points) if points else 0
465
+ neg_count = len(negative_points) if negative_points else 0
466
+ title_suffix = f"Points: {pos_count}+ {neg_count}-" if neg_count > 0 else f"Points: {pos_count}"
467
+ elif mode == "Bounding Box" and bbox:
468
+ # Show bounding box
469
+ x1, y1, x2, y2 = bbox
470
+ width = x2 - x1
471
+ height = y2 - y1
472
+
473
+ # Draw bounding box rectangle
474
+ from matplotlib.patches import Rectangle
475
+ rect = Rectangle((x1, y1), width, height, linewidth=3, edgecolor='lime', facecolor='none')
476
+ ax.add_patch(rect)
477
+
478
+ # Show corner points
479
+ ax.plot([x1, x2], [y1, y2], 'go', markersize=8, markeredgewidth=2, markerfacecolor='lime')
480
+ title_suffix = f"BBox: {int(width)}Γ—{int(height)}"
481
+ else:
482
+ title_suffix = "No input"
483
+
484
+ ax.set_title(f"Input Selection ({title_suffix})", fontsize=14)
485
+ ax.axis('off')
486
+
487
+ # Convert to numpy array
488
+ fig.canvas.draw()
489
+ buf = fig.canvas.buffer_rgba()
490
+ result_array = np.asarray(buf)
491
+ # Convert RGBA to RGB
492
+ result_array = result_array[:, :, :3]
493
+
494
+ plt.close(fig)
495
+ return result_array
496
+
497
+ def create_mask_visualization(pil_image, mask, score, mask_threshold=0.0):
498
+ """Create clean mask visualization without input overlays"""
499
+ # Convert PIL to numpy for matplotlib
500
+ img_array = np.array(pil_image)
501
+
502
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
503
+
504
+ # Show original image
505
+ ax.imshow(img_array)
506
+
507
+ # Overlay mask in red
508
+ mask_overlay = np.zeros((*mask.shape, 4))
509
+ mask_overlay[mask > 0] = [1, 0, 0, 0.6] # Red with transparency
510
+ ax.imshow(mask_overlay)
511
+
512
+ ax.set_title(f"Generated Mask (Score: {float(score):.3f}, Threshold: {mask_threshold:.2f})", fontsize=14)
513
+ ax.axis('off')
514
+
515
+ # Convert to numpy array
516
+ fig.canvas.draw()
517
+ buf = fig.canvas.buffer_rgba()
518
+ result_array = np.asarray(buf)
519
+ # Convert RGBA to RGB
520
+ result_array = result_array[:, :, :3]
521
+
522
+ plt.close(fig)
523
+ return result_array
524
+
525
+ def create_interface():
526
+ """Create a simplified single-image annotator interface."""
527
+
528
+ with gr.Blocks(title="SAM 2.1 - Simple Annotator", theme=gr.themes.Soft(), css="""
529
+ .negative-mode-checkbox label {
530
+ color: #d00000 !important;
531
+ font-weight: 800 !important;
532
+ font-size: 16px !important;
533
+ }
534
+ """) as interface:
535
+ gr.HTML("""
536
+ <div style="text-align: center;">
537
+ <h1>🎯 AI-Powered Image Segmentation</h1>
538
+ <h2>SAM 2.1 + Grounding DINO</h2>
539
+ <p><strong>✨ Just type what you want to segment!</strong> Try "person", "face", "car", "dog" - or click points manually.</p>
540
+ <p>🎭 Generate multiple mask options and pick your favorite!</p>
541
+ <hr style="margin: 20px 0;">
542
+ <p style="font-size: 12px; color: #666;">
543
+ <strong>Acknowledgment:</strong> This is a GUI interface for research by Meta AI (SAM 2.1) and IDEA Research (Grounding DINO).<br>
544
+ All credit goes to the original researchers. This tool only provides an easy-to-use web interface.
545
+ </p>
546
+ </div>
547
+ """)
548
+
549
+ # Image input (single image) - directly annotate; this serves as uploader too
550
+ # Users can upload by clicking the annotatable image component below.
551
+ image_input = gr.Image(
552
+ label=None,
553
+ type="filepath",
554
+ height=0,
555
+ visible=False
556
+ )
557
+
558
+ # Text prompt input with clear button
559
+ with gr.Row():
560
+ text_prompt_input = gr.Textbox(
561
+ label="πŸ” Text Prompt (Optional)",
562
+ placeholder="Type what to segment (e.g., 'person', 'car', 'dog') and press Enter",
563
+ value="snoopy",
564
+ interactive=True,
565
+ info="πŸ’‘ Text = auto-detection | Empty + clicking = manual points | Text takes priority if both provided",
566
+ scale=4
567
+ )
568
+ clear_text_btn = gr.Button("πŸ—‘οΈ Clear Text", variant="secondary", scale=1)
569
+
570
+ # Number of masks to generate
571
+ num_masks = gr.Slider(
572
+ minimum=1,
573
+ maximum=5,
574
+ value=3,
575
+ step=1,
576
+ label="🎭 Number of Masks to Generate",
577
+ info="Generate multiple mask options to choose from"
578
+ )
579
+
580
+ # Main layout: Selected Points on the left, annotatable image in the center, preview on the right
581
+ with gr.Row():
582
+ with gr.Column(scale=1):
583
+ clear_points_btn = gr.Button("πŸ—‘οΈ Clear Points", variant="secondary", size="sm")
584
+ points_display = gr.JSON(label="πŸ“ Selected Points", value=[], visible=True)
585
+ with gr.Column(scale=3):
586
+ # Negative mode toggle with clear red styling
587
+ negative_point_mode = gr.Checkbox(
588
+ label="βž– NEGATIVE POINT MODE",
589
+ value=False,
590
+ info="πŸ”΄ Enable to add negative points (shown in red)",
591
+ interactive=True,
592
+ elem_classes="negative-mode-checkbox"
593
+ )
594
+ original_with_input = gr.Image(
595
+ label="πŸ“ Click to Annotate (toggle negative mode to exclude)",
596
+ height=640,
597
+ interactive=True,
598
+ value="data/snoopy.jpg"
599
+ )
600
+ with gr.Column(scale=1):
601
+ points_overlay = gr.Image(label="πŸ“ Points Preview (green=positive, red=negative)", height=720, interactive=False)
602
+
603
+ # Action buttons
604
+ with gr.Row():
605
+ generate_btn = gr.Button("🎯 Generate Mask", variant="primary", size="lg")
606
+
607
+ # Mask result with navigation
608
+ with gr.Row():
609
+ mask_result = gr.Image(label="🎭 Generated Mask", height=512)
610
+
611
+ # Mask navigation controls
612
+ with gr.Row():
613
+ prev_mask_btn = gr.Button("⬅️ Previous", variant="secondary", size="sm")
614
+ mask_info = gr.Textbox(
615
+ label="Mask Info",
616
+ value="No masks generated yet",
617
+ interactive=False,
618
+ scale=2
619
+ )
620
+ next_mask_btn = gr.Button("➑️ Next", variant="secondary", size="sm")
621
+
622
+ # Save controls under mask
623
+ with gr.Row():
624
+ mask_name_input = gr.Textbox(label="Folder name (optional)", placeholder="e.g., michael_phelps_bottom_left", scale=2)
625
+ format_selector = gr.Radio(
626
+ choices=["PNG", "JPG", "PT"],
627
+ value="PNG",
628
+ label="πŸ“ Download Format",
629
+ scale=1
630
+ )
631
+ save_btn = gr.Button("πŸ’Ύ Save & Download", variant="stop", size="lg", scale=1)
632
+
633
+ # Status and Download
634
+ with gr.Row():
635
+ status_text = gr.Textbox(label="πŸ“Š Status", interactive=False, lines=3, scale=2)
636
+ download_file = gr.File(label="πŸ“₯ Download", visible=False, scale=1)
637
+
638
+ # State to store points and masks
639
+ points_state = gr.State([])
640
+ masks_data = gr.State({"masks": [], "scores": [], "image": None}) # Store all mask data
641
+ current_mask_index = gr.State(0) # Current mask being viewed
642
+
643
+ # Event handlers
644
+ def on_image_click(image, current_points, negative_mode, evt: gr.SelectData):
645
+ """Handle clicks on the image for point annotations only."""
646
+ if evt.index is not None and image is not None:
647
+ x, y = evt.index
648
+ try:
649
+ pil_image = fix_image_array(image)
650
+ is_negative = negative_mode
651
+ new_point = {"x": int(x), "y": int(y), "positive": not is_negative}
652
+ updated_points = current_points + [new_point]
653
+
654
+ positive_points = [[p["x"], p["y"]] for p in updated_points if p.get("positive", True)]
655
+ negative_points = [[p["x"], p["y"]] for p in updated_points if not p.get("positive", True)]
656
+
657
+ updated_visualization = create_original_with_input_visualization(
658
+ pil_image, positive_points, None, "Points", negative_points
659
+ )
660
+
661
+ point_type = "positive" if not is_negative else "negative"
662
+ pos_count = len(positive_points)
663
+ neg_count = len(negative_points)
664
+ return updated_points, updated_points, updated_visualization, (
665
+ f"Added {point_type} point at ({x}, {y}). Total: {pos_count} positive, {neg_count} negative points."
666
+ )
667
+ except Exception as e:
668
+ print(f"Error in visualization: {e}")
669
+ return current_points, current_points, None, f"Error updating visualization: {str(e)}"
670
+ return current_points, current_points, None, "Click on the image to add points."
671
+
672
+ def on_image_upload(image):
673
+ """Handle image upload and show it for annotation."""
674
+ if image is not None:
675
+ try:
676
+ pil_image = fix_image_array(image)
677
+ img_array = np.array(pil_image)
678
+ # Populate both the annotation image (left) and the points preview (right)
679
+ return img_array, img_array, [], [], "Image uploaded. Click on the left image to add points (enable negative mode for exclusion)."
680
+ except Exception as e:
681
+ return None, None, [], [], f"Error loading image: {str(e)}"
682
+ return None, None, [], [], "No image uploaded."
683
+
684
+ def clear_all_points(image):
685
+ """Clear points and keep the image visible for annotation."""
686
+ try:
687
+ if image is not None:
688
+ pil_image = fix_image_array(image)
689
+ img_array = np.array(pil_image)
690
+ return [], [], img_array, img_array, None, "All points cleared. You can continue annotating."
691
+ except Exception:
692
+ pass
693
+ return [], [], None, None, None, "All points cleared."
694
+
695
+ def clear_text_prompt():
696
+ """Clear the text prompt."""
697
+ return "", "Text prompt cleared. You can now use manual points."
698
+
699
+ def generate_segmentation(image, points, text_prompt, num_masks_to_generate):
700
+ """Generate multiple segmentation masks - auto-detects input type."""
701
+ # Determine image name
702
+ if isinstance(image, str):
703
+ image_name = os.path.basename(image)
704
+ else:
705
+ # Prefer an explicit friendly default if metadata lacks a good name
706
+ image_name = None
707
+ if hasattr(image, 'orig_name'):
708
+ image_name = image.orig_name
709
+ elif isinstance(image, dict) and 'orig_name' in image:
710
+ image_name = image['orig_name']
711
+ elif hasattr(image, 'name'):
712
+ image_name = image.name
713
+ if not image_name or 'tmp' in str(image_name).lower() or 'uploaded_image' in str(image_name).lower():
714
+ image_name = "michael_phelps_bottom_left.jpg"
715
+
716
+ # Auto-detect input type and run segmentation
717
+ has_text = text_prompt and text_prompt.strip()
718
+ has_points = points and len(points) > 0
719
+
720
+ if has_text and has_points:
721
+ # Combine text detection with manual point refinement
722
+ status_info = "🎯 Combining text detection with manual point refinement"
723
+
724
+ # First, detect with text to get initial bounding box
725
+ detected_boxes, detection_status = detect_objects_with_text(image, text_prompt, 0.25)
726
+ if detected_boxes:
727
+ # Use the detected bounding box AND manual points together
728
+ bbox = detected_boxes[0] # Use first detection as guidance
729
+
730
+ # Process with both bounding box and points
731
+ # The points will be used to refine the segmentation within the detected area
732
+ _, mask_img, status, masks, scores = process_sam_segmentation(
733
+ image, points, bbox, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25
734
+ )
735
+ status = f"{status_info}\nβœ“ Text: {detection_status}\nβœ“ Using {len(points)} manual points for refinement\n{status}"
736
+ masks_data_dict = {"masks": masks, "scores": scores, "image": image}
737
+ return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})"
738
+ else:
739
+ # Fall back to points only if text detection fails
740
+ _, mask_img, status, masks, scores = process_sam_segmentation(
741
+ image, points, None, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25
742
+ )
743
+ status = f"πŸ”„ Text detection failed, using {len(points)} manual points only\n{status}"
744
+ masks_data_dict = {"masks": masks, "scores": scores, "image": image}
745
+ return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})"
746
+ elif has_text:
747
+ # Use text prompt
748
+ _, mask_img, status, masks, scores = process_sam_segmentation(
749
+ image, None, None, "Text Prompt", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, text_prompt, 0.25
750
+ )
751
+ masks_data_dict = {"masks": masks, "scores": scores, "image": image}
752
+ return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})"
753
+ elif has_points:
754
+ # Use points
755
+ _, mask_img, status, masks, scores = process_sam_segmentation(
756
+ image, points, None, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25
757
+ )
758
+ masks_data_dict = {"masks": masks, "scores": scores, "image": image}
759
+ return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})"
760
+ else:
761
+ return None, "❌ Please either enter a text prompt or click points on the image.", {"masks": [], "scores": [], "image": None}, 0, "No masks generated"
762
+
763
+ def navigate_mask(direction, current_index, masks_data):
764
+ """Navigate through generated masks"""
765
+ masks = masks_data.get("masks", [])
766
+ scores = masks_data.get("scores", [])
767
+ image = masks_data.get("image", None)
768
+
769
+ if not masks or len(masks) == 0:
770
+ return None, current_index, "No masks available"
771
+
772
+ # Calculate new index
773
+ if direction == "next":
774
+ new_index = (current_index + 1) % len(masks)
775
+ else: # previous
776
+ new_index = (current_index - 1) % len(masks)
777
+
778
+ # Get the mask at new index
779
+ mask = masks[new_index]
780
+ score = scores[new_index]
781
+
782
+ # Update global state for saving
783
+ global CURRENT_MASK
784
+ CURRENT_MASK = mask
785
+
786
+ # Create visualization
787
+ if image is not None:
788
+ pil_image = fix_image_array(image)
789
+ mask_visualization = create_mask_visualization(pil_image, mask, score, 0.0)
790
+ else:
791
+ mask_visualization = None
792
+
793
+ mask_info_text = f"Mask {new_index + 1} of {len(masks)} (Score: {score:.3f})"
794
+
795
+ return mask_visualization, new_index, mask_info_text
796
+
797
+ def save_and_download_mask(custom_folder_name, download_format):
798
+ """Save mask locally and prepare download for user."""
799
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
800
+ if CURRENT_MASK is None:
801
+ return "❌ No mask to save. Generate a mask first.", None
802
+ if CURRENT_POINTS is None:
803
+ return "❌ No points available. Generate a mask first.", None
804
+
805
+ try:
806
+ # Save locally (keep existing hierarchy)
807
+ local_save_status = save_binary_mask(
808
+ CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS,
809
+ 0.0, 0, False, custom_folder_name=(custom_folder_name or None)
810
+ )
811
+
812
+ # Create download file
813
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
814
+ base_name = os.path.splitext(os.path.basename(CURRENT_IMAGE_NAME or "mask"))[0]
815
+
816
+ if download_format == "PNG":
817
+ # Create PNG for download
818
+ binary_mask = (CURRENT_MASK > 0).astype(np.uint8) * 255
819
+ mask_image = Image.fromarray(binary_mask, mode='L')
820
+ download_path = f"/tmp/mask_{base_name}_{timestamp}.png"
821
+ mask_image.save(download_path, format="PNG")
822
+
823
+ elif download_format == "JPG":
824
+ # Create JPG for download
825
+ binary_mask = (CURRENT_MASK > 0).astype(np.uint8) * 255
826
+ mask_image = Image.fromarray(binary_mask, mode='L')
827
+ download_path = f"/tmp/mask_{base_name}_{timestamp}.jpg"
828
+ mask_image.save(download_path, format="JPEG", quality=95)
829
+
830
+ elif download_format == "PT":
831
+ # Create PyTorch tensor for download
832
+ download_path = f"/tmp/mask_{base_name}_{timestamp}.pt"
833
+ torch.save(torch.from_numpy((CURRENT_MASK > 0).astype(np.float32)), download_path)
834
+
835
+ # Make download visible and return file
836
+ download_status = f"βœ… {local_save_status}\nπŸ“₯ Download ready: {download_format} format"
837
+ return download_status, gr.File.update(value=download_path, visible=True)
838
+
839
+ except Exception as e:
840
+ return f"❌ Save/download failed: {str(e)}", None
841
+
842
+ # Wire events
843
+ # Let the annotatable image also handle image uploads (drag & drop / click upload)
844
+ original_with_input.upload(
845
+ on_image_upload,
846
+ inputs=[original_with_input],
847
+ outputs=[original_with_input, points_overlay, points_state, points_display, status_text]
848
+ )
849
+
850
+ original_with_input.select(
851
+ on_image_click,
852
+ inputs=[original_with_input, points_state, negative_point_mode],
853
+ outputs=[points_state, points_display, points_overlay, status_text]
854
+ )
855
+
856
+ # Generate button and Enter key support
857
+ generate_btn.click(
858
+ generate_segmentation,
859
+ inputs=[original_with_input, points_state, text_prompt_input, num_masks],
860
+ outputs=[mask_result, status_text, masks_data, current_mask_index, mask_info]
861
+ )
862
+
863
+ # Enter key support for text prompt
864
+ text_prompt_input.submit(
865
+ generate_segmentation,
866
+ inputs=[original_with_input, points_state, text_prompt_input, num_masks],
867
+ outputs=[mask_result, status_text, masks_data, current_mask_index, mask_info]
868
+ )
869
+
870
+ # Mask navigation
871
+ prev_mask_btn.click(
872
+ lambda idx, data: navigate_mask("prev", idx, data),
873
+ inputs=[current_mask_index, masks_data],
874
+ outputs=[mask_result, current_mask_index, mask_info]
875
+ )
876
+
877
+ next_mask_btn.click(
878
+ lambda idx, data: navigate_mask("next", idx, data),
879
+ inputs=[current_mask_index, masks_data],
880
+ outputs=[mask_result, current_mask_index, mask_info]
881
+ )
882
+
883
+ clear_points_btn.click(
884
+ clear_all_points,
885
+ inputs=[original_with_input],
886
+ outputs=[points_state, points_display, points_overlay, original_with_input, mask_result, status_text]
887
+ )
888
+
889
+ clear_text_btn.click(
890
+ clear_text_prompt,
891
+ outputs=[text_prompt_input, status_text]
892
+ )
893
+
894
+ save_btn.click(
895
+ save_and_download_mask,
896
+ inputs=[mask_name_input, format_selector],
897
+ outputs=[status_text, download_file]
898
+ )
899
+
900
+ return interface
901
+
902
+ def main():
903
+ """Main function"""
904
+ print("πŸš€ Starting Fixed SAM 2.1 Interface...")
905
+
906
+ interface = create_interface()
907
+
908
+ print("🌐 Launching web interface...")
909
+ print("πŸ“ Click on objects in images to segment them!")
910
+
911
+ interface.launch(
912
+ server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)),
913
+ share=True, # Enable public sharing
914
+ inbrowser=False, # Don't auto-open browser in server environment
915
+ show_error=True,
916
+ server_name="0.0.0.0", # Allow external connections
917
+ auth=None # No authentication for public access
918
+ )
919
+
920
+ if __name__ == "__main__":
921
+ main()
sam_inference.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fixed SAM 2.1 Interface - Handles negative stride issues properly
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+ from transformers import Sam2Model, Sam2Processor
12
+ import warnings
13
+ import io
14
+ import base64
15
+ import os
16
+ from datetime import datetime
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+ # Global model instance to avoid reloading
21
+ MODEL = None
22
+ PROCESSOR = None
23
+ DEVICE = None
24
+
25
+ # Global state for saving
26
+ CURRENT_MASK = None
27
+ CURRENT_IMAGE_NAME = None
28
+ CURRENT_POINTS = None
29
+
30
+ def initialize_sam(model_size="small"):
31
+ """Initialize SAM model once"""
32
+ global MODEL, PROCESSOR, DEVICE
33
+
34
+ if MODEL is None:
35
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
+ print(f"Initializing SAM 2.1 {model_size} on {DEVICE}...")
37
+
38
+ model_name = f"facebook/sam2-hiera-{model_size}"
39
+ MODEL = Sam2Model.from_pretrained(model_name).to(DEVICE)
40
+ PROCESSOR = Sam2Processor.from_pretrained(model_name)
41
+
42
+ print("βœ“ Model loaded successfully!")
43
+
44
+ return MODEL, PROCESSOR, DEVICE
45
+
46
+ def fix_image_array(image):
47
+ """Fix image input for SAM processing - handles filepath, numpy array, or PIL Image"""
48
+ if isinstance(image, str):
49
+ # Handle filepath input from Gradio
50
+ return Image.open(image).convert("RGB")
51
+
52
+ elif isinstance(image, np.ndarray):
53
+ # Make sure array is contiguous
54
+ if not image.flags['C_CONTIGUOUS']:
55
+ image = np.ascontiguousarray(image)
56
+
57
+ # Ensure uint8 dtype
58
+ if image.dtype != np.uint8:
59
+ if image.max() <= 1.0:
60
+ image = (image * 255).astype(np.uint8)
61
+ else:
62
+ image = image.astype(np.uint8)
63
+
64
+ # Convert to PIL Image to avoid any stride issues
65
+ return Image.fromarray(image).convert("RGB")
66
+
67
+ elif isinstance(image, Image.Image):
68
+ return image.convert("RGB")
69
+
70
+ else:
71
+ raise ValueError(f"Unsupported image type: {type(image)}")
72
+
73
+ def apply_mask_post_processing(mask, stability_threshold=0.95):
74
+ """Apply post-processing to refine mask size and quality"""
75
+ import cv2
76
+
77
+ # Convert to binary mask
78
+ binary_mask = (mask > 0).astype(np.uint8)
79
+
80
+ # Apply morphological operations to clean up the mask
81
+ kernel_size = max(3, int(mask.shape[0] * 0.01)) # Adaptive kernel size
82
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
83
+
84
+ # Close small holes
85
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
86
+
87
+ # Remove small noise
88
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
89
+
90
+ return binary_mask.astype(np.float32)
91
+
92
+ def apply_erosion_dilation(mask, erosion_dilation_value):
93
+ """Apply erosion or dilation to adjust mask size"""
94
+ import cv2
95
+
96
+ binary_mask = (mask > 0).astype(np.uint8)
97
+
98
+ if erosion_dilation_value == 0:
99
+ return mask
100
+
101
+ kernel_size = abs(erosion_dilation_value)
102
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
103
+
104
+ if erosion_dilation_value > 0:
105
+ # Dilate (make larger)
106
+ binary_mask = cv2.dilate(binary_mask, kernel, iterations=1)
107
+ else:
108
+ # Erode (make smaller)
109
+ binary_mask = cv2.erode(binary_mask, kernel, iterations=1)
110
+
111
+ return binary_mask.astype(np.float32)
112
+
113
+ def save_binary_mask(mask, image_name, points, mask_threshold, erosion_dilation, save_low_res=False, custom_folder_name=None):
114
+ """Save binary mask to organized folder structure"""
115
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
116
+
117
+ try:
118
+ # Store current state for saving
119
+ CURRENT_MASK = mask
120
+ CURRENT_IMAGE_NAME = image_name
121
+ CURRENT_POINTS = points
122
+
123
+ # Extract image name without extension and sanitize
124
+ if image_name:
125
+ base_name = os.path.splitext(os.path.basename(image_name))[0]
126
+ # Remove any path separators and special characters
127
+ base_name = base_name.replace('/', '_').replace('\\', '_').replace(':', '_').replace(' ', '_')
128
+ else:
129
+ base_name = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
130
+
131
+ # Choose folder tag: user-provided name if available, else 'default'
132
+ folder_tag = None
133
+ if custom_folder_name and str(custom_folder_name).strip():
134
+ folder_tag = str(custom_folder_name).strip().replace(' ', '_')
135
+ else:
136
+ folder_tag = "default"
137
+
138
+ # Create folder structure: masks/<image_base>/<folder_tag>/
139
+ folder_name = f"masks/{base_name}/{folder_tag}"
140
+ os.makedirs(folder_name, exist_ok=True)
141
+
142
+ # Create binary mask (0 and 255 values)
143
+ binary_mask = (mask > 0).astype(np.uint8) * 255
144
+
145
+ # Calculate low resolution dimensions if requested
146
+ original_height, original_width = binary_mask.shape
147
+ if save_low_res:
148
+ # Calculate sqrt-based resolution
149
+ sqrt_factor = int(np.sqrt(max(original_width, original_height)))
150
+ low_res_width = sqrt_factor
151
+ low_res_height = sqrt_factor
152
+ print(f"Original mask size: {original_width}x{original_height}")
153
+ print(f"Low-res mask size: {low_res_width}x{low_res_height}")
154
+
155
+ # Save binary mask
156
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
157
+ # Sanitize filename - replace problematic characters
158
+ threshold_str = f"{mask_threshold:.2f}".replace('.', 'p') # 0.30 -> 0p30
159
+ adj_str = f"{erosion_dilation:+d}".replace('+', 'plus').replace('-', 'minus') # +2 -> plus2, -2 -> minus2
160
+
161
+ saved_paths = []
162
+
163
+ # Save full resolution mask as JPEG with a simple filename
164
+ mask_filename = "image.jpg"
165
+ mask_path = os.path.join(folder_name, mask_filename)
166
+
167
+ mask_image = Image.fromarray(binary_mask, mode='L')
168
+ mask_image.save(mask_path, format="JPEG", quality=95, optimize=True)
169
+ saved_paths.append(mask_path)
170
+
171
+ # Save tensor mask (.pt) as float tensor (0.0/1.0)
172
+ tensor_filename = "image.pt"
173
+ tensor_path = os.path.join(folder_name, tensor_filename)
174
+ torch.save(torch.from_numpy((mask > 0).astype(np.float32)), tensor_path)
175
+ saved_paths.append(tensor_path)
176
+
177
+ # Save low resolution mask if requested
178
+ if save_low_res:
179
+ # Resize mask to low resolution
180
+ low_res_mask = mask_image.resize((low_res_width, low_res_height), Image.Resampling.NEAREST)
181
+
182
+ low_res_filename = f"mask_lowres_{sqrt_factor}x{sqrt_factor}_t{threshold_str}_adj{adj_str}_{timestamp}.png"
183
+ low_res_path = os.path.join(folder_name, low_res_filename)
184
+
185
+ low_res_mask.save(low_res_path)
186
+ saved_paths.append(low_res_path)
187
+
188
+ # Also save metadata
189
+ metadata = {
190
+ "timestamp": timestamp,
191
+ "points": points,
192
+ "mask_threshold": mask_threshold,
193
+ "erosion_dilation": erosion_dilation,
194
+ "image_name": image_name,
195
+ "original_resolution": f"{original_width}x{original_height}",
196
+ "saved_paths": saved_paths,
197
+ "low_resolution_saved": save_low_res
198
+ }
199
+
200
+ if save_low_res:
201
+ metadata["low_resolution"] = f"{low_res_width}x{low_res_height}"
202
+ metadata["sqrt_factor"] = sqrt_factor
203
+
204
+ import json
205
+ metadata_path = os.path.join(folder_name, f"metadata_{timestamp}.json")
206
+ with open(metadata_path, 'w') as f:
207
+ json.dump(metadata, f, indent=2)
208
+
209
+ # Return appropriate message
210
+ if save_low_res:
211
+ return f"βœ… Masks saved:\nπŸ“ Full: {os.path.basename(mask_path)}\nπŸ“ Low-res: {os.path.basename(low_res_path)}"
212
+ else:
213
+ return f"βœ… Mask saved to: {os.path.basename(mask_path)}"
214
+
215
+ except Exception as e:
216
+ return f"❌ Save failed: {str(e)}"
217
+
218
+ def process_sam_segmentation(image, points_data, bbox_data, mode, image_name=None, top_k=3, mask_threshold=0.0, stability_score_threshold=0.95, erosion_dilation=0):
219
+ """Main processing function with mask size controls - supports points and bounding boxes"""
220
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
221
+
222
+ if image is None:
223
+ return None, None, "Please upload an image first."
224
+
225
+ # Check input based on mode
226
+ if mode == "Points":
227
+ if not points_data or len(points_data) == 0:
228
+ return None, None, "Please click on the image to select points."
229
+ elif mode == "Bounding Box":
230
+ if bbox_data is None:
231
+ return None, None, "Please click two corners to define a bounding box."
232
+
233
+ try:
234
+ # Initialize model
235
+ model, processor, device = initialize_sam()
236
+
237
+ # Fix image
238
+ pil_image = fix_image_array(image)
239
+
240
+ # Prepare SAM inputs based on mode
241
+ input_points = None
242
+ input_labels = None
243
+ input_boxes = None
244
+ points = None
245
+
246
+ if mode == "Points":
247
+ # Extract points with positive/negative labels
248
+ points = []
249
+ labels = []
250
+ for point_info in points_data:
251
+ if isinstance(point_info, dict):
252
+ points.append([point_info.get("x", 0), point_info.get("y", 0)])
253
+ labels.append(1 if point_info.get("positive", True) else 0) # 1 = positive, 0 = negative
254
+ elif isinstance(point_info, (list, tuple)) and len(point_info) >= 2:
255
+ points.append([point_info[0], point_info[1]])
256
+ labels.append(1) # Default to positive for old format
257
+
258
+ if not points:
259
+ return None, "No valid points found."
260
+
261
+ print(f"Processing {len(points)} points: {points} with labels: {labels}")
262
+ input_points = [[points]]
263
+ input_labels = [[labels]]
264
+
265
+ elif mode == "Bounding Box":
266
+ # Use bounding box
267
+ bbox = bbox_data # [x1, y1, x2, y2]
268
+ print(f"Processing bounding box: {bbox}")
269
+ input_boxes = [[bbox]]
270
+ # For visualization, store the bbox corners as points
271
+ points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]]
272
+
273
+ # Process with SAM
274
+ processor_inputs = {
275
+ "images": pil_image,
276
+ "return_tensors": "pt"
277
+ }
278
+
279
+ # Add points or boxes based on mode
280
+ if mode == "Points":
281
+ processor_inputs["input_points"] = input_points
282
+ processor_inputs["input_labels"] = input_labels
283
+ elif mode == "Bounding Box":
284
+ processor_inputs["input_boxes"] = input_boxes
285
+
286
+ inputs = processor(**processor_inputs).to(device)
287
+
288
+ # Generate masks with multiple outputs for better control
289
+ with torch.no_grad():
290
+ outputs = model(**inputs, multimask_output=True)
291
+
292
+ # Get masks and scores
293
+ masks = processor.post_process_masks(
294
+ outputs.pred_masks.cpu(),
295
+ inputs["original_sizes"]
296
+ )[0]
297
+
298
+ scores = outputs.iou_scores.cpu().numpy().flatten()
299
+
300
+ # Get top-k masks
301
+ top_indices = np.argsort(scores)[::-1][:top_k]
302
+
303
+ # Apply mask threshold to control size
304
+ best_mask = masks[0, top_indices[0]].numpy()
305
+ best_score = scores[top_indices[0]]
306
+
307
+ # Apply threshold to control mask size
308
+ if mask_threshold > 0:
309
+ best_mask = (best_mask > mask_threshold).astype(np.float32)
310
+
311
+ # Additional mask processing for size control
312
+ best_mask = apply_mask_post_processing(best_mask, stability_score_threshold)
313
+
314
+ # Apply erosion/dilation for fine size control
315
+ if erosion_dilation != 0:
316
+ best_mask = apply_erosion_dilation(best_mask, erosion_dilation)
317
+
318
+ # Store current state for saving
319
+ CURRENT_MASK = best_mask
320
+ CURRENT_IMAGE_NAME = image_name
321
+ CURRENT_POINTS = points
322
+
323
+ # Create dual visualizations
324
+ original_with_input = create_original_with_input_visualization(pil_image, points, bbox_data, mode)
325
+ mask_result = create_mask_visualization(pil_image, best_mask, best_score, mask_threshold)
326
+
327
+ status = f"βœ“ Generated mask with score: {float(best_score):.3f}\nπŸ”„ Ready to save!"
328
+ return original_with_input, mask_result, status
329
+
330
+ except Exception as e:
331
+ print(f"Error in processing: {e}")
332
+ return None, None, f"Error: {str(e)}"
333
+
334
+ def create_original_with_input_visualization(pil_image, points, bbox, mode, negative_points=None):
335
+ """Create visualization of original image with input points/bbox overlay"""
336
+ # Convert PIL to numpy for matplotlib
337
+ img_array = np.array(pil_image)
338
+
339
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
340
+
341
+ # Show original image only
342
+ ax.imshow(img_array)
343
+
344
+ # Show input visualization based on mode
345
+ if mode == "Points":
346
+ total_points = 0
347
+ # Show positive points (green)
348
+ if points:
349
+ for point in points:
350
+ ax.plot(point[0], point[1], 'go', markersize=12, markeredgewidth=3, markerfacecolor='lime')
351
+ total_points += len(points)
352
+
353
+ # Show negative points (red)
354
+ if negative_points:
355
+ for point in negative_points:
356
+ ax.plot(point[0], point[1], 'ro', markersize=12, markeredgewidth=3, markerfacecolor='red')
357
+ total_points += len(negative_points)
358
+
359
+ pos_count = len(points) if points else 0
360
+ neg_count = len(negative_points) if negative_points else 0
361
+ title_suffix = f"Points: {pos_count}+ {neg_count}-" if neg_count > 0 else f"Points: {pos_count}"
362
+ elif mode == "Bounding Box" and bbox:
363
+ # Show bounding box
364
+ x1, y1, x2, y2 = bbox
365
+ width = x2 - x1
366
+ height = y2 - y1
367
+
368
+ # Draw bounding box rectangle
369
+ from matplotlib.patches import Rectangle
370
+ rect = Rectangle((x1, y1), width, height, linewidth=3, edgecolor='lime', facecolor='none')
371
+ ax.add_patch(rect)
372
+
373
+ # Show corner points
374
+ ax.plot([x1, x2], [y1, y2], 'go', markersize=8, markeredgewidth=2, markerfacecolor='lime')
375
+ title_suffix = f"BBox: {int(width)}Γ—{int(height)}"
376
+ else:
377
+ title_suffix = "No input"
378
+
379
+ ax.set_title(f"Input Selection ({title_suffix})", fontsize=14)
380
+ ax.axis('off')
381
+
382
+ # Convert to numpy array
383
+ fig.canvas.draw()
384
+ buf = fig.canvas.buffer_rgba()
385
+ result_array = np.asarray(buf)
386
+ # Convert RGBA to RGB
387
+ result_array = result_array[:, :, :3]
388
+
389
+ plt.close(fig)
390
+ return result_array
391
+
392
+ def create_mask_visualization(pil_image, mask, score, mask_threshold=0.0):
393
+ """Create clean mask visualization without input overlays"""
394
+ # Convert PIL to numpy for matplotlib
395
+ img_array = np.array(pil_image)
396
+
397
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
398
+
399
+ # Show original image
400
+ ax.imshow(img_array)
401
+
402
+ # Overlay mask in red
403
+ mask_overlay = np.zeros((*mask.shape, 4))
404
+ mask_overlay[mask > 0] = [1, 0, 0, 0.6] # Red with transparency
405
+ ax.imshow(mask_overlay)
406
+
407
+ ax.set_title(f"Generated Mask (Score: {float(score):.3f}, Threshold: {mask_threshold:.2f})", fontsize=14)
408
+ ax.axis('off')
409
+
410
+ # Convert to numpy array
411
+ fig.canvas.draw()
412
+ buf = fig.canvas.buffer_rgba()
413
+ result_array = np.asarray(buf)
414
+ # Convert RGBA to RGB
415
+ result_array = result_array[:, :, :3]
416
+
417
+ plt.close(fig)
418
+ return result_array
419
+
420
+ def create_interface():
421
+ """Create a simplified single-image annotator interface."""
422
+
423
+ with gr.Blocks(title="SAM 2.1 - Simple Annotator", theme=gr.themes.Soft(), css="""
424
+ .negative-mode-checkbox label {
425
+ color: #d00000 !important;
426
+ font-weight: 800 !important;
427
+ font-size: 16px !important;
428
+ }
429
+ """) as interface:
430
+ gr.HTML("""
431
+ <div style="text-align: center;">
432
+ <h1>🎯 SAM 2.1 Simple Annotator</h1>
433
+ <p>Upload one image, click to add positive/negative points, generate mask, and save.</p>
434
+ </div>
435
+ """)
436
+
437
+ # Image input (single image) - directly annotate; this serves as uploader too
438
+ # Users can upload by clicking the annotatable image component below.
439
+ image_input = gr.Image(
440
+ label=None,
441
+ type="filepath",
442
+ height=0,
443
+ visible=False
444
+ )
445
+
446
+ # Main layout: Selected Points on the left, annotatable image in the center, preview on the right
447
+ with gr.Row():
448
+ with gr.Column(scale=1):
449
+ points_display = gr.JSON(label="πŸ“ Selected Points", value=[], visible=True)
450
+ with gr.Column(scale=3):
451
+ # Negative mode toggle with clear red styling
452
+ negative_point_mode = gr.Checkbox(
453
+ label="βž– NEGATIVE POINT MODE",
454
+ value=False,
455
+ info="πŸ”΄ Enable to add negative points (shown in red)",
456
+ interactive=True,
457
+ elem_classes="negative-mode-checkbox"
458
+ )
459
+ original_with_input = gr.Image(
460
+ label="πŸ“ Click to Annotate (toggle negative mode to exclude)",
461
+ height=640,
462
+ interactive=True
463
+ )
464
+ with gr.Column(scale=1):
465
+ points_overlay = gr.Image(label="πŸ“ Points Preview (green=positive, red=negative)", height=720, interactive=False)
466
+
467
+ # Action buttons
468
+ with gr.Row():
469
+ generate_btn = gr.Button("🎯 Generate Mask", variant="primary", size="lg")
470
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Points", variant="secondary", size="lg")
471
+
472
+ # Mask result under buttons
473
+ with gr.Row():
474
+ mask_result = gr.Image(label="🎭 Generated Mask", height=512)
475
+
476
+ # Save controls under mask
477
+ with gr.Row():
478
+ mask_name_input = gr.Textbox(label="Folder name (optional)", placeholder="e.g., michael_phelps_bottom_left")
479
+ save_btn = gr.Button("πŸ’Ύ Save Mask", variant="stop", size="lg")
480
+
481
+ # Status
482
+ with gr.Row():
483
+ status_text = gr.Textbox(label="πŸ“Š Status", interactive=False, lines=3)
484
+
485
+ # State to store points only
486
+ points_state = gr.State([])
487
+
488
+ # Event handlers
489
+ def on_image_click(image, current_points, negative_mode, evt: gr.SelectData):
490
+ """Handle clicks on the image for point annotations only."""
491
+ if evt.index is not None and image is not None:
492
+ x, y = evt.index
493
+ try:
494
+ pil_image = fix_image_array(image)
495
+ is_negative = negative_mode
496
+ new_point = {"x": int(x), "y": int(y), "positive": not is_negative}
497
+ updated_points = current_points + [new_point]
498
+
499
+ positive_points = [[p["x"], p["y"]] for p in updated_points if p.get("positive", True)]
500
+ negative_points = [[p["x"], p["y"]] for p in updated_points if not p.get("positive", True)]
501
+
502
+ updated_visualization = create_original_with_input_visualization(
503
+ pil_image, positive_points, None, "Points", negative_points
504
+ )
505
+
506
+ point_type = "positive" if not is_negative else "negative"
507
+ pos_count = len(positive_points)
508
+ neg_count = len(negative_points)
509
+ return updated_points, updated_points, updated_visualization, (
510
+ f"Added {point_type} point at ({x}, {y}). Total: {pos_count} positive, {neg_count} negative points."
511
+ )
512
+ except Exception as e:
513
+ print(f"Error in visualization: {e}")
514
+ return current_points, current_points, None, f"Error updating visualization: {str(e)}"
515
+ return current_points, current_points, None, "Click on the image to add points."
516
+
517
+ def on_image_upload(image):
518
+ """Handle image upload and show it for annotation."""
519
+ if image is not None:
520
+ try:
521
+ pil_image = fix_image_array(image)
522
+ img_array = np.array(pil_image)
523
+ # Populate both the annotation image (left) and the points preview (right)
524
+ return img_array, img_array, [], [], "Image uploaded. Click on the left image to add points (enable negative mode for exclusion)."
525
+ except Exception as e:
526
+ return None, None, [], [], f"Error loading image: {str(e)}"
527
+ return None, None, [], [], "No image uploaded."
528
+
529
+ def clear_all_points(image):
530
+ """Clear points and keep the image visible for annotation."""
531
+ try:
532
+ if image is not None:
533
+ pil_image = fix_image_array(image)
534
+ img_array = np.array(pil_image)
535
+ return [], [], img_array, img_array, None, "All points cleared. You can continue annotating."
536
+ except Exception:
537
+ pass
538
+ return [], [], None, None, None, "All points cleared."
539
+
540
+ def generate_segmentation(image, points):
541
+ """Generate a single segmentation mask using points only."""
542
+ # Determine image name
543
+ if isinstance(image, str):
544
+ image_name = os.path.basename(image)
545
+ else:
546
+ # Prefer an explicit friendly default if metadata lacks a good name
547
+ image_name = None
548
+ if hasattr(image, 'orig_name'):
549
+ image_name = image.orig_name
550
+ elif isinstance(image, dict) and 'orig_name' in image:
551
+ image_name = image['orig_name']
552
+ elif hasattr(image, 'name'):
553
+ image_name = image.name
554
+ if not image_name or 'tmp' in str(image_name).lower() or 'uploaded_image' in str(image_name).lower():
555
+ image_name = "michael_phelps_bottom_left.jpg"
556
+
557
+ # Run segmentation (points mode)
558
+ _, mask_img, status = process_sam_segmentation(
559
+ image, points, None, "Points", image_name, 1, 0.0, 0.95, 0
560
+ )
561
+ if mask_img is not None:
562
+ status += f"\nπŸ“ Image: {os.path.basename(image_name)}"
563
+ return mask_img, status
564
+
565
+ def save_current_mask(custom_folder_name):
566
+ """Save the currently generated mask."""
567
+ global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS
568
+ if CURRENT_MASK is None:
569
+ return "❌ No mask to save. Generate a mask first."
570
+ if CURRENT_POINTS is None:
571
+ return "❌ No points available. Generate a mask first."
572
+ return save_binary_mask(CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS, 0.0, 0, False, custom_folder_name=(custom_folder_name or None))
573
+
574
+ # Wire events
575
+ # Let the annotatable image also handle image uploads (drag & drop / click upload)
576
+ original_with_input.upload(
577
+ on_image_upload,
578
+ inputs=[original_with_input],
579
+ outputs=[original_with_input, points_overlay, points_state, points_display, status_text]
580
+ )
581
+
582
+ original_with_input.select(
583
+ on_image_click,
584
+ inputs=[original_with_input, points_state, negative_point_mode],
585
+ outputs=[points_state, points_display, points_overlay, status_text]
586
+ )
587
+
588
+ generate_btn.click(
589
+ generate_segmentation,
590
+ inputs=[original_with_input, points_state],
591
+ outputs=[mask_result, status_text]
592
+ )
593
+
594
+ clear_btn.click(
595
+ clear_all_points,
596
+ inputs=[original_with_input],
597
+ outputs=[points_state, points_display, points_overlay, original_with_input, mask_result, status_text]
598
+ )
599
+
600
+ save_btn.click(
601
+ save_current_mask,
602
+ inputs=[mask_name_input],
603
+ outputs=[status_text]
604
+ )
605
+
606
+ return interface
607
+
608
+ def main():
609
+ """Main function"""
610
+ print("πŸš€ Starting Fixed SAM 2.1 Interface...")
611
+
612
+ interface = create_interface()
613
+
614
+ print("🌐 Launching web interface...")
615
+ print("πŸ“ Click on objects in images to segment them!")
616
+
617
+ interface.launch(
618
+ server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)),
619
+ share=False,
620
+ inbrowser=False, # Don't auto-open browser in server environment
621
+ show_error=True
622
+ )
623
+
624
+ if __name__ == "__main__":
625
+ main()