Abdur123 commited on
Commit
65b8939
·
1 Parent(s): 7673bbf

Add application file

Browse files
Files changed (5) hide show
  1. .gitignore +201 -0
  2. app.py +145 -0
  3. img/example.jpg +0 -0
  4. inference.py +315 -0
  5. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ share/python-wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+ MANIFEST
24
+
25
+ # PyInstaller
26
+ *.manifest
27
+ *.spec
28
+
29
+ # Installer logs
30
+ pip-log.txt
31
+ pip-delete-this-directory.txt
32
+
33
+ # Unit test / coverage reports
34
+ htmlcov/
35
+ .tox/
36
+ .nox/
37
+ .coverage
38
+ .coverage.*
39
+ .cache
40
+ nosetests.xml
41
+ coverage.xml
42
+ *.cover
43
+ *.py,cover
44
+ .hypothesis/
45
+ .pytest_cache/
46
+ cover/
47
+
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+
52
+ # Django stuff:
53
+ *.log
54
+ local_settings.py
55
+ db.sqlite3
56
+ db.sqlite3-journal
57
+
58
+ # Flask stuff:
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Scrapy stuff:
63
+ .scrapy
64
+
65
+ # Sphinx documentation
66
+ docs/_build/
67
+
68
+ # PyBuilder
69
+ .pybuilder/
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # IPython
76
+ profile_default/
77
+ ipython_config.py
78
+
79
+ # pyenv
80
+ .python-version
81
+
82
+ # pipenv
83
+ Pipfile.lock
84
+
85
+ # poetry
86
+ poetry.lock
87
+
88
+ # pdm
89
+ .pdm.toml
90
+
91
+ # PEP 582
92
+ __pypackages__/
93
+
94
+ # Celery stuff
95
+ celerybeat-schedule
96
+ celerybeat.pid
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # Environments
102
+ .env
103
+ .venv
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # pytype static type analyzer
129
+ .pytype/
130
+
131
+ # Cython debug symbols
132
+ cython_debug/
133
+
134
+ # PyCharm
135
+ .idea/
136
+
137
+ # VS Code
138
+ .vscode/
139
+
140
+ # Temporary files
141
+ *.tmp
142
+ *.temp
143
+ *.swp
144
+ *.swo
145
+ *~
146
+
147
+ # OS generated files
148
+ .DS_Store
149
+ .DS_Store?
150
+ ._*
151
+ .Spotlight-V100
152
+ .Trashes
153
+ ehthumbs.db
154
+ Thumbs.db
155
+
156
+ # User-specific directories and files
157
+ jlwkkuvw7a/
158
+ *.coco.json
159
+
160
+ # Model files (these will be downloaded automatically)
161
+ *.pt
162
+ *.pth
163
+ *.onnx
164
+ *.safetensors
165
+
166
+ # Lock files
167
+ uv.lock
168
+ poetry.lock
169
+ Pipfile.lock
170
+
171
+ # Project configuration files
172
+ pyproject.toml
173
+ setup.py
174
+ setup.cfg
175
+
176
+ # Logs
177
+ *.log
178
+ logs/
179
+
180
+ # Cache directories
181
+ .cache/
182
+ cache/
183
+
184
+ # Temporary directories
185
+ tmp/
186
+ temp/
187
+
188
+ # Large files that shouldn't be in git
189
+ *.zip
190
+ *.tar.gz
191
+ *.rar
192
+ *.7z
193
+
194
+ # Keep only essential files for deployment
195
+ # The following files are essential and should NOT be ignored:
196
+ # - app.py
197
+ # - inference.py
198
+ # - requirements.txt
199
+ # - README.md
200
+ # - img/example.jpg
201
+ # - .gitattributes
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import secrets
4
+ import shutil
5
+ from inference import detector, detector_processor, segment_predictor, ModelInference
6
+
7
+ current_dir = Path(__file__).parent
8
+
9
+
10
+ def process_images(image_path, files, slider_value, request: gr.Request):
11
+
12
+ user_dir: Path = current_dir / str(request.session_hash)
13
+ user_dir.mkdir(exist_ok=True)
14
+
15
+ annotation_path = user_dir / f"{secrets.token_hex(nbytes=8)}_annotations.coco.json"
16
+ class_names = list(inferencer.id2label.values())
17
+
18
+ if image_path:
19
+ print(f"Processing image: {image_path}")
20
+ seg_detections, annotated_frame = inferencer.predict_one(image_path)
21
+
22
+ inferencer.save_annotations([image_path], [seg_detections], class_names, annotation_path)
23
+ elif files:
24
+ print(f"Processing files: {files}")
25
+ print(f"Batch size: {slider_value}")
26
+ all_image_paths, all_results, annotated_frame, detector_failed_list, segmentor_failed_list = inferencer.predict_folder(files, slider_value)
27
+
28
+ print(f"Detector failed list: {detector_failed_list}")
29
+ print(f"Segmentor failed list: {segmentor_failed_list}")
30
+
31
+ inferencer.save_annotations(all_image_paths, all_results, class_names, annotation_path)
32
+
33
+ return [
34
+ gr.UploadButton(visible=False),
35
+ gr.Button("Run", visible=False),
36
+ gr.DownloadButton("Download annotation results", value=annotation_path, label="Download", visible=True),
37
+ gr.Image(value=annotated_frame, label="Annotated Image", visible=True),
38
+ ]
39
+
40
+
41
+ def upload_file():
42
+
43
+ return [
44
+ None,
45
+ gr.UploadButton(visible=False),
46
+ gr.Slider(1, 6, step=1, label="Batch size", interactive=True, value=4, visible=True),
47
+ gr.Button("Run", visible=True),
48
+ gr.DownloadButton(visible=False),
49
+ gr.Image(value=None, label="Annotated Image", visible=True),
50
+ ]
51
+
52
+
53
+ def upload_image(imge_path):
54
+
55
+ return [
56
+ gr.UploadButton(visible=False),
57
+ gr.Slider(1, 6, step=1, label="Batch size", interactive=True, value=4, visible=False),
58
+ gr.Button("Run", visible=True),
59
+ gr.DownloadButton(visible=False),
60
+ gr.Image(value=None, label="Annotated Image", visible=True),
61
+ ]
62
+
63
+
64
+ def download_file():
65
+ return [
66
+ gr.Image(value=None),
67
+ gr.UploadButton(visible=True),
68
+ gr.Slider(1, 6, step=1, label="Batch size", interactive=True, value=4, visible=False),
69
+ gr.Button("Run", visible=False),
70
+ gr.DownloadButton(visible=True),
71
+ gr.Image(value=None, visible=False),
72
+ ]
73
+
74
+
75
+ def delete_directory(request: gr.Request):
76
+ """Delete the user-specific directory when the user's session ends."""
77
+ user_dir = current_dir / str(request.session_hash)
78
+ if user_dir.exists():
79
+ shutil.rmtree(user_dir)
80
+
81
+
82
+ def create_gradio_interface():
83
+ with gr.Blocks(theme=gr.themes.Monochrome(), delete_cache=(60, 3600)) as demo:
84
+ gr.HTML("""
85
+ <div style="text-align: center;">
86
+ <h1>Satellite Image Roofs Auto Annotation</h1>
87
+ <p>Powered by a <a href="https://huggingface.co/Yifeng-Liu/rt-detr-finetuned-for-satellite-image-roofs-detection" target="_blank">fine-tuned RT-DETR model</a> and a Fast-SAM model.</p>
88
+ <p>📤 Upload an image or a folder containing images.</p>
89
+ <p>🖼️ Images are saved in a user-specific directory and are deleted when the user closes the page.</p>
90
+ <p>⚙️ Each user can upload files with a maximum size of 200 MB each time.</p>
91
+ <p>🏷️ Annotation results will be saved in the COCO format for download.</p>
92
+ <p>🔧 TODO: Enhance model inference using Intel OpenVINO.</p>
93
+ </div>
94
+ """)
95
+ with gr.Row():
96
+ with gr.Column(scale=1):
97
+ img_input = gr.Image(
98
+ interactive=True,
99
+ sources=["upload", "clipboard"],
100
+ show_share_button=True,
101
+ type='filepath',
102
+ label="Upload a single image",
103
+ )
104
+ upload_button = gr.UploadButton("Upload a folder", file_count="directory")
105
+ batch_slider = gr.Slider(1, 6, step=1, label="Batch size", interactive=True, value=4, visible=False)
106
+ run_button = gr.Button("Run", visible=False)
107
+ with gr.Column(scale=1):
108
+ img_output = gr.Image(label="Annotated Image", visible=False)
109
+ download_button = gr.DownloadButton("Download annotation results", label="Download", visible=False)
110
+
111
+ with gr.Row():
112
+ examples = gr.Examples(
113
+ examples=[["./img/example.jpg"]],
114
+ inputs=[img_input],
115
+ outputs=[upload_button, batch_slider, run_button, download_button, img_output],
116
+ fn=upload_image,
117
+ run_on_click=True,
118
+ )
119
+
120
+ upload_button.upload(upload_file, None, [img_input, upload_button, batch_slider, run_button, download_button, img_output])
121
+
122
+ download_button.click(download_file, None, [img_input, upload_button, batch_slider, run_button, download_button, img_output])
123
+
124
+ run_button.click(process_images,
125
+ [img_input, upload_button, batch_slider],
126
+ [upload_button, run_button, download_button, img_output])
127
+
128
+ img_input.upload(upload_image, img_input, [upload_button, batch_slider, run_button, download_button, img_output])
129
+
130
+ demo.unload(delete_directory)
131
+
132
+ return demo
133
+
134
+
135
+ def inferencer_init():
136
+ id2label = {0: 'building'}
137
+ CONFIDENCE_TRESHOLD = 0.5
138
+ return ModelInference(detector, detector_processor, segment_predictor, id2label, CONFIDENCE_TRESHOLD)
139
+
140
+
141
+ inferencer = inferencer_init()
142
+
143
+ if __name__ == "__main__":
144
+ demo = create_gradio_interface()
145
+ demo.launch(max_file_size=200 * gr.FileSize.MB)
img/example.jpg ADDED
inference.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForObjectDetection, AutoImageProcessor
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import os
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from ultralytics.models.fastsam import FastSAMPredictor
8
+ import supervision as sv
9
+ import torch
10
+ import numpy as np
11
+ import cv2
12
+ from typing import List, Tuple, Dict, Any, Optional
13
+ from supervision.dataset.utils import approximate_mask_with_polygons
14
+ from supervision.detection.utils import (
15
+ contains_holes,
16
+ contains_multiple_segments,
17
+ )
18
+
19
+ detector = AutoModelForObjectDetection.from_pretrained("Yifeng-Liu/rt-detr-finetuned-for-satellite-image-roofs-detection")
20
+ detector_processor = AutoImageProcessor.from_pretrained("Yifeng-Liu/rt-detr-finetuned-for-satellite-image-roofs-detection")
21
+
22
+
23
+ overrides = dict(conf=0.25, task="segment", mode="predict", model="FastSAM-x.pt", save=False)
24
+ segment_predictor = FastSAMPredictor(overrides=overrides)
25
+
26
+ # IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
27
+
28
+
29
+ class ImageInferenceDataset(Dataset):
30
+ def __init__(self, image_paths: Path, image_processor):
31
+ """
32
+ A custom dataset class for image inference without annotations or masks.
33
+
34
+ Args:
35
+ image_folder (Path): The path to the folder containing images.
36
+ image_processor: A callable for processing images (usually a transformer or feature extractor).
37
+ image_formats (set): A set of supported image formats to be filtered.
38
+ """
39
+ self.image_processor = image_processor
40
+ # Filter out files that are not supported image formats
41
+ self.image_files = image_paths
42
+
43
+ def __len__(self) -> int:
44
+ return len(self.image_files)
45
+
46
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
47
+ """
48
+ Get an image from the dataset at the specified index.
49
+
50
+ Args:
51
+ idx (int): The index of the image.
52
+
53
+ Returns:
54
+ Tuple[torch.Tensor, str]: A tuple containing the processed image tensor and the image file path.
55
+ """
56
+ image_path = self.image_files[idx]
57
+ # Open image using PIL and process it using the provided image processor
58
+ with Image.open(image_path) as img:
59
+ orig_size = img.size
60
+ img = img.convert("RGB") # Ensure all images are in RGB format for consistency
61
+ processed_img = self.image_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
62
+
63
+ return processed_img, str(image_path), orig_size
64
+
65
+
66
+ def collate_fn_inference(batch: List[Tuple[torch.Tensor, str]]) -> dict:
67
+ """
68
+ Collate function for batching images for inference.
69
+
70
+ Args:
71
+ batch (List[Tuple[torch.Tensor, str]]): A list of tuples where each tuple contains
72
+ the processed image tensor and image path.
73
+
74
+ Returns:
75
+ dict: A dictionary containing the batched image tensors and corresponding image file paths.
76
+ """
77
+ pixel_values = [item[0] for item in batch] # Extract processed images
78
+ image_paths = [item[1] for item in batch] # Extract image paths
79
+ orig_sizes = [item[2] for item in batch]
80
+
81
+ # Pad the images to match the largest image in the batch
82
+ encoding = detector_processor.pad(pixel_values, return_tensors="pt")
83
+
84
+ return {
85
+ 'pixel_values': encoding['pixel_values'],
86
+ 'pixel_mask': encoding['pixel_mask'], # Padding mask (if needed by the model)
87
+ 'image_paths': image_paths,
88
+ 'orig_sizes': orig_sizes
89
+ }
90
+
91
+
92
+ class ModelInference:
93
+ def __init__(self, detector, detector_processor, segment_predictor, id2label, CONFIDENCE_TRESHOLD):
94
+ self.detector = detector
95
+ self.detector_processor = detector_processor
96
+ self.segment_predictor = segment_predictor
97
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ self.CONFIDENCE_TRESHOLD = CONFIDENCE_TRESHOLD
99
+ self.id2label = id2label
100
+ self.mask_annotator = sv.MaskAnnotator()
101
+ self.detector.to(self.device)
102
+
103
+ def predict_one(self, image_path):
104
+ image = cv2.imread(image_path)
105
+ with torch.no_grad():
106
+
107
+ # load image and predict
108
+ inputs = self.detector_processor(images=image, return_tensors='pt').to(self.device)
109
+ outputs = self.detector(**inputs)
110
+
111
+ # post-process
112
+ target_sizes = torch.tensor([image.shape[:2]]).to(self.device)
113
+ results = detector_processor.post_process_object_detection(
114
+ outputs=outputs,
115
+ threshold=self.CONFIDENCE_TRESHOLD,
116
+ target_sizes=target_sizes
117
+ )[0]
118
+ if results['boxes'].numel() == 0:
119
+ print("No bounding box detected")
120
+ return None, None
121
+ else:
122
+ det_detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=0.5)
123
+
124
+ everything_results = self.segment_predictor(image)
125
+ if everything_results[0].masks is not None:
126
+ bbox_results = self.segment_predictor.prompt(everything_results, det_detections.xyxy.tolist())[0]
127
+ seg_detections = sv.Detections.from_ultralytics(bbox_results)
128
+ seg_detections = self.filter_small_masks(seg_detections)
129
+
130
+ max_length = max(len(name) for name in self.id2label.values())
131
+
132
+ # Create a new NumPy array with the appropriate dtype based on the longest string
133
+ seg_detections.data['class_name'] = np.array(seg_detections.data['class_name'], dtype=f'<U{max_length}')
134
+
135
+ for idx, class_name in enumerate(seg_detections.data['class_name']):
136
+ if class_name == 'object':
137
+ seg_detections.data['class_name'][idx] = self.id2label[seg_detections.class_id[idx]]
138
+
139
+ annotated_frame = image.copy()
140
+ annotated_frame = self.mask_annotator.annotate(scene=annotated_frame, detections=seg_detections)
141
+
142
+ return seg_detections, annotated_frame
143
+ else:
144
+ print("No segmentation mask generated")
145
+ return None, None
146
+
147
+ def predict_folder(self, image_paths, batch_size=4):
148
+ dataset = ImageInferenceDataset(image_paths=image_paths, image_processor=detector_processor)
149
+
150
+ # Create DataLoader instance with the custom collate function
151
+ dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn_inference)
152
+
153
+ detector_failed_list = []
154
+ segmentor_failed_list = []
155
+
156
+ id2label = {0: 'building'}
157
+ max_length = max(len(name) for name in id2label.values())
158
+
159
+ all_image_paths = []
160
+
161
+ all_results = []
162
+
163
+ for idx, batch in enumerate(tqdm(dataloader)):
164
+ pixel_values = batch["pixel_values"].to(self.device)
165
+ pixel_mask = batch["pixel_mask"].to(self.device)
166
+ image_paths = batch["image_paths"]
167
+ orig_sizes = batch["orig_sizes"]
168
+
169
+ orig_target_sizes = torch.tensor(orig_sizes, device=self.device)
170
+
171
+ with torch.no_grad():
172
+ outputs = self.detector(
173
+ pixel_values=pixel_values, pixel_mask=pixel_mask)
174
+
175
+ # orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
176
+
177
+ detector_results = detector_processor.post_process_object_detection(
178
+ outputs,
179
+ target_sizes=orig_target_sizes)
180
+
181
+ detector_detections = []
182
+ detector_to_remove = []
183
+
184
+ for idx, detector_result in enumerate(detector_results):
185
+ if detector_result['boxes'].numel() == 0:
186
+ # The tensor is empty
187
+ detector_to_remove.append(idx)
188
+ else:
189
+ detector_detections.append(sv.Detections.from_transformers(transformers_results=detector_result))
190
+
191
+ if detector_to_remove is not None:
192
+ # Remove items from detector_results and image_ids by reversing the indices to avoid index shifting
193
+ for idx in sorted(detector_to_remove, reverse=True):
194
+ detector_failed_list.append(image_paths[idx])
195
+ del image_paths[idx]
196
+
197
+ images_raw = [cv2.imread(image_path) for image_path in image_paths]
198
+
199
+ boxes = [detections.xyxy.tolist() for detections in detector_detections]
200
+
201
+ results = []
202
+
203
+ to_remove_seg = []
204
+
205
+ for idx, (image_path, image, box) in enumerate(zip(image_paths, images_raw, boxes)):
206
+ try:
207
+ with torch.no_grad():
208
+ # segmentation_result = segment_model(image, bboxes=box)[0]
209
+ everything_results = self.segment_predictor(image)
210
+
211
+ if everything_results[0].masks is not None:
212
+ bbox_results = self.segment_predictor.prompt(everything_results, box)[0]
213
+ seg_detections = sv.Detections.from_ultralytics(bbox_results)
214
+ seg_detections = self.filter_small_masks(seg_detections)
215
+ seg_detections.data['class_name'] = np.array(seg_detections.data['class_name'], dtype=f'<U{max_length}')
216
+ for idx, class_name in enumerate(seg_detections.data['class_name']):
217
+ if class_name == 'object':
218
+ seg_detections.data['class_name'][idx] = id2label[seg_detections.class_id[idx]]
219
+ results.append(seg_detections)
220
+ else:
221
+ to_remove_seg.append(idx)
222
+ except Exception as e:
223
+ print(f"An error occurred: {e}")
224
+ print(f"box: {box}")
225
+ print(f"image id: {image_path}")
226
+ # result = sv.Detections.from_ultralytics(segmentation_result)
227
+ # results.append(result)
228
+
229
+ if to_remove_seg is not None:
230
+ for idx in sorted(to_remove_seg, reverse=True):
231
+ segmentor_failed_list.append(image_paths[idx])
232
+ del image_paths[idx]
233
+
234
+ if len(results) != len(image_paths):
235
+ print(f"Length of results ({len(results)}) does not match the length of image_ids ({len(image_paths)})")
236
+ continue
237
+
238
+ all_image_paths.extend(image_paths)
239
+ all_results.extend(results)
240
+
241
+ annotated_frame = cv2.imread(all_image_paths[0]).copy()
242
+ annotated_frame = self.mask_annotator.annotate(scene=annotated_frame, detections=all_results[0])
243
+
244
+ return all_image_paths, all_results, annotated_frame, detector_failed_list, segmentor_failed_list
245
+
246
+ def filter_small_masks(self, detections: sv.Detections) -> sv.Detections:
247
+ valid_indices = []
248
+ min_image_area_percentage = 0.002
249
+ max_image_area_percentage = 0.80
250
+ approximation_percentage = 0.75
251
+ for i, mask in enumerate(detections.mask):
252
+
253
+ # Check for structural issues in the mask
254
+ if not (contains_holes(mask) or contains_multiple_segments(mask)):
255
+ # Check if the mask can be approximated to a polygon successfully
256
+ if not approximate_mask_with_polygons(mask=mask,
257
+ min_image_area_percentage=min_image_area_percentage,
258
+ max_image_area_percentage=max_image_area_percentage,
259
+ approximation_percentage=approximation_percentage,
260
+ ):
261
+ print(f"Skipping mask {i} due to structural issues")
262
+ continue
263
+
264
+ # If all checks pass, add index to valid_indices
265
+ valid_indices.append(i)
266
+
267
+ filtered_xyxy = detections.xyxy[valid_indices]
268
+ filtered_mask = detections.mask[valid_indices]
269
+ filtered_confidence = detections.confidence[valid_indices]
270
+ filtered_class_id = detections.class_id[valid_indices]
271
+ filtered_class_name = detections.data['class_name'][valid_indices]
272
+
273
+ detections.xyxy = filtered_xyxy
274
+ detections.mask = filtered_mask
275
+ detections.confidence = filtered_confidence
276
+ detections.class_id = filtered_class_id
277
+ detections.data['class_name'] = filtered_class_name
278
+ return detections
279
+
280
+ def get_dict(
281
+ self,
282
+ image_paths: List[Any],
283
+ detections: List[Any]
284
+ ) -> Dict[str, Any]:
285
+
286
+ detections_dict = {}
287
+
288
+ for idx, image_path in enumerate(image_paths):
289
+ detections_dict[image_path] = detections[idx]
290
+
291
+ return detections_dict
292
+
293
+ def save_annotations(self,
294
+ image_paths,
295
+ detections,
296
+ class_names,
297
+ annotation_path,
298
+ MIN_IMAGE_AREA_PERCENTAGE=0.002,
299
+ MAX_IMAGE_AREA_PERCENTAGE=0.80,
300
+ APPROXIMATION_PERCENTAGE=0.75):
301
+ # image_dir = annotation_path.parent
302
+ detections_dict = self.get_dict(image_paths, detections)
303
+ sv.DetectionDataset(
304
+ classes=class_names,
305
+ images=image_paths,
306
+ annotations=detections_dict
307
+ ).as_coco(
308
+ images_directory_path=None,
309
+ annotations_path=annotation_path,
310
+ min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE,
311
+ max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE,
312
+ approximation_percentage=APPROXIMATION_PERCENTAGE
313
+ )
314
+
315
+ return
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==5.1.0
2
+ opencv-python==4.10.0.84
3
+ torch==2.4.0
4
+ supervision==0.23.0
5
+ tqdm==4.66.5
6
+ transformers==4.44.2
7
+ ultralytics==8.2.85