cc1234 commited on
Commit
244b0b6
·
1 Parent(s): 0dc8e91
.deepface/weights/arcface_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6336979c0c602cae08d1122a66f4dfb862d059bbcd8ef80306aef2b2249b0c93
3
+ size 137026640
.deepface/weights/facenet512_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f76b5117a9ca574d536af8199e6720089eb4ad3dc7e93534496d88265de864f
3
+ size 94955648
.deepface/weights/yolov8n-face.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d545bf1add5aa736a4febac4f4f9245a6d596cd0fe70d5d57989fe0cb9e626ca
3
+ size 6389512
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
37
+ face.json filter=lfs diff=lfs merge=lfs -text
38
+ .deepface/weights/yolov8n-face.pt filter=lfs diff=lfs merge=lfs -text
39
+ .deepface/weights/face_recognition_sface_2021dec.onnx filter=lfs diff=lfs merge=lfs -text
40
+ .deepface/weights/res10_300x300_ssd_iter_140000.caffemodel filter=lfs diff=lfs merge=lfs -text
41
+ .deepface/weights/centerface.onnx filter=lfs diff=lfs merge=lfs -text
42
+ .deepface/weights/deploy.prototxt filter=lfs diff=lfs merge=lfs -text
43
+ .deepface/weights/facenet512_weights.h5 filter=lfs diff=lfs merge=lfs -text
44
+ .deepface/weights/retinaface.h5 filter=lfs diff=lfs merge=lfs -text
45
+ .deepface/weights/face_detection_yunet_2023mar.onnx filter=lfs diff=lfs merge=lfs -text
46
+ .deepface/weights/arcface_weights.h5 filter=lfs diff=lfs merge=lfs -text
47
+ face_arc.voy filter=lfs diff=lfs merge=lfs -text
48
+ face_facenet.voy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ flagged
3
+ temp.jpg
4
+ __pycache__
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Stashface
3
- emoji: 💻
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.25.2
 
1
  ---
2
  title: Stashface
3
+ emoji: 👀
4
+ colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.25.2
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Set DeepFace home directory
4
+ os.environ["DEEPFACE_HOME"] = "."
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6
+
7
+ from models.data_manager import DataManager
8
+ from web.interface import WebInterface
9
+
10
+ def main():
11
+ """Main entry point for the application"""
12
+ # Initialize data manager
13
+ data_manager = DataManager(
14
+ faces_path="data/faces.json",
15
+ performers_zip="data/persons.zip",
16
+ facenet_index_path="data/face_facenet.voy",
17
+ arc_index_path="data/face_arc.voy"
18
+ )
19
+
20
+ # Initialize and launch web interface
21
+ web_interface = WebInterface(data_manager, default_threshold=0.5)
22
+ web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
23
+
24
+ if __name__ == "__main__":
25
+ main()
data/face_arc.voy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1913fae047f492d1b1da7100cb7835d12de5fd527b2f47a1d77996641abec8aa
3
+ size 56782035
data/face_facenet.voy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d722e23bfeb24a661ae1372b7648b61378c2c11870a46e57bdf06f8fe7fee969
3
+ size 56781931
data/faces.json ADDED
The diff for this file is too large to render. See raw diff
 
data/persons.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9eb10173cddf6a4bd0339218eb2cd9b3621e29d7d111dbcb41b8062ee43e8ed
3
+ size 5031776
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # models package
models/data_manager.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pyzipper
4
+ from typing import Dict, Any, Optional
5
+ from voyager import Index, Space, StorageDataType
6
+
7
+ class DataManager:
8
+ def __init__(self, faces_path: str = "data/faces.json",
9
+ performers_zip: str = "data/persons.zip",
10
+ facenet_index_path: str = "data/face_facenet.voy",
11
+ arc_index_path: str = "data/face_arc.voy"):
12
+ """
13
+ Initialize the data manager.
14
+
15
+ Parameters:
16
+ faces_path: Path to the faces.json file
17
+ performers_zip: Path to the performers zip file
18
+ facenet_index_path: Path to the facenet index file
19
+ arc_index_path: Path to the arc index file
20
+ """
21
+ self.faces_path = faces_path
22
+ self.performers_zip = performers_zip
23
+ self.facenet_index_path = facenet_index_path
24
+ self.arc_index_path = arc_index_path
25
+
26
+ # Initialize indices
27
+ self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
28
+ self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
29
+
30
+ # Load data
31
+ self.faces = {}
32
+ self.performer_db = {}
33
+ self.load_data()
34
+
35
+ def load_data(self):
36
+ """Load all data from files"""
37
+ self._load_faces()
38
+ self._load_performer_db()
39
+ self._load_indices()
40
+
41
+ def _load_faces(self):
42
+ """Load faces from JSON file"""
43
+ try:
44
+ with open(self.faces_path, 'r') as f:
45
+ self.faces = json.load(f)
46
+ except Exception as e:
47
+ print(f"Error loading faces: {e}")
48
+ self.faces = {}
49
+
50
+ def _load_performer_db(self):
51
+ """Load performer database from encrypted zip file"""
52
+ try:
53
+ with pyzipper.AESZipFile(self.performers_zip) as zf:
54
+ password = os.getenv("VISAGE_KEY", "").encode('ascii')
55
+ zf.setpassword(password)
56
+ self.performer_db = json.loads(zf.read('performers.json'))
57
+ except Exception as e:
58
+ print(f"Error loading performer database: {e}")
59
+ self.performer_db = {}
60
+
61
+ def _load_indices(self):
62
+ """Load face recognition indices"""
63
+ try:
64
+ with open(self.arc_index_path, 'rb') as f:
65
+ self.index_arc = self.index_arc.load(f)
66
+
67
+ with open(self.facenet_index_path, 'rb') as f:
68
+ self.index_facenet = self.index_facenet.load(f)
69
+ except Exception as e:
70
+ print(f"Error loading indices: {e}")
71
+
72
+ def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
73
+ """
74
+ Get performer information from the database
75
+
76
+ Parameters:
77
+ stash_id: Stash ID of the performer
78
+ confidence: Confidence score (0-1)
79
+
80
+ Returns:
81
+ Dictionary with performer information or None if not found
82
+ """
83
+ performer = self.performer_db.get(stash_id, [])
84
+ if not performer:
85
+ return None
86
+
87
+ confidence_int = int(confidence * 100)
88
+ return {
89
+ 'id': stash_id,
90
+ "name": performer['name'],
91
+ "confidence": confidence_int,
92
+ 'image': performer['image'],
93
+ 'country': performer['country'],
94
+ 'hits': 1,
95
+ 'distance': confidence_int,
96
+ 'performer_url': f"https://stashdb.org/performers/{stash_id}"
97
+ }
98
+
99
+ def query_facenet_index(self, embedding, limit):
100
+ """Query the facenet index with an embedding"""
101
+ return self.index_facenet.query(embedding, limit)
102
+
103
+ def query_arc_index(self, embedding, limit):
104
+ """Query the arc index with an embedding"""
105
+ return self.index_arc.query(embedding, limit)
models/face_recognition.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from typing import Dict, List, Tuple
4
+
5
+ from deepface import DeepFace
6
+
7
+ class EnsembleFaceRecognition:
8
+ def __init__(self, model_weights: Dict[str, float] = None):
9
+ """
10
+ Initialize ensemble face recognition system.
11
+
12
+ Parameters:
13
+ model_weights: Dictionary mapping model names to their weights
14
+ If None, all models are weighted equally
15
+ """
16
+ self.model_weights = model_weights or {}
17
+ self.boost_factor = 1.8
18
+
19
+ def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
20
+ """Normalize distances to [0,1] range within each model's predictions"""
21
+ min_dist = np.min(distances)
22
+ max_dist = np.max(distances)
23
+ if max_dist == min_dist:
24
+ return np.zeros_like(distances)
25
+ return (distances - min_dist) / (max_dist - min_dist)
26
+
27
+ def compute_model_confidence(self,
28
+ distances: np.ndarray,
29
+ temperature: float = 0.1) -> np.ndarray:
30
+ """Convert distances to confidence scores for a single model"""
31
+ normalized_distances = self.normalize_distances(distances)
32
+ exp_distances = np.exp(-normalized_distances / temperature)
33
+ return exp_distances / np.sum(exp_distances)
34
+
35
+ def get_face_embeddings(self, image: np.ndarray) -> Dict[str, np.ndarray]:
36
+ """Get face embeddings for each model"""
37
+ return {
38
+ 'facenet': DeepFace.represent(img_path=image, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018',align=True)[0]['embedding'],
39
+ 'arc': DeepFace.represent(img_path=image, detector_backend='skip', model_name='ArcFace',align=True)[0]['embedding']}
40
+
41
+ def ensemble_prediction(self,
42
+ model_predictions: Dict[str, Tuple[List[str], List[float]]],
43
+ temperature: float = 0.1,
44
+ min_agreement: float = 0.5) -> List[Tuple[str, float]]:
45
+ """
46
+ Combine predictions from multiple models.
47
+
48
+ Parameters:
49
+ model_predictions: Dictionary mapping model names to their (distances, names) predictions
50
+ temperature: Temperature parameter for softmax scaling
51
+ min_agreement: Minimum agreement threshold between models
52
+
53
+ Returns:
54
+ final_predictions: List of (name, confidence) tuples
55
+ """
56
+ # Initialize vote counting
57
+ vote_dict = {}
58
+ confidence_dict = {}
59
+
60
+ # Process each model's predictions
61
+ for model_name, (names, distances) in model_predictions.items():
62
+ # Get model weight (default to 1.0 if not specified)
63
+ model_weight = self.model_weights.get(model_name, 1.0)
64
+
65
+ # Compute confidence scores for this model
66
+ confidences = self.compute_model_confidence(np.array(distances), temperature)
67
+
68
+ # Add weighted votes for top prediction
69
+ top_name = names[0]
70
+ top_confidence = confidences[0]
71
+
72
+ vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
73
+ confidence_dict[top_name] = confidence_dict.get(top_name, [])
74
+ confidence_dict[top_name].append(top_confidence)
75
+
76
+ # Normalize votes
77
+ total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
78
+
79
+ # Compute final results with minimum agreement check
80
+ final_results = []
81
+ for name, votes in vote_dict.items():
82
+ normalized_votes = votes / total_weight
83
+ # Only include results that meet minimum agreement threshold
84
+ if normalized_votes >= min_agreement:
85
+ avg_confidence = np.mean(confidence_dict[name])
86
+ final_score = normalized_votes * avg_confidence * self.boost_factor
87
+ final_score = min(final_score, 1.0) # Cap at 1.0
88
+ final_results.append((name, final_score))
89
+
90
+ # Sort by final score
91
+ final_results.sort(key=lambda x: x[1], reverse=True)
92
+ return final_results
93
+
94
+ def extract_faces(image):
95
+ """Extract faces from an image using DeepFace"""
96
+ return DeepFace.extract_faces(image, detector_backend="yolov8")
97
+
98
+ def extract_faces_mediapipe(image, enforce_detection=False, align=False):
99
+ """Extract faces from an image using MediaPipe backend"""
100
+ return DeepFace.extract_faces(image, detector_backend="mediapipe",
101
+ enforce_detection=enforce_detection,
102
+ align=align)
models/image_processor.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import numpy as np
4
+ from uuid import uuid4
5
+ from PIL import Image as PILImage
6
+ from typing import List, Dict, Any, Tuple
7
+
8
+ from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
9
+ from models.data_manager import DataManager
10
+ from utils.vtt_parser import parse_vtt_offsets
11
+
12
+ def get_face_predictions(face, ensemble, data_manager, results):
13
+ """
14
+ Get predictions for a single face
15
+
16
+ Parameters:
17
+ face: Face image array
18
+ ensemble: EnsembleFaceRecognition instance
19
+ data_manager: DataManager instance
20
+ results: Number of results to return
21
+
22
+ Returns:
23
+ List of (name, confidence) tuples
24
+ """
25
+ # Get embeddings for original and flipped images
26
+ embeddings_orig = ensemble.get_face_embeddings(face)
27
+ embeddings_flip = ensemble.get_face_embeddings(np.fliplr(face))
28
+
29
+ # Average the embeddings
30
+ facenet = np.mean([embeddings_orig['facenet'], embeddings_flip['facenet']], axis=0)
31
+ arc = np.mean([embeddings_orig['arc'], embeddings_flip['arc']], axis=0)
32
+
33
+ # Get predictions from both models
34
+ model_predictions = {
35
+ 'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
36
+ 'arc': data_manager.query_arc_index(arc, max(results, 50)),
37
+ }
38
+
39
+ return ensemble.ensemble_prediction(model_predictions)
40
+
41
+ def image_search_performer(image, data_manager, threshold=0.5, results=3):
42
+ """
43
+ Search for a performer in an image
44
+
45
+ Parameters:
46
+ image: PIL Image object
47
+ data_manager: DataManager instance
48
+ threshold: Confidence threshold
49
+ results: Number of results to return
50
+
51
+ Returns:
52
+ List of performer information dictionaries
53
+ """
54
+ image_array = np.array(image)
55
+ ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
56
+
57
+ try:
58
+ faces = extract_faces(image_array)
59
+ except ValueError:
60
+ raise ValueError("No faces found")
61
+
62
+ predictions = get_face_predictions(faces[0]['face'], ensemble, data_manager, results)
63
+ response = []
64
+ for name, confidence in predictions:
65
+ performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
66
+ if performer_info:
67
+ response.append(performer_info)
68
+ return response
69
+
70
+ def image_search_performers(image, data_manager, threshold=0.5, results=3):
71
+ """
72
+ Search for multiple performers in an image
73
+
74
+ Parameters:
75
+ image: PIL Image object
76
+ data_manager: DataManager instance
77
+ threshold: Confidence threshold
78
+ results: Number of results to return
79
+
80
+ Returns:
81
+ List of dictionaries with face image and performer information
82
+ """
83
+ image_array = np.array(image)
84
+ ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
85
+
86
+ try:
87
+ faces = extract_faces(image_array)
88
+ except ValueError:
89
+ raise ValueError("No faces found")
90
+
91
+ response = []
92
+ for face in faces:
93
+ predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
94
+
95
+ # Crop and encode face image
96
+ area = face['facial_area']
97
+ cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
98
+ buf = io.BytesIO()
99
+ cimage.save(buf, format='JPEG')
100
+ im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
101
+
102
+ # Get performer information
103
+ performers = []
104
+ for name, confidence in predictions:
105
+ performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
106
+ if performer_info:
107
+ performers.append(performer_info)
108
+
109
+ response.append({
110
+ 'image': im_b64,
111
+ 'confidence': face['confidence'],
112
+ 'performers': performers
113
+ })
114
+ return response
115
+
116
+ def find_faces_in_sprite(image, vtt_data):
117
+ """
118
+ Find faces in a sprite image using VTT data
119
+
120
+ Parameters:
121
+ image: PIL Image object
122
+ vtt_data: Base64 encoded VTT data
123
+
124
+ Returns:
125
+ List of dictionaries with face information
126
+ """
127
+ vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
128
+ sprite = PILImage.fromarray(image)
129
+
130
+ results = []
131
+ for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
132
+ cut_frame = sprite.crop((left, top, left + right, top + bottom))
133
+ faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
134
+ faces = [face for face in faces if face['confidence'] > 0.6]
135
+ if faces:
136
+ size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
137
+ data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
138
+ results.append(data)
139
+
140
+ return results
requirements.txt ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ aiofiles==24.1.0
3
+ annotated-types==0.7.0
4
+ anyio==4.9.0
5
+ astunparse==1.6.3
6
+ beautifulsoup4==4.13.4
7
+ blinker==1.9.0
8
+ certifi==2025.1.31
9
+ charset-normalizer==3.4.1
10
+ click==8.1.8
11
+ contourpy==1.3.2
12
+ cycler==0.12.1
13
+ deepface @ git+https://github.com/serengil/deepface.git@cc484b54be5188eb47faf132995af16a871d70b9
14
+ fastapi==0.115.12
15
+ ffmpy==0.5.0
16
+ filelock==3.18.0
17
+ fire==0.7.0
18
+ flask==3.1.0
19
+ flask-cors==5.0.1
20
+ flatbuffers==25.2.10
21
+ fonttools==4.57.0
22
+ fsspec==2025.3.2
23
+ gast==0.6.0
24
+ gdown==5.2.0
25
+ google-pasta==0.2.0
26
+ gradio==5.25.2
27
+ gradio-client==1.8.0
28
+ groovy==0.1.2
29
+ grpcio==1.71.0
30
+ gunicorn==23.0.0
31
+ h11==0.14.0
32
+ h5py==3.13.0
33
+ httpcore==1.0.8
34
+ httpx==0.28.1
35
+ huggingface-hub==0.30.2
36
+ idna==3.10
37
+ itsdangerous==2.2.0
38
+ jinja2==3.1.6
39
+ joblib==1.4.2
40
+ keras==3.9.2
41
+ kiwisolver==1.4.8
42
+ libclang==18.1.1
43
+ lz4==4.4.4
44
+ markdown==3.8
45
+ markdown-it-py==3.0.0
46
+ markupsafe==3.0.2
47
+ matplotlib==3.10.1
48
+ mdurl==0.1.2
49
+ ml-dtypes==0.5.1
50
+ mpmath==1.3.0
51
+ mtcnn==1.0.0
52
+ namex==0.0.8
53
+ networkx==3.4.2
54
+ numpy==2.1.3
55
+ nvidia-cublas-cu12==12.4.5.8
56
+ nvidia-cuda-cupti-cu12==12.4.127
57
+ nvidia-cuda-nvrtc-cu12==12.4.127
58
+ nvidia-cuda-runtime-cu12==12.4.127
59
+ nvidia-cudnn-cu12==9.1.0.70
60
+ nvidia-cufft-cu12==11.2.1.3
61
+ nvidia-curand-cu12==10.3.5.147
62
+ nvidia-cusolver-cu12==11.6.1.9
63
+ nvidia-cusparse-cu12==12.3.1.170
64
+ nvidia-cusparselt-cu12==0.6.2
65
+ nvidia-nccl-cu12==2.21.5
66
+ nvidia-nvjitlink-cu12==12.4.127
67
+ nvidia-nvtx-cu12==12.4.127
68
+ opencv-python==4.11.0.86
69
+ opt-einsum==3.4.0
70
+ optree==0.15.0
71
+ orjson==3.10.16
72
+ packaging==25.0
73
+ pandas==2.2.3
74
+ pillow==11.2.1
75
+ protobuf==5.29.4
76
+ psutil==7.0.0
77
+ py-cpuinfo==9.0.0
78
+ pycryptodomex==3.22.0
79
+ pydantic==2.11.3
80
+ pydantic-core==2.33.1
81
+ pydub==0.25.1
82
+ pygments==2.19.1
83
+ pyparsing==3.2.3
84
+ pysocks==1.7.1
85
+ python-dateutil==2.9.0.post0
86
+ python-multipart==0.0.20
87
+ pytz==2025.2
88
+ pyyaml==6.0.2
89
+ pyzipper==0.3.6
90
+ requests==2.32.3
91
+ retina-face==0.0.17
92
+ rich==14.0.0
93
+ ruff==0.11.6
94
+ safehttpx==0.1.6
95
+ scipy==1.15.2
96
+ seaborn==0.13.2
97
+ semantic-version==2.10.0
98
+ setuptools==78.1.0
99
+ shellingham==1.5.4
100
+ six==1.17.0
101
+ sniffio==1.3.1
102
+ soupsieve==2.6
103
+ starlette==0.46.2
104
+ sympy==1.13.1
105
+ tensorboard==2.19.0
106
+ tensorboard-data-server==0.7.2
107
+ tensorflow==2.19.0
108
+ termcolor==3.0.1
109
+ tf-keras==2.19.0
110
+ tomlkit==0.13.2
111
+ torch==2.6.0
112
+ torchvision==0.21.0
113
+ tqdm==4.67.1
114
+ triton==3.2.0
115
+ typer==0.15.2
116
+ typing-extensions==4.13.2
117
+ typing-inspection==0.4.0
118
+ tzdata==2025.2
119
+ ultralytics==8.3.69
120
+ ultralytics-thop==2.0.14
121
+ urllib3==2.4.0
122
+ uvicorn==0.34.2
123
+ voyager==2.1.0
124
+ websockets==15.0.1
125
+ werkzeug==3.1.3
126
+ wheel==0.45.1
127
+ wrapt==1.17.2
tests/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Test package initialization
3
+ """
tests/test_vtt_parser.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from utils.vtt_parser import parse_vtt_offsets
3
+
4
+ def test_parse_simple_vtt():
5
+ """Test parsing a simple VTT file with one timestamp and coordinates"""
6
+ vtt_content = """WEBVTT
7
+
8
+ 00:00:05.000 --> 00:00:10.000
9
+ xywh=100,200,300,400
10
+ """
11
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
12
+ assert len(result) == 1
13
+ left, top, right, bottom, time = result[0]
14
+ assert left == 100
15
+ assert top == 200
16
+ assert right == 300
17
+ assert bottom == 400
18
+ assert time == 5.0
19
+
20
+ def test_parse_multiple_entries():
21
+ """Test parsing multiple timestamps and coordinates"""
22
+ vtt_content = """WEBVTT
23
+
24
+ 00:00:05.000 --> 00:00:10.000
25
+ xywh=100,200,300,400
26
+
27
+ 00:01:30.500 --> 00:01:35.000
28
+ xywh=150,250,350,450
29
+ """
30
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
31
+ assert len(result) == 2
32
+
33
+ # First entry
34
+ left, top, right, bottom, time = result[0]
35
+ assert (left, top, right, bottom) == (100, 200, 300, 400)
36
+ assert time == 5.0
37
+
38
+ # Second entry
39
+ left, top, right, bottom, time = result[1]
40
+ assert (left, top, right, bottom) == (150, 250, 350, 450)
41
+ assert time == 90.5 # 1 minute 30.5 seconds
42
+
43
+ def test_parse_empty_vtt():
44
+ """Test parsing an empty VTT file"""
45
+ vtt_content = "WEBVTT\n"
46
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
47
+ assert len(result) == 0
48
+
49
+ def test_parse_invalid_format():
50
+ """Test parsing VTT with invalid format should not yield results"""
51
+ vtt_content = """WEBVTT
52
+
53
+ 00:00:05.000 --> 00:00:10.000
54
+ invalid_line
55
+ """
56
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
57
+ assert len(result) == 0
58
+
59
+ def test_parse_hour_timestamp():
60
+ """Test parsing timestamp with hours"""
61
+ vtt_content = """WEBVTT
62
+
63
+ 01:30:05.000 --> 01:30:10.000
64
+ xywh=100,200,300,400
65
+ """
66
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
67
+ assert len(result) == 1
68
+ left, top, right, bottom, time = result[0]
69
+ assert time == 5405.0 # 1 hour + 30 minutes + 5 seconds
70
+
71
+ def test_parse_missing_coordinates():
72
+ """Test that entries without coordinates are skipped"""
73
+ vtt_content = """WEBVTT
74
+
75
+ 00:00:05.000 --> 00:00:10.000
76
+ Some text content
77
+
78
+ 00:00:10.000 --> 00:00:15.000
79
+ xywh=100,200,300,400
80
+ """
81
+ result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
82
+ assert len(result) == 1
83
+ left, top, right, bottom, time = result[0]
84
+ assert time == 10.0
85
+ assert (left, top, right, bottom) == (100, 200, 300, 400)
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # utils package
utils/vtt_parser.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Generator
2
+
3
+ def parse_vtt_offsets(vtt_content: bytes) -> Generator[Tuple[int, int, int, int, float], None, None]:
4
+ """
5
+ Parse VTT file content and extract offsets and timestamps.
6
+
7
+ Parameters:
8
+ vtt_content: Raw VTT file content as bytes
9
+
10
+ Returns:
11
+ Generator yielding tuples of (left, top, right, bottom, time_seconds)
12
+ """
13
+ time_seconds = 0
14
+ left = top = right = bottom = None
15
+
16
+ for line in vtt_content.decode("utf-8").split("\n"):
17
+ line = line.strip()
18
+
19
+ if "-->" in line:
20
+ # grab the start time
21
+ # 00:00:00.000 --> 00:00:41.000
22
+ start = line.split("-->")[0].strip().split(":")
23
+ # convert to seconds
24
+ time_seconds = (
25
+ int(start[0]) * 3600
26
+ + int(start[1]) * 60
27
+ + float(start[2])
28
+ )
29
+ left = top = right = bottom = None
30
+ elif "xywh=" in line:
31
+ left, top, right, bottom = line.split("xywh=")[-1].split(",")
32
+ left, top, right, bottom = (
33
+ int(left),
34
+ int(top),
35
+ int(right),
36
+ int(bottom),
37
+ )
38
+ else:
39
+ continue
40
+
41
+ if not left:
42
+ continue
43
+
44
+ yield left, top, right, bottom, time_seconds
web/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # web package
web/interface.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Dict, Any
3
+
4
+ from models.data_manager import DataManager
5
+ from models.image_processor import (
6
+ image_search_performer,
7
+ image_search_performers,
8
+ find_faces_in_sprite
9
+ )
10
+
11
+ class WebInterface:
12
+ def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
13
+ """
14
+ Initialize the web interface.
15
+
16
+ Parameters:
17
+ data_manager: DataManager instance
18
+ default_threshold: Default confidence threshold
19
+ """
20
+ self.data_manager = data_manager
21
+ self.default_threshold = default_threshold
22
+
23
+ def image_search(self, img, threshold, results):
24
+ """Wrapper for the image search function"""
25
+ return image_search_performer(img, self.data_manager, threshold, results)
26
+
27
+ def multiple_image_search(self, img, threshold, results):
28
+ """Wrapper for the multiple image search function"""
29
+ return image_search_performers(img, self.data_manager, threshold, results)
30
+
31
+ def vector_search(self, vector_json, threshold, results):
32
+ """Wrapper for the vector search function (deprecated)"""
33
+ return {'status': 'not implemented'}
34
+
35
+ def _create_image_search_interface(self):
36
+ """Create the single face search interface"""
37
+ with gr.Blocks() as interface:
38
+ gr.Markdown("# Who is in the photo?")
39
+ gr.Markdown("Upload an image of a person and we'll tell you who it is.")
40
+
41
+ with gr.Row():
42
+ with gr.Column():
43
+ img_input = gr.Image()
44
+ threshold = gr.Slider(
45
+ label="threshold",
46
+ minimum=0.0,
47
+ maximum=1.0,
48
+ value=self.default_threshold
49
+ )
50
+ results_count = gr.Slider(
51
+ label="results",
52
+ minimum=0,
53
+ maximum=50,
54
+ value=3,
55
+ step=1
56
+ )
57
+ search_btn = gr.Button("Search")
58
+
59
+ with gr.Column():
60
+ output = gr.JSON(label="Results")
61
+
62
+ search_btn.click(
63
+ fn=self.image_search,
64
+ inputs=[img_input, threshold, results_count],
65
+ outputs=output
66
+ )
67
+
68
+ return interface
69
+
70
+ def _create_multiple_image_search_interface(self):
71
+ """Create the multiple face search interface"""
72
+ with gr.Blocks() as interface:
73
+ gr.Markdown("# Who is in the photo?")
74
+ gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
75
+
76
+ with gr.Row():
77
+ with gr.Column():
78
+ img_input = gr.Image(type="pil")
79
+ threshold = gr.Slider(
80
+ label="threshold",
81
+ minimum=0.0,
82
+ maximum=1.0,
83
+ value=self.default_threshold
84
+ )
85
+ results_count = gr.Slider(
86
+ label="results",
87
+ minimum=0,
88
+ maximum=50,
89
+ value=3,
90
+ step=1
91
+ )
92
+ search_btn = gr.Button("Search")
93
+
94
+ with gr.Column():
95
+ output = gr.JSON(label="Results")
96
+
97
+ search_btn.click(
98
+ fn=self.multiple_image_search,
99
+ inputs=[img_input, threshold, results_count],
100
+ outputs=output
101
+ )
102
+
103
+ return interface
104
+
105
+ def _create_vector_search_interface(self):
106
+ """Create the vector search interface (deprecated)"""
107
+ with gr.Blocks() as interface:
108
+ gr.Markdown("# Vector Search (deprecated)")
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ vector_input = gr.Textbox()
113
+ threshold = gr.Slider(
114
+ label="threshold",
115
+ minimum=0.0,
116
+ maximum=1.0,
117
+ value=self.default_threshold
118
+ )
119
+ results_count = gr.Slider(
120
+ label="results",
121
+ minimum=0,
122
+ maximum=50,
123
+ value=3,
124
+ step=1
125
+ )
126
+ search_btn = gr.Button("Search")
127
+
128
+ with gr.Column():
129
+ output = gr.JSON(label="Results")
130
+
131
+ search_btn.click(
132
+ fn=self.vector_search,
133
+ inputs=[vector_input, threshold, results_count],
134
+ outputs=output
135
+ )
136
+
137
+ return interface
138
+
139
+ def _create_faces_in_sprite_interface(self):
140
+ """Create the faces in sprite interface"""
141
+ with gr.Blocks() as interface:
142
+ gr.Markdown("# Find Faces in Sprite")
143
+
144
+ with gr.Row():
145
+ with gr.Column():
146
+ img_input = gr.Image()
147
+ vtt_input = gr.Textbox(label="VTT file")
148
+ search_btn = gr.Button("Process")
149
+
150
+ with gr.Column():
151
+ output = gr.JSON(label="Results")
152
+
153
+ search_btn.click(
154
+ fn=find_faces_in_sprite,
155
+ inputs=[img_input, vtt_input],
156
+ outputs=output
157
+ )
158
+
159
+ return interface
160
+
161
+ def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
162
+ """Launch the web interface"""
163
+ with gr.Blocks() as demo:
164
+ with gr.Tabs() as tabs:
165
+ with gr.TabItem("Single Face Search"):
166
+ self._create_image_search_interface()
167
+ with gr.TabItem("Multiple Face Search"):
168
+ self._create_multiple_image_search_interface()
169
+ with gr.TabItem("Vector Search"):
170
+ self._create_vector_search_interface()
171
+ with gr.TabItem("Faces in Sprite"):
172
+ self._create_faces_in_sprite_interface()
173
+
174
+ demo.queue().launch(share=share, ssr_mode=False)