cc1234
commited on
Commit
·
244b0b6
1
Parent(s):
0dc8e91
init
Browse files- .deepface/weights/arcface_weights.h5 +3 -0
- .deepface/weights/facenet512_weights.h5 +3 -0
- .deepface/weights/yolov8n-face.pt +3 -0
- .gitattributes +13 -0
- .gitignore +4 -0
- README.md +2 -2
- app.py +25 -0
- data/face_arc.voy +3 -0
- data/face_facenet.voy +3 -0
- data/faces.json +0 -0
- data/persons.zip +3 -0
- models/__init__.py +1 -0
- models/data_manager.py +105 -0
- models/face_recognition.py +102 -0
- models/image_processor.py +140 -0
- requirements.txt +127 -0
- tests/__init__.py +3 -0
- tests/test_vtt_parser.py +85 -0
- utils/__init__.py +1 -0
- utils/vtt_parser.py +44 -0
- web/__init__.py +1 -0
- web/interface.py +174 -0
.deepface/weights/arcface_weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6336979c0c602cae08d1122a66f4dfb862d059bbcd8ef80306aef2b2249b0c93
|
3 |
+
size 137026640
|
.deepface/weights/facenet512_weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f76b5117a9ca574d536af8199e6720089eb4ad3dc7e93534496d88265de864f
|
3 |
+
size 94955648
|
.deepface/weights/yolov8n-face.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d545bf1add5aa736a4febac4f4f9245a6d596cd0fe70d5d57989fe0cb9e626ca
|
3 |
+
size 6389512
|
.gitattributes
CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
37 |
+
face.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
.deepface/weights/yolov8n-face.pt filter=lfs diff=lfs merge=lfs -text
|
39 |
+
.deepface/weights/face_recognition_sface_2021dec.onnx filter=lfs diff=lfs merge=lfs -text
|
40 |
+
.deepface/weights/res10_300x300_ssd_iter_140000.caffemodel filter=lfs diff=lfs merge=lfs -text
|
41 |
+
.deepface/weights/centerface.onnx filter=lfs diff=lfs merge=lfs -text
|
42 |
+
.deepface/weights/deploy.prototxt filter=lfs diff=lfs merge=lfs -text
|
43 |
+
.deepface/weights/facenet512_weights.h5 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
.deepface/weights/retinaface.h5 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
.deepface/weights/face_detection_yunet_2023mar.onnx filter=lfs diff=lfs merge=lfs -text
|
46 |
+
.deepface/weights/arcface_weights.h5 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
face_arc.voy filter=lfs diff=lfs merge=lfs -text
|
48 |
+
face_facenet.voy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
flagged
|
3 |
+
temp.jpg
|
4 |
+
__pycache__
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: Stashface
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
|
|
1 |
---
|
2 |
title: Stashface
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: indigo
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Set DeepFace home directory
|
4 |
+
os.environ["DEEPFACE_HOME"] = "."
|
5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
6 |
+
|
7 |
+
from models.data_manager import DataManager
|
8 |
+
from web.interface import WebInterface
|
9 |
+
|
10 |
+
def main():
|
11 |
+
"""Main entry point for the application"""
|
12 |
+
# Initialize data manager
|
13 |
+
data_manager = DataManager(
|
14 |
+
faces_path="data/faces.json",
|
15 |
+
performers_zip="data/persons.zip",
|
16 |
+
facenet_index_path="data/face_facenet.voy",
|
17 |
+
arc_index_path="data/face_arc.voy"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Initialize and launch web interface
|
21 |
+
web_interface = WebInterface(data_manager, default_threshold=0.5)
|
22 |
+
web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
main()
|
data/face_arc.voy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1913fae047f492d1b1da7100cb7835d12de5fd527b2f47a1d77996641abec8aa
|
3 |
+
size 56782035
|
data/face_facenet.voy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d722e23bfeb24a661ae1372b7648b61378c2c11870a46e57bdf06f8fe7fee969
|
3 |
+
size 56781931
|
data/faces.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/persons.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9eb10173cddf6a4bd0339218eb2cd9b3621e29d7d111dbcb41b8062ee43e8ed
|
3 |
+
size 5031776
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# models package
|
models/data_manager.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pyzipper
|
4 |
+
from typing import Dict, Any, Optional
|
5 |
+
from voyager import Index, Space, StorageDataType
|
6 |
+
|
7 |
+
class DataManager:
|
8 |
+
def __init__(self, faces_path: str = "data/faces.json",
|
9 |
+
performers_zip: str = "data/persons.zip",
|
10 |
+
facenet_index_path: str = "data/face_facenet.voy",
|
11 |
+
arc_index_path: str = "data/face_arc.voy"):
|
12 |
+
"""
|
13 |
+
Initialize the data manager.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
faces_path: Path to the faces.json file
|
17 |
+
performers_zip: Path to the performers zip file
|
18 |
+
facenet_index_path: Path to the facenet index file
|
19 |
+
arc_index_path: Path to the arc index file
|
20 |
+
"""
|
21 |
+
self.faces_path = faces_path
|
22 |
+
self.performers_zip = performers_zip
|
23 |
+
self.facenet_index_path = facenet_index_path
|
24 |
+
self.arc_index_path = arc_index_path
|
25 |
+
|
26 |
+
# Initialize indices
|
27 |
+
self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
|
28 |
+
self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
|
29 |
+
|
30 |
+
# Load data
|
31 |
+
self.faces = {}
|
32 |
+
self.performer_db = {}
|
33 |
+
self.load_data()
|
34 |
+
|
35 |
+
def load_data(self):
|
36 |
+
"""Load all data from files"""
|
37 |
+
self._load_faces()
|
38 |
+
self._load_performer_db()
|
39 |
+
self._load_indices()
|
40 |
+
|
41 |
+
def _load_faces(self):
|
42 |
+
"""Load faces from JSON file"""
|
43 |
+
try:
|
44 |
+
with open(self.faces_path, 'r') as f:
|
45 |
+
self.faces = json.load(f)
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error loading faces: {e}")
|
48 |
+
self.faces = {}
|
49 |
+
|
50 |
+
def _load_performer_db(self):
|
51 |
+
"""Load performer database from encrypted zip file"""
|
52 |
+
try:
|
53 |
+
with pyzipper.AESZipFile(self.performers_zip) as zf:
|
54 |
+
password = os.getenv("VISAGE_KEY", "").encode('ascii')
|
55 |
+
zf.setpassword(password)
|
56 |
+
self.performer_db = json.loads(zf.read('performers.json'))
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error loading performer database: {e}")
|
59 |
+
self.performer_db = {}
|
60 |
+
|
61 |
+
def _load_indices(self):
|
62 |
+
"""Load face recognition indices"""
|
63 |
+
try:
|
64 |
+
with open(self.arc_index_path, 'rb') as f:
|
65 |
+
self.index_arc = self.index_arc.load(f)
|
66 |
+
|
67 |
+
with open(self.facenet_index_path, 'rb') as f:
|
68 |
+
self.index_facenet = self.index_facenet.load(f)
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error loading indices: {e}")
|
71 |
+
|
72 |
+
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
|
73 |
+
"""
|
74 |
+
Get performer information from the database
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
stash_id: Stash ID of the performer
|
78 |
+
confidence: Confidence score (0-1)
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Dictionary with performer information or None if not found
|
82 |
+
"""
|
83 |
+
performer = self.performer_db.get(stash_id, [])
|
84 |
+
if not performer:
|
85 |
+
return None
|
86 |
+
|
87 |
+
confidence_int = int(confidence * 100)
|
88 |
+
return {
|
89 |
+
'id': stash_id,
|
90 |
+
"name": performer['name'],
|
91 |
+
"confidence": confidence_int,
|
92 |
+
'image': performer['image'],
|
93 |
+
'country': performer['country'],
|
94 |
+
'hits': 1,
|
95 |
+
'distance': confidence_int,
|
96 |
+
'performer_url': f"https://stashdb.org/performers/{stash_id}"
|
97 |
+
}
|
98 |
+
|
99 |
+
def query_facenet_index(self, embedding, limit):
|
100 |
+
"""Query the facenet index with an embedding"""
|
101 |
+
return self.index_facenet.query(embedding, limit)
|
102 |
+
|
103 |
+
def query_arc_index(self, embedding, limit):
|
104 |
+
"""Query the arc index with an embedding"""
|
105 |
+
return self.index_arc.query(embedding, limit)
|
models/face_recognition.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
+
|
5 |
+
from deepface import DeepFace
|
6 |
+
|
7 |
+
class EnsembleFaceRecognition:
|
8 |
+
def __init__(self, model_weights: Dict[str, float] = None):
|
9 |
+
"""
|
10 |
+
Initialize ensemble face recognition system.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
model_weights: Dictionary mapping model names to their weights
|
14 |
+
If None, all models are weighted equally
|
15 |
+
"""
|
16 |
+
self.model_weights = model_weights or {}
|
17 |
+
self.boost_factor = 1.8
|
18 |
+
|
19 |
+
def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
|
20 |
+
"""Normalize distances to [0,1] range within each model's predictions"""
|
21 |
+
min_dist = np.min(distances)
|
22 |
+
max_dist = np.max(distances)
|
23 |
+
if max_dist == min_dist:
|
24 |
+
return np.zeros_like(distances)
|
25 |
+
return (distances - min_dist) / (max_dist - min_dist)
|
26 |
+
|
27 |
+
def compute_model_confidence(self,
|
28 |
+
distances: np.ndarray,
|
29 |
+
temperature: float = 0.1) -> np.ndarray:
|
30 |
+
"""Convert distances to confidence scores for a single model"""
|
31 |
+
normalized_distances = self.normalize_distances(distances)
|
32 |
+
exp_distances = np.exp(-normalized_distances / temperature)
|
33 |
+
return exp_distances / np.sum(exp_distances)
|
34 |
+
|
35 |
+
def get_face_embeddings(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
36 |
+
"""Get face embeddings for each model"""
|
37 |
+
return {
|
38 |
+
'facenet': DeepFace.represent(img_path=image, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018',align=True)[0]['embedding'],
|
39 |
+
'arc': DeepFace.represent(img_path=image, detector_backend='skip', model_name='ArcFace',align=True)[0]['embedding']}
|
40 |
+
|
41 |
+
def ensemble_prediction(self,
|
42 |
+
model_predictions: Dict[str, Tuple[List[str], List[float]]],
|
43 |
+
temperature: float = 0.1,
|
44 |
+
min_agreement: float = 0.5) -> List[Tuple[str, float]]:
|
45 |
+
"""
|
46 |
+
Combine predictions from multiple models.
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
model_predictions: Dictionary mapping model names to their (distances, names) predictions
|
50 |
+
temperature: Temperature parameter for softmax scaling
|
51 |
+
min_agreement: Minimum agreement threshold between models
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
final_predictions: List of (name, confidence) tuples
|
55 |
+
"""
|
56 |
+
# Initialize vote counting
|
57 |
+
vote_dict = {}
|
58 |
+
confidence_dict = {}
|
59 |
+
|
60 |
+
# Process each model's predictions
|
61 |
+
for model_name, (names, distances) in model_predictions.items():
|
62 |
+
# Get model weight (default to 1.0 if not specified)
|
63 |
+
model_weight = self.model_weights.get(model_name, 1.0)
|
64 |
+
|
65 |
+
# Compute confidence scores for this model
|
66 |
+
confidences = self.compute_model_confidence(np.array(distances), temperature)
|
67 |
+
|
68 |
+
# Add weighted votes for top prediction
|
69 |
+
top_name = names[0]
|
70 |
+
top_confidence = confidences[0]
|
71 |
+
|
72 |
+
vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
|
73 |
+
confidence_dict[top_name] = confidence_dict.get(top_name, [])
|
74 |
+
confidence_dict[top_name].append(top_confidence)
|
75 |
+
|
76 |
+
# Normalize votes
|
77 |
+
total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
|
78 |
+
|
79 |
+
# Compute final results with minimum agreement check
|
80 |
+
final_results = []
|
81 |
+
for name, votes in vote_dict.items():
|
82 |
+
normalized_votes = votes / total_weight
|
83 |
+
# Only include results that meet minimum agreement threshold
|
84 |
+
if normalized_votes >= min_agreement:
|
85 |
+
avg_confidence = np.mean(confidence_dict[name])
|
86 |
+
final_score = normalized_votes * avg_confidence * self.boost_factor
|
87 |
+
final_score = min(final_score, 1.0) # Cap at 1.0
|
88 |
+
final_results.append((name, final_score))
|
89 |
+
|
90 |
+
# Sort by final score
|
91 |
+
final_results.sort(key=lambda x: x[1], reverse=True)
|
92 |
+
return final_results
|
93 |
+
|
94 |
+
def extract_faces(image):
|
95 |
+
"""Extract faces from an image using DeepFace"""
|
96 |
+
return DeepFace.extract_faces(image, detector_backend="yolov8")
|
97 |
+
|
98 |
+
def extract_faces_mediapipe(image, enforce_detection=False, align=False):
|
99 |
+
"""Extract faces from an image using MediaPipe backend"""
|
100 |
+
return DeepFace.extract_faces(image, detector_backend="mediapipe",
|
101 |
+
enforce_detection=enforce_detection,
|
102 |
+
align=align)
|
models/image_processor.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import base64
|
3 |
+
import numpy as np
|
4 |
+
from uuid import uuid4
|
5 |
+
from PIL import Image as PILImage
|
6 |
+
from typing import List, Dict, Any, Tuple
|
7 |
+
|
8 |
+
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
|
9 |
+
from models.data_manager import DataManager
|
10 |
+
from utils.vtt_parser import parse_vtt_offsets
|
11 |
+
|
12 |
+
def get_face_predictions(face, ensemble, data_manager, results):
|
13 |
+
"""
|
14 |
+
Get predictions for a single face
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
face: Face image array
|
18 |
+
ensemble: EnsembleFaceRecognition instance
|
19 |
+
data_manager: DataManager instance
|
20 |
+
results: Number of results to return
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
List of (name, confidence) tuples
|
24 |
+
"""
|
25 |
+
# Get embeddings for original and flipped images
|
26 |
+
embeddings_orig = ensemble.get_face_embeddings(face)
|
27 |
+
embeddings_flip = ensemble.get_face_embeddings(np.fliplr(face))
|
28 |
+
|
29 |
+
# Average the embeddings
|
30 |
+
facenet = np.mean([embeddings_orig['facenet'], embeddings_flip['facenet']], axis=0)
|
31 |
+
arc = np.mean([embeddings_orig['arc'], embeddings_flip['arc']], axis=0)
|
32 |
+
|
33 |
+
# Get predictions from both models
|
34 |
+
model_predictions = {
|
35 |
+
'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
|
36 |
+
'arc': data_manager.query_arc_index(arc, max(results, 50)),
|
37 |
+
}
|
38 |
+
|
39 |
+
return ensemble.ensemble_prediction(model_predictions)
|
40 |
+
|
41 |
+
def image_search_performer(image, data_manager, threshold=0.5, results=3):
|
42 |
+
"""
|
43 |
+
Search for a performer in an image
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
image: PIL Image object
|
47 |
+
data_manager: DataManager instance
|
48 |
+
threshold: Confidence threshold
|
49 |
+
results: Number of results to return
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
List of performer information dictionaries
|
53 |
+
"""
|
54 |
+
image_array = np.array(image)
|
55 |
+
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
|
56 |
+
|
57 |
+
try:
|
58 |
+
faces = extract_faces(image_array)
|
59 |
+
except ValueError:
|
60 |
+
raise ValueError("No faces found")
|
61 |
+
|
62 |
+
predictions = get_face_predictions(faces[0]['face'], ensemble, data_manager, results)
|
63 |
+
response = []
|
64 |
+
for name, confidence in predictions:
|
65 |
+
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
|
66 |
+
if performer_info:
|
67 |
+
response.append(performer_info)
|
68 |
+
return response
|
69 |
+
|
70 |
+
def image_search_performers(image, data_manager, threshold=0.5, results=3):
|
71 |
+
"""
|
72 |
+
Search for multiple performers in an image
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
image: PIL Image object
|
76 |
+
data_manager: DataManager instance
|
77 |
+
threshold: Confidence threshold
|
78 |
+
results: Number of results to return
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
List of dictionaries with face image and performer information
|
82 |
+
"""
|
83 |
+
image_array = np.array(image)
|
84 |
+
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
|
85 |
+
|
86 |
+
try:
|
87 |
+
faces = extract_faces(image_array)
|
88 |
+
except ValueError:
|
89 |
+
raise ValueError("No faces found")
|
90 |
+
|
91 |
+
response = []
|
92 |
+
for face in faces:
|
93 |
+
predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
|
94 |
+
|
95 |
+
# Crop and encode face image
|
96 |
+
area = face['facial_area']
|
97 |
+
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
|
98 |
+
buf = io.BytesIO()
|
99 |
+
cimage.save(buf, format='JPEG')
|
100 |
+
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
|
101 |
+
|
102 |
+
# Get performer information
|
103 |
+
performers = []
|
104 |
+
for name, confidence in predictions:
|
105 |
+
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
|
106 |
+
if performer_info:
|
107 |
+
performers.append(performer_info)
|
108 |
+
|
109 |
+
response.append({
|
110 |
+
'image': im_b64,
|
111 |
+
'confidence': face['confidence'],
|
112 |
+
'performers': performers
|
113 |
+
})
|
114 |
+
return response
|
115 |
+
|
116 |
+
def find_faces_in_sprite(image, vtt_data):
|
117 |
+
"""
|
118 |
+
Find faces in a sprite image using VTT data
|
119 |
+
|
120 |
+
Parameters:
|
121 |
+
image: PIL Image object
|
122 |
+
vtt_data: Base64 encoded VTT data
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
List of dictionaries with face information
|
126 |
+
"""
|
127 |
+
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
|
128 |
+
sprite = PILImage.fromarray(image)
|
129 |
+
|
130 |
+
results = []
|
131 |
+
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
|
132 |
+
cut_frame = sprite.crop((left, top, left + right, top + bottom))
|
133 |
+
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
|
134 |
+
faces = [face for face in faces if face['confidence'] > 0.6]
|
135 |
+
if faces:
|
136 |
+
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
|
137 |
+
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
|
138 |
+
results.append(data)
|
139 |
+
|
140 |
+
return results
|
requirements.txt
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.2.2
|
2 |
+
aiofiles==24.1.0
|
3 |
+
annotated-types==0.7.0
|
4 |
+
anyio==4.9.0
|
5 |
+
astunparse==1.6.3
|
6 |
+
beautifulsoup4==4.13.4
|
7 |
+
blinker==1.9.0
|
8 |
+
certifi==2025.1.31
|
9 |
+
charset-normalizer==3.4.1
|
10 |
+
click==8.1.8
|
11 |
+
contourpy==1.3.2
|
12 |
+
cycler==0.12.1
|
13 |
+
deepface @ git+https://github.com/serengil/deepface.git@cc484b54be5188eb47faf132995af16a871d70b9
|
14 |
+
fastapi==0.115.12
|
15 |
+
ffmpy==0.5.0
|
16 |
+
filelock==3.18.0
|
17 |
+
fire==0.7.0
|
18 |
+
flask==3.1.0
|
19 |
+
flask-cors==5.0.1
|
20 |
+
flatbuffers==25.2.10
|
21 |
+
fonttools==4.57.0
|
22 |
+
fsspec==2025.3.2
|
23 |
+
gast==0.6.0
|
24 |
+
gdown==5.2.0
|
25 |
+
google-pasta==0.2.0
|
26 |
+
gradio==5.25.2
|
27 |
+
gradio-client==1.8.0
|
28 |
+
groovy==0.1.2
|
29 |
+
grpcio==1.71.0
|
30 |
+
gunicorn==23.0.0
|
31 |
+
h11==0.14.0
|
32 |
+
h5py==3.13.0
|
33 |
+
httpcore==1.0.8
|
34 |
+
httpx==0.28.1
|
35 |
+
huggingface-hub==0.30.2
|
36 |
+
idna==3.10
|
37 |
+
itsdangerous==2.2.0
|
38 |
+
jinja2==3.1.6
|
39 |
+
joblib==1.4.2
|
40 |
+
keras==3.9.2
|
41 |
+
kiwisolver==1.4.8
|
42 |
+
libclang==18.1.1
|
43 |
+
lz4==4.4.4
|
44 |
+
markdown==3.8
|
45 |
+
markdown-it-py==3.0.0
|
46 |
+
markupsafe==3.0.2
|
47 |
+
matplotlib==3.10.1
|
48 |
+
mdurl==0.1.2
|
49 |
+
ml-dtypes==0.5.1
|
50 |
+
mpmath==1.3.0
|
51 |
+
mtcnn==1.0.0
|
52 |
+
namex==0.0.8
|
53 |
+
networkx==3.4.2
|
54 |
+
numpy==2.1.3
|
55 |
+
nvidia-cublas-cu12==12.4.5.8
|
56 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
57 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
58 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
59 |
+
nvidia-cudnn-cu12==9.1.0.70
|
60 |
+
nvidia-cufft-cu12==11.2.1.3
|
61 |
+
nvidia-curand-cu12==10.3.5.147
|
62 |
+
nvidia-cusolver-cu12==11.6.1.9
|
63 |
+
nvidia-cusparse-cu12==12.3.1.170
|
64 |
+
nvidia-cusparselt-cu12==0.6.2
|
65 |
+
nvidia-nccl-cu12==2.21.5
|
66 |
+
nvidia-nvjitlink-cu12==12.4.127
|
67 |
+
nvidia-nvtx-cu12==12.4.127
|
68 |
+
opencv-python==4.11.0.86
|
69 |
+
opt-einsum==3.4.0
|
70 |
+
optree==0.15.0
|
71 |
+
orjson==3.10.16
|
72 |
+
packaging==25.0
|
73 |
+
pandas==2.2.3
|
74 |
+
pillow==11.2.1
|
75 |
+
protobuf==5.29.4
|
76 |
+
psutil==7.0.0
|
77 |
+
py-cpuinfo==9.0.0
|
78 |
+
pycryptodomex==3.22.0
|
79 |
+
pydantic==2.11.3
|
80 |
+
pydantic-core==2.33.1
|
81 |
+
pydub==0.25.1
|
82 |
+
pygments==2.19.1
|
83 |
+
pyparsing==3.2.3
|
84 |
+
pysocks==1.7.1
|
85 |
+
python-dateutil==2.9.0.post0
|
86 |
+
python-multipart==0.0.20
|
87 |
+
pytz==2025.2
|
88 |
+
pyyaml==6.0.2
|
89 |
+
pyzipper==0.3.6
|
90 |
+
requests==2.32.3
|
91 |
+
retina-face==0.0.17
|
92 |
+
rich==14.0.0
|
93 |
+
ruff==0.11.6
|
94 |
+
safehttpx==0.1.6
|
95 |
+
scipy==1.15.2
|
96 |
+
seaborn==0.13.2
|
97 |
+
semantic-version==2.10.0
|
98 |
+
setuptools==78.1.0
|
99 |
+
shellingham==1.5.4
|
100 |
+
six==1.17.0
|
101 |
+
sniffio==1.3.1
|
102 |
+
soupsieve==2.6
|
103 |
+
starlette==0.46.2
|
104 |
+
sympy==1.13.1
|
105 |
+
tensorboard==2.19.0
|
106 |
+
tensorboard-data-server==0.7.2
|
107 |
+
tensorflow==2.19.0
|
108 |
+
termcolor==3.0.1
|
109 |
+
tf-keras==2.19.0
|
110 |
+
tomlkit==0.13.2
|
111 |
+
torch==2.6.0
|
112 |
+
torchvision==0.21.0
|
113 |
+
tqdm==4.67.1
|
114 |
+
triton==3.2.0
|
115 |
+
typer==0.15.2
|
116 |
+
typing-extensions==4.13.2
|
117 |
+
typing-inspection==0.4.0
|
118 |
+
tzdata==2025.2
|
119 |
+
ultralytics==8.3.69
|
120 |
+
ultralytics-thop==2.0.14
|
121 |
+
urllib3==2.4.0
|
122 |
+
uvicorn==0.34.2
|
123 |
+
voyager==2.1.0
|
124 |
+
websockets==15.0.1
|
125 |
+
werkzeug==3.1.3
|
126 |
+
wheel==0.45.1
|
127 |
+
wrapt==1.17.2
|
tests/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test package initialization
|
3 |
+
"""
|
tests/test_vtt_parser.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from utils.vtt_parser import parse_vtt_offsets
|
3 |
+
|
4 |
+
def test_parse_simple_vtt():
|
5 |
+
"""Test parsing a simple VTT file with one timestamp and coordinates"""
|
6 |
+
vtt_content = """WEBVTT
|
7 |
+
|
8 |
+
00:00:05.000 --> 00:00:10.000
|
9 |
+
xywh=100,200,300,400
|
10 |
+
"""
|
11 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
12 |
+
assert len(result) == 1
|
13 |
+
left, top, right, bottom, time = result[0]
|
14 |
+
assert left == 100
|
15 |
+
assert top == 200
|
16 |
+
assert right == 300
|
17 |
+
assert bottom == 400
|
18 |
+
assert time == 5.0
|
19 |
+
|
20 |
+
def test_parse_multiple_entries():
|
21 |
+
"""Test parsing multiple timestamps and coordinates"""
|
22 |
+
vtt_content = """WEBVTT
|
23 |
+
|
24 |
+
00:00:05.000 --> 00:00:10.000
|
25 |
+
xywh=100,200,300,400
|
26 |
+
|
27 |
+
00:01:30.500 --> 00:01:35.000
|
28 |
+
xywh=150,250,350,450
|
29 |
+
"""
|
30 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
31 |
+
assert len(result) == 2
|
32 |
+
|
33 |
+
# First entry
|
34 |
+
left, top, right, bottom, time = result[0]
|
35 |
+
assert (left, top, right, bottom) == (100, 200, 300, 400)
|
36 |
+
assert time == 5.0
|
37 |
+
|
38 |
+
# Second entry
|
39 |
+
left, top, right, bottom, time = result[1]
|
40 |
+
assert (left, top, right, bottom) == (150, 250, 350, 450)
|
41 |
+
assert time == 90.5 # 1 minute 30.5 seconds
|
42 |
+
|
43 |
+
def test_parse_empty_vtt():
|
44 |
+
"""Test parsing an empty VTT file"""
|
45 |
+
vtt_content = "WEBVTT\n"
|
46 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
47 |
+
assert len(result) == 0
|
48 |
+
|
49 |
+
def test_parse_invalid_format():
|
50 |
+
"""Test parsing VTT with invalid format should not yield results"""
|
51 |
+
vtt_content = """WEBVTT
|
52 |
+
|
53 |
+
00:00:05.000 --> 00:00:10.000
|
54 |
+
invalid_line
|
55 |
+
"""
|
56 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
57 |
+
assert len(result) == 0
|
58 |
+
|
59 |
+
def test_parse_hour_timestamp():
|
60 |
+
"""Test parsing timestamp with hours"""
|
61 |
+
vtt_content = """WEBVTT
|
62 |
+
|
63 |
+
01:30:05.000 --> 01:30:10.000
|
64 |
+
xywh=100,200,300,400
|
65 |
+
"""
|
66 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
67 |
+
assert len(result) == 1
|
68 |
+
left, top, right, bottom, time = result[0]
|
69 |
+
assert time == 5405.0 # 1 hour + 30 minutes + 5 seconds
|
70 |
+
|
71 |
+
def test_parse_missing_coordinates():
|
72 |
+
"""Test that entries without coordinates are skipped"""
|
73 |
+
vtt_content = """WEBVTT
|
74 |
+
|
75 |
+
00:00:05.000 --> 00:00:10.000
|
76 |
+
Some text content
|
77 |
+
|
78 |
+
00:00:10.000 --> 00:00:15.000
|
79 |
+
xywh=100,200,300,400
|
80 |
+
"""
|
81 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
82 |
+
assert len(result) == 1
|
83 |
+
left, top, right, bottom, time = result[0]
|
84 |
+
assert time == 10.0
|
85 |
+
assert (left, top, right, bottom) == (100, 200, 300, 400)
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# utils package
|
utils/vtt_parser.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Generator
|
2 |
+
|
3 |
+
def parse_vtt_offsets(vtt_content: bytes) -> Generator[Tuple[int, int, int, int, float], None, None]:
|
4 |
+
"""
|
5 |
+
Parse VTT file content and extract offsets and timestamps.
|
6 |
+
|
7 |
+
Parameters:
|
8 |
+
vtt_content: Raw VTT file content as bytes
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
Generator yielding tuples of (left, top, right, bottom, time_seconds)
|
12 |
+
"""
|
13 |
+
time_seconds = 0
|
14 |
+
left = top = right = bottom = None
|
15 |
+
|
16 |
+
for line in vtt_content.decode("utf-8").split("\n"):
|
17 |
+
line = line.strip()
|
18 |
+
|
19 |
+
if "-->" in line:
|
20 |
+
# grab the start time
|
21 |
+
# 00:00:00.000 --> 00:00:41.000
|
22 |
+
start = line.split("-->")[0].strip().split(":")
|
23 |
+
# convert to seconds
|
24 |
+
time_seconds = (
|
25 |
+
int(start[0]) * 3600
|
26 |
+
+ int(start[1]) * 60
|
27 |
+
+ float(start[2])
|
28 |
+
)
|
29 |
+
left = top = right = bottom = None
|
30 |
+
elif "xywh=" in line:
|
31 |
+
left, top, right, bottom = line.split("xywh=")[-1].split(",")
|
32 |
+
left, top, right, bottom = (
|
33 |
+
int(left),
|
34 |
+
int(top),
|
35 |
+
int(right),
|
36 |
+
int(bottom),
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
continue
|
40 |
+
|
41 |
+
if not left:
|
42 |
+
continue
|
43 |
+
|
44 |
+
yield left, top, right, bottom, time_seconds
|
web/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# web package
|
web/interface.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict, Any
|
3 |
+
|
4 |
+
from models.data_manager import DataManager
|
5 |
+
from models.image_processor import (
|
6 |
+
image_search_performer,
|
7 |
+
image_search_performers,
|
8 |
+
find_faces_in_sprite
|
9 |
+
)
|
10 |
+
|
11 |
+
class WebInterface:
|
12 |
+
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
|
13 |
+
"""
|
14 |
+
Initialize the web interface.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
data_manager: DataManager instance
|
18 |
+
default_threshold: Default confidence threshold
|
19 |
+
"""
|
20 |
+
self.data_manager = data_manager
|
21 |
+
self.default_threshold = default_threshold
|
22 |
+
|
23 |
+
def image_search(self, img, threshold, results):
|
24 |
+
"""Wrapper for the image search function"""
|
25 |
+
return image_search_performer(img, self.data_manager, threshold, results)
|
26 |
+
|
27 |
+
def multiple_image_search(self, img, threshold, results):
|
28 |
+
"""Wrapper for the multiple image search function"""
|
29 |
+
return image_search_performers(img, self.data_manager, threshold, results)
|
30 |
+
|
31 |
+
def vector_search(self, vector_json, threshold, results):
|
32 |
+
"""Wrapper for the vector search function (deprecated)"""
|
33 |
+
return {'status': 'not implemented'}
|
34 |
+
|
35 |
+
def _create_image_search_interface(self):
|
36 |
+
"""Create the single face search interface"""
|
37 |
+
with gr.Blocks() as interface:
|
38 |
+
gr.Markdown("# Who is in the photo?")
|
39 |
+
gr.Markdown("Upload an image of a person and we'll tell you who it is.")
|
40 |
+
|
41 |
+
with gr.Row():
|
42 |
+
with gr.Column():
|
43 |
+
img_input = gr.Image()
|
44 |
+
threshold = gr.Slider(
|
45 |
+
label="threshold",
|
46 |
+
minimum=0.0,
|
47 |
+
maximum=1.0,
|
48 |
+
value=self.default_threshold
|
49 |
+
)
|
50 |
+
results_count = gr.Slider(
|
51 |
+
label="results",
|
52 |
+
minimum=0,
|
53 |
+
maximum=50,
|
54 |
+
value=3,
|
55 |
+
step=1
|
56 |
+
)
|
57 |
+
search_btn = gr.Button("Search")
|
58 |
+
|
59 |
+
with gr.Column():
|
60 |
+
output = gr.JSON(label="Results")
|
61 |
+
|
62 |
+
search_btn.click(
|
63 |
+
fn=self.image_search,
|
64 |
+
inputs=[img_input, threshold, results_count],
|
65 |
+
outputs=output
|
66 |
+
)
|
67 |
+
|
68 |
+
return interface
|
69 |
+
|
70 |
+
def _create_multiple_image_search_interface(self):
|
71 |
+
"""Create the multiple face search interface"""
|
72 |
+
with gr.Blocks() as interface:
|
73 |
+
gr.Markdown("# Who is in the photo?")
|
74 |
+
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
|
75 |
+
|
76 |
+
with gr.Row():
|
77 |
+
with gr.Column():
|
78 |
+
img_input = gr.Image(type="pil")
|
79 |
+
threshold = gr.Slider(
|
80 |
+
label="threshold",
|
81 |
+
minimum=0.0,
|
82 |
+
maximum=1.0,
|
83 |
+
value=self.default_threshold
|
84 |
+
)
|
85 |
+
results_count = gr.Slider(
|
86 |
+
label="results",
|
87 |
+
minimum=0,
|
88 |
+
maximum=50,
|
89 |
+
value=3,
|
90 |
+
step=1
|
91 |
+
)
|
92 |
+
search_btn = gr.Button("Search")
|
93 |
+
|
94 |
+
with gr.Column():
|
95 |
+
output = gr.JSON(label="Results")
|
96 |
+
|
97 |
+
search_btn.click(
|
98 |
+
fn=self.multiple_image_search,
|
99 |
+
inputs=[img_input, threshold, results_count],
|
100 |
+
outputs=output
|
101 |
+
)
|
102 |
+
|
103 |
+
return interface
|
104 |
+
|
105 |
+
def _create_vector_search_interface(self):
|
106 |
+
"""Create the vector search interface (deprecated)"""
|
107 |
+
with gr.Blocks() as interface:
|
108 |
+
gr.Markdown("# Vector Search (deprecated)")
|
109 |
+
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
vector_input = gr.Textbox()
|
113 |
+
threshold = gr.Slider(
|
114 |
+
label="threshold",
|
115 |
+
minimum=0.0,
|
116 |
+
maximum=1.0,
|
117 |
+
value=self.default_threshold
|
118 |
+
)
|
119 |
+
results_count = gr.Slider(
|
120 |
+
label="results",
|
121 |
+
minimum=0,
|
122 |
+
maximum=50,
|
123 |
+
value=3,
|
124 |
+
step=1
|
125 |
+
)
|
126 |
+
search_btn = gr.Button("Search")
|
127 |
+
|
128 |
+
with gr.Column():
|
129 |
+
output = gr.JSON(label="Results")
|
130 |
+
|
131 |
+
search_btn.click(
|
132 |
+
fn=self.vector_search,
|
133 |
+
inputs=[vector_input, threshold, results_count],
|
134 |
+
outputs=output
|
135 |
+
)
|
136 |
+
|
137 |
+
return interface
|
138 |
+
|
139 |
+
def _create_faces_in_sprite_interface(self):
|
140 |
+
"""Create the faces in sprite interface"""
|
141 |
+
with gr.Blocks() as interface:
|
142 |
+
gr.Markdown("# Find Faces in Sprite")
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
img_input = gr.Image()
|
147 |
+
vtt_input = gr.Textbox(label="VTT file")
|
148 |
+
search_btn = gr.Button("Process")
|
149 |
+
|
150 |
+
with gr.Column():
|
151 |
+
output = gr.JSON(label="Results")
|
152 |
+
|
153 |
+
search_btn.click(
|
154 |
+
fn=find_faces_in_sprite,
|
155 |
+
inputs=[img_input, vtt_input],
|
156 |
+
outputs=output
|
157 |
+
)
|
158 |
+
|
159 |
+
return interface
|
160 |
+
|
161 |
+
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
|
162 |
+
"""Launch the web interface"""
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
with gr.Tabs() as tabs:
|
165 |
+
with gr.TabItem("Single Face Search"):
|
166 |
+
self._create_image_search_interface()
|
167 |
+
with gr.TabItem("Multiple Face Search"):
|
168 |
+
self._create_multiple_image_search_interface()
|
169 |
+
with gr.TabItem("Vector Search"):
|
170 |
+
self._create_vector_search_interface()
|
171 |
+
with gr.TabItem("Faces in Sprite"):
|
172 |
+
self._create_faces_in_sprite_interface()
|
173 |
+
|
174 |
+
demo.queue().launch(share=share, ssr_mode=False)
|