Spaces:
Running
Running
init
Browse files- .deepface/weights/arcface_weights.h5 +3 -0
- .deepface/weights/facenet512_weights.h5 +3 -0
- .deepface/weights/yolov8n-face.pt +3 -0
- .gitattributes +13 -0
- .gitignore +4 -0
- README.md +2 -2
- app.py +25 -0
- data/faces.json +0 -0
- data/persons.zip +3 -0
- models/__init__.py +1 -0
- models/data_manager.py +105 -0
- models/face_recognition.py +102 -0
- models/image_processor.py +140 -0
- requirements.txt +127 -0
- tests/__init__.py +3 -0
- tests/test_vtt_parser.py +85 -0
- utils/__init__.py +1 -0
- utils/vtt_parser.py +44 -0
- web/__init__.py +1 -0
- web/interface.py +174 -0
.deepface/weights/arcface_weights.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6336979c0c602cae08d1122a66f4dfb862d059bbcd8ef80306aef2b2249b0c93
|
| 3 |
+
size 137026640
|
.deepface/weights/facenet512_weights.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f76b5117a9ca574d536af8199e6720089eb4ad3dc7e93534496d88265de864f
|
| 3 |
+
size 94955648
|
.deepface/weights/yolov8n-face.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d545bf1add5aa736a4febac4f4f9245a6d596cd0fe70d5d57989fe0cb9e626ca
|
| 3 |
+
size 6389512
|
.gitattributes
CHANGED
|
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
face.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
.deepface/weights/yolov8n-face.pt filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
.deepface/weights/face_recognition_sface_2021dec.onnx filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
.deepface/weights/res10_300x300_ssd_iter_140000.caffemodel filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
.deepface/weights/centerface.onnx filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
.deepface/weights/deploy.prototxt filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
.deepface/weights/facenet512_weights.h5 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
.deepface/weights/retinaface.h5 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
.deepface/weights/face_detection_yunet_2023mar.onnx filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
.deepface/weights/arcface_weights.h5 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
face_arc.voy filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
face_facenet.voy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv
|
| 2 |
+
flagged
|
| 3 |
+
temp.jpg
|
| 4 |
+
__pycache__
|
README.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
---
|
| 2 |
title: Stashface
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.25.2
|
|
|
|
| 1 |
---
|
| 2 |
title: Stashface
|
| 3 |
+
emoji: 👀
|
| 4 |
+
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.25.2
|
app.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Set DeepFace home directory
|
| 4 |
+
os.environ["DEEPFACE_HOME"] = "."
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 6 |
+
|
| 7 |
+
from models.data_manager import DataManager
|
| 8 |
+
from web.interface import WebInterface
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
"""Main entry point for the application"""
|
| 12 |
+
# Initialize data manager
|
| 13 |
+
data_manager = DataManager(
|
| 14 |
+
faces_path="data/faces.json",
|
| 15 |
+
performers_zip="data/persons.zip",
|
| 16 |
+
facenet_index_path="data/face_facenet.voy",
|
| 17 |
+
arc_index_path="data/face_arc.voy"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Initialize and launch web interface
|
| 21 |
+
web_interface = WebInterface(data_manager, default_threshold=0.5)
|
| 22 |
+
web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
main()
|
data/faces.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/persons.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9eb10173cddf6a4bd0339218eb2cd9b3621e29d7d111dbcb41b8062ee43e8ed
|
| 3 |
+
size 5031776
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# models package
|
models/data_manager.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pyzipper
|
| 4 |
+
from typing import Dict, Any, Optional
|
| 5 |
+
from voyager import Index, Space, StorageDataType
|
| 6 |
+
|
| 7 |
+
class DataManager:
|
| 8 |
+
def __init__(self, faces_path: str = "data/faces.json",
|
| 9 |
+
performers_zip: str = "data/persons.zip",
|
| 10 |
+
facenet_index_path: str = "data/face_facenet.voy",
|
| 11 |
+
arc_index_path: str = "data/face_arc.voy"):
|
| 12 |
+
"""
|
| 13 |
+
Initialize the data manager.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
faces_path: Path to the faces.json file
|
| 17 |
+
performers_zip: Path to the performers zip file
|
| 18 |
+
facenet_index_path: Path to the facenet index file
|
| 19 |
+
arc_index_path: Path to the arc index file
|
| 20 |
+
"""
|
| 21 |
+
self.faces_path = faces_path
|
| 22 |
+
self.performers_zip = performers_zip
|
| 23 |
+
self.facenet_index_path = facenet_index_path
|
| 24 |
+
self.arc_index_path = arc_index_path
|
| 25 |
+
|
| 26 |
+
# Initialize indices
|
| 27 |
+
self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
|
| 28 |
+
self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
|
| 29 |
+
|
| 30 |
+
# Load data
|
| 31 |
+
self.faces = {}
|
| 32 |
+
self.performer_db = {}
|
| 33 |
+
self.load_data()
|
| 34 |
+
|
| 35 |
+
def load_data(self):
|
| 36 |
+
"""Load all data from files"""
|
| 37 |
+
self._load_faces()
|
| 38 |
+
self._load_performer_db()
|
| 39 |
+
self._load_indices()
|
| 40 |
+
|
| 41 |
+
def _load_faces(self):
|
| 42 |
+
"""Load faces from JSON file"""
|
| 43 |
+
try:
|
| 44 |
+
with open(self.faces_path, 'r') as f:
|
| 45 |
+
self.faces = json.load(f)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Error loading faces: {e}")
|
| 48 |
+
self.faces = {}
|
| 49 |
+
|
| 50 |
+
def _load_performer_db(self):
|
| 51 |
+
"""Load performer database from encrypted zip file"""
|
| 52 |
+
try:
|
| 53 |
+
with pyzipper.AESZipFile(self.performers_zip) as zf:
|
| 54 |
+
password = os.getenv("VISAGE_KEY", "").encode('ascii')
|
| 55 |
+
zf.setpassword(password)
|
| 56 |
+
self.performer_db = json.loads(zf.read('performers.json'))
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error loading performer database: {e}")
|
| 59 |
+
self.performer_db = {}
|
| 60 |
+
|
| 61 |
+
def _load_indices(self):
|
| 62 |
+
"""Load face recognition indices"""
|
| 63 |
+
try:
|
| 64 |
+
with open(self.arc_index_path, 'rb') as f:
|
| 65 |
+
self.index_arc = self.index_arc.load(f)
|
| 66 |
+
|
| 67 |
+
with open(self.facenet_index_path, 'rb') as f:
|
| 68 |
+
self.index_facenet = self.index_facenet.load(f)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error loading indices: {e}")
|
| 71 |
+
|
| 72 |
+
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
|
| 73 |
+
"""
|
| 74 |
+
Get performer information from the database
|
| 75 |
+
|
| 76 |
+
Parameters:
|
| 77 |
+
stash_id: Stash ID of the performer
|
| 78 |
+
confidence: Confidence score (0-1)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Dictionary with performer information or None if not found
|
| 82 |
+
"""
|
| 83 |
+
performer = self.performer_db.get(stash_id, [])
|
| 84 |
+
if not performer:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
confidence_int = int(confidence * 100)
|
| 88 |
+
return {
|
| 89 |
+
'id': stash_id,
|
| 90 |
+
"name": performer['name'],
|
| 91 |
+
"confidence": confidence_int,
|
| 92 |
+
'image': performer['image'],
|
| 93 |
+
'country': performer['country'],
|
| 94 |
+
'hits': 1,
|
| 95 |
+
'distance': confidence_int,
|
| 96 |
+
'performer_url': f"https://stashdb.org/performers/{stash_id}"
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
def query_facenet_index(self, embedding, limit):
|
| 100 |
+
"""Query the facenet index with an embedding"""
|
| 101 |
+
return self.index_facenet.query(embedding, limit)
|
| 102 |
+
|
| 103 |
+
def query_arc_index(self, embedding, limit):
|
| 104 |
+
"""Query the arc index with an embedding"""
|
| 105 |
+
return self.index_arc.query(embedding, limit)
|
models/face_recognition.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
|
| 5 |
+
from deepface import DeepFace
|
| 6 |
+
|
| 7 |
+
class EnsembleFaceRecognition:
|
| 8 |
+
def __init__(self, model_weights: Dict[str, float] = None):
|
| 9 |
+
"""
|
| 10 |
+
Initialize ensemble face recognition system.
|
| 11 |
+
|
| 12 |
+
Parameters:
|
| 13 |
+
model_weights: Dictionary mapping model names to their weights
|
| 14 |
+
If None, all models are weighted equally
|
| 15 |
+
"""
|
| 16 |
+
self.model_weights = model_weights or {}
|
| 17 |
+
self.boost_factor = 1.8
|
| 18 |
+
|
| 19 |
+
def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
|
| 20 |
+
"""Normalize distances to [0,1] range within each model's predictions"""
|
| 21 |
+
min_dist = np.min(distances)
|
| 22 |
+
max_dist = np.max(distances)
|
| 23 |
+
if max_dist == min_dist:
|
| 24 |
+
return np.zeros_like(distances)
|
| 25 |
+
return (distances - min_dist) / (max_dist - min_dist)
|
| 26 |
+
|
| 27 |
+
def compute_model_confidence(self,
|
| 28 |
+
distances: np.ndarray,
|
| 29 |
+
temperature: float = 0.1) -> np.ndarray:
|
| 30 |
+
"""Convert distances to confidence scores for a single model"""
|
| 31 |
+
normalized_distances = self.normalize_distances(distances)
|
| 32 |
+
exp_distances = np.exp(-normalized_distances / temperature)
|
| 33 |
+
return exp_distances / np.sum(exp_distances)
|
| 34 |
+
|
| 35 |
+
def get_face_embeddings(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
| 36 |
+
"""Get face embeddings for each model"""
|
| 37 |
+
return {
|
| 38 |
+
'facenet': DeepFace.represent(img_path=image, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018',align=True)[0]['embedding'],
|
| 39 |
+
'arc': DeepFace.represent(img_path=image, detector_backend='skip', model_name='ArcFace',align=True)[0]['embedding']}
|
| 40 |
+
|
| 41 |
+
def ensemble_prediction(self,
|
| 42 |
+
model_predictions: Dict[str, Tuple[List[str], List[float]]],
|
| 43 |
+
temperature: float = 0.1,
|
| 44 |
+
min_agreement: float = 0.5) -> List[Tuple[str, float]]:
|
| 45 |
+
"""
|
| 46 |
+
Combine predictions from multiple models.
|
| 47 |
+
|
| 48 |
+
Parameters:
|
| 49 |
+
model_predictions: Dictionary mapping model names to their (distances, names) predictions
|
| 50 |
+
temperature: Temperature parameter for softmax scaling
|
| 51 |
+
min_agreement: Minimum agreement threshold between models
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
final_predictions: List of (name, confidence) tuples
|
| 55 |
+
"""
|
| 56 |
+
# Initialize vote counting
|
| 57 |
+
vote_dict = {}
|
| 58 |
+
confidence_dict = {}
|
| 59 |
+
|
| 60 |
+
# Process each model's predictions
|
| 61 |
+
for model_name, (names, distances) in model_predictions.items():
|
| 62 |
+
# Get model weight (default to 1.0 if not specified)
|
| 63 |
+
model_weight = self.model_weights.get(model_name, 1.0)
|
| 64 |
+
|
| 65 |
+
# Compute confidence scores for this model
|
| 66 |
+
confidences = self.compute_model_confidence(np.array(distances), temperature)
|
| 67 |
+
|
| 68 |
+
# Add weighted votes for top prediction
|
| 69 |
+
top_name = names[0]
|
| 70 |
+
top_confidence = confidences[0]
|
| 71 |
+
|
| 72 |
+
vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
|
| 73 |
+
confidence_dict[top_name] = confidence_dict.get(top_name, [])
|
| 74 |
+
confidence_dict[top_name].append(top_confidence)
|
| 75 |
+
|
| 76 |
+
# Normalize votes
|
| 77 |
+
total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
|
| 78 |
+
|
| 79 |
+
# Compute final results with minimum agreement check
|
| 80 |
+
final_results = []
|
| 81 |
+
for name, votes in vote_dict.items():
|
| 82 |
+
normalized_votes = votes / total_weight
|
| 83 |
+
# Only include results that meet minimum agreement threshold
|
| 84 |
+
if normalized_votes >= min_agreement:
|
| 85 |
+
avg_confidence = np.mean(confidence_dict[name])
|
| 86 |
+
final_score = normalized_votes * avg_confidence * self.boost_factor
|
| 87 |
+
final_score = min(final_score, 1.0) # Cap at 1.0
|
| 88 |
+
final_results.append((name, final_score))
|
| 89 |
+
|
| 90 |
+
# Sort by final score
|
| 91 |
+
final_results.sort(key=lambda x: x[1], reverse=True)
|
| 92 |
+
return final_results
|
| 93 |
+
|
| 94 |
+
def extract_faces(image):
|
| 95 |
+
"""Extract faces from an image using DeepFace"""
|
| 96 |
+
return DeepFace.extract_faces(image, detector_backend="yolov8")
|
| 97 |
+
|
| 98 |
+
def extract_faces_mediapipe(image, enforce_detection=False, align=False):
|
| 99 |
+
"""Extract faces from an image using MediaPipe backend"""
|
| 100 |
+
return DeepFace.extract_faces(image, detector_backend="mediapipe",
|
| 101 |
+
enforce_detection=enforce_detection,
|
| 102 |
+
align=align)
|
models/image_processor.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import base64
|
| 3 |
+
import numpy as np
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
from PIL import Image as PILImage
|
| 6 |
+
from typing import List, Dict, Any, Tuple
|
| 7 |
+
|
| 8 |
+
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
|
| 9 |
+
from models.data_manager import DataManager
|
| 10 |
+
from utils.vtt_parser import parse_vtt_offsets
|
| 11 |
+
|
| 12 |
+
def get_face_predictions(face, ensemble, data_manager, results):
|
| 13 |
+
"""
|
| 14 |
+
Get predictions for a single face
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
face: Face image array
|
| 18 |
+
ensemble: EnsembleFaceRecognition instance
|
| 19 |
+
data_manager: DataManager instance
|
| 20 |
+
results: Number of results to return
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
List of (name, confidence) tuples
|
| 24 |
+
"""
|
| 25 |
+
# Get embeddings for original and flipped images
|
| 26 |
+
embeddings_orig = ensemble.get_face_embeddings(face)
|
| 27 |
+
embeddings_flip = ensemble.get_face_embeddings(np.fliplr(face))
|
| 28 |
+
|
| 29 |
+
# Average the embeddings
|
| 30 |
+
facenet = np.mean([embeddings_orig['facenet'], embeddings_flip['facenet']], axis=0)
|
| 31 |
+
arc = np.mean([embeddings_orig['arc'], embeddings_flip['arc']], axis=0)
|
| 32 |
+
|
| 33 |
+
# Get predictions from both models
|
| 34 |
+
model_predictions = {
|
| 35 |
+
'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
|
| 36 |
+
'arc': data_manager.query_arc_index(arc, max(results, 50)),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
return ensemble.ensemble_prediction(model_predictions)
|
| 40 |
+
|
| 41 |
+
def image_search_performer(image, data_manager, threshold=0.5, results=3):
|
| 42 |
+
"""
|
| 43 |
+
Search for a performer in an image
|
| 44 |
+
|
| 45 |
+
Parameters:
|
| 46 |
+
image: PIL Image object
|
| 47 |
+
data_manager: DataManager instance
|
| 48 |
+
threshold: Confidence threshold
|
| 49 |
+
results: Number of results to return
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
List of performer information dictionaries
|
| 53 |
+
"""
|
| 54 |
+
image_array = np.array(image)
|
| 55 |
+
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
faces = extract_faces(image_array)
|
| 59 |
+
except ValueError:
|
| 60 |
+
raise ValueError("No faces found")
|
| 61 |
+
|
| 62 |
+
predictions = get_face_predictions(faces[0]['face'], ensemble, data_manager, results)
|
| 63 |
+
response = []
|
| 64 |
+
for name, confidence in predictions:
|
| 65 |
+
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
|
| 66 |
+
if performer_info:
|
| 67 |
+
response.append(performer_info)
|
| 68 |
+
return response
|
| 69 |
+
|
| 70 |
+
def image_search_performers(image, data_manager, threshold=0.5, results=3):
|
| 71 |
+
"""
|
| 72 |
+
Search for multiple performers in an image
|
| 73 |
+
|
| 74 |
+
Parameters:
|
| 75 |
+
image: PIL Image object
|
| 76 |
+
data_manager: DataManager instance
|
| 77 |
+
threshold: Confidence threshold
|
| 78 |
+
results: Number of results to return
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of dictionaries with face image and performer information
|
| 82 |
+
"""
|
| 83 |
+
image_array = np.array(image)
|
| 84 |
+
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
faces = extract_faces(image_array)
|
| 88 |
+
except ValueError:
|
| 89 |
+
raise ValueError("No faces found")
|
| 90 |
+
|
| 91 |
+
response = []
|
| 92 |
+
for face in faces:
|
| 93 |
+
predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
|
| 94 |
+
|
| 95 |
+
# Crop and encode face image
|
| 96 |
+
area = face['facial_area']
|
| 97 |
+
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
|
| 98 |
+
buf = io.BytesIO()
|
| 99 |
+
cimage.save(buf, format='JPEG')
|
| 100 |
+
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
|
| 101 |
+
|
| 102 |
+
# Get performer information
|
| 103 |
+
performers = []
|
| 104 |
+
for name, confidence in predictions:
|
| 105 |
+
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
|
| 106 |
+
if performer_info:
|
| 107 |
+
performers.append(performer_info)
|
| 108 |
+
|
| 109 |
+
response.append({
|
| 110 |
+
'image': im_b64,
|
| 111 |
+
'confidence': face['confidence'],
|
| 112 |
+
'performers': performers
|
| 113 |
+
})
|
| 114 |
+
return response
|
| 115 |
+
|
| 116 |
+
def find_faces_in_sprite(image, vtt_data):
|
| 117 |
+
"""
|
| 118 |
+
Find faces in a sprite image using VTT data
|
| 119 |
+
|
| 120 |
+
Parameters:
|
| 121 |
+
image: PIL Image object
|
| 122 |
+
vtt_data: Base64 encoded VTT data
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
List of dictionaries with face information
|
| 126 |
+
"""
|
| 127 |
+
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
|
| 128 |
+
sprite = PILImage.fromarray(image)
|
| 129 |
+
|
| 130 |
+
results = []
|
| 131 |
+
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
|
| 132 |
+
cut_frame = sprite.crop((left, top, left + right, top + bottom))
|
| 133 |
+
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
|
| 134 |
+
faces = [face for face in faces if face['confidence'] > 0.6]
|
| 135 |
+
if faces:
|
| 136 |
+
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
|
| 137 |
+
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
|
| 138 |
+
results.append(data)
|
| 139 |
+
|
| 140 |
+
return results
|
requirements.txt
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.2.2
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio==4.9.0
|
| 5 |
+
astunparse==1.6.3
|
| 6 |
+
beautifulsoup4==4.13.4
|
| 7 |
+
blinker==1.9.0
|
| 8 |
+
certifi==2025.1.31
|
| 9 |
+
charset-normalizer==3.4.1
|
| 10 |
+
click==8.1.8
|
| 11 |
+
contourpy==1.3.2
|
| 12 |
+
cycler==0.12.1
|
| 13 |
+
deepface @ git+https://github.com/serengil/deepface.git@cc484b54be5188eb47faf132995af16a871d70b9
|
| 14 |
+
fastapi==0.115.12
|
| 15 |
+
ffmpy==0.5.0
|
| 16 |
+
filelock==3.18.0
|
| 17 |
+
fire==0.7.0
|
| 18 |
+
flask==3.1.0
|
| 19 |
+
flask-cors==5.0.1
|
| 20 |
+
flatbuffers==25.2.10
|
| 21 |
+
fonttools==4.57.0
|
| 22 |
+
fsspec==2025.3.2
|
| 23 |
+
gast==0.6.0
|
| 24 |
+
gdown==5.2.0
|
| 25 |
+
google-pasta==0.2.0
|
| 26 |
+
gradio==5.25.2
|
| 27 |
+
gradio-client==1.8.0
|
| 28 |
+
groovy==0.1.2
|
| 29 |
+
grpcio==1.71.0
|
| 30 |
+
gunicorn==23.0.0
|
| 31 |
+
h11==0.14.0
|
| 32 |
+
h5py==3.13.0
|
| 33 |
+
httpcore==1.0.8
|
| 34 |
+
httpx==0.28.1
|
| 35 |
+
huggingface-hub==0.30.2
|
| 36 |
+
idna==3.10
|
| 37 |
+
itsdangerous==2.2.0
|
| 38 |
+
jinja2==3.1.6
|
| 39 |
+
joblib==1.4.2
|
| 40 |
+
keras==3.9.2
|
| 41 |
+
kiwisolver==1.4.8
|
| 42 |
+
libclang==18.1.1
|
| 43 |
+
lz4==4.4.4
|
| 44 |
+
markdown==3.8
|
| 45 |
+
markdown-it-py==3.0.0
|
| 46 |
+
markupsafe==3.0.2
|
| 47 |
+
matplotlib==3.10.1
|
| 48 |
+
mdurl==0.1.2
|
| 49 |
+
ml-dtypes==0.5.1
|
| 50 |
+
mpmath==1.3.0
|
| 51 |
+
mtcnn==1.0.0
|
| 52 |
+
namex==0.0.8
|
| 53 |
+
networkx==3.4.2
|
| 54 |
+
numpy==2.1.3
|
| 55 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 56 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 57 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 58 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 59 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 60 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 61 |
+
nvidia-curand-cu12==10.3.5.147
|
| 62 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 63 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 64 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 65 |
+
nvidia-nccl-cu12==2.21.5
|
| 66 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 67 |
+
nvidia-nvtx-cu12==12.4.127
|
| 68 |
+
opencv-python==4.11.0.86
|
| 69 |
+
opt-einsum==3.4.0
|
| 70 |
+
optree==0.15.0
|
| 71 |
+
orjson==3.10.16
|
| 72 |
+
packaging==25.0
|
| 73 |
+
pandas==2.2.3
|
| 74 |
+
pillow==11.2.1
|
| 75 |
+
protobuf==5.29.4
|
| 76 |
+
psutil==7.0.0
|
| 77 |
+
py-cpuinfo==9.0.0
|
| 78 |
+
pycryptodomex==3.22.0
|
| 79 |
+
pydantic==2.11.3
|
| 80 |
+
pydantic-core==2.33.1
|
| 81 |
+
pydub==0.25.1
|
| 82 |
+
pygments==2.19.1
|
| 83 |
+
pyparsing==3.2.3
|
| 84 |
+
pysocks==1.7.1
|
| 85 |
+
python-dateutil==2.9.0.post0
|
| 86 |
+
python-multipart==0.0.20
|
| 87 |
+
pytz==2025.2
|
| 88 |
+
pyyaml==6.0.2
|
| 89 |
+
pyzipper==0.3.6
|
| 90 |
+
requests==2.32.3
|
| 91 |
+
retina-face==0.0.17
|
| 92 |
+
rich==14.0.0
|
| 93 |
+
ruff==0.11.6
|
| 94 |
+
safehttpx==0.1.6
|
| 95 |
+
scipy==1.15.2
|
| 96 |
+
seaborn==0.13.2
|
| 97 |
+
semantic-version==2.10.0
|
| 98 |
+
setuptools==78.1.0
|
| 99 |
+
shellingham==1.5.4
|
| 100 |
+
six==1.17.0
|
| 101 |
+
sniffio==1.3.1
|
| 102 |
+
soupsieve==2.6
|
| 103 |
+
starlette==0.46.2
|
| 104 |
+
sympy==1.13.1
|
| 105 |
+
tensorboard==2.19.0
|
| 106 |
+
tensorboard-data-server==0.7.2
|
| 107 |
+
tensorflow==2.19.0
|
| 108 |
+
termcolor==3.0.1
|
| 109 |
+
tf-keras==2.19.0
|
| 110 |
+
tomlkit==0.13.2
|
| 111 |
+
torch==2.6.0
|
| 112 |
+
torchvision==0.21.0
|
| 113 |
+
tqdm==4.67.1
|
| 114 |
+
triton==3.2.0
|
| 115 |
+
typer==0.15.2
|
| 116 |
+
typing-extensions==4.13.2
|
| 117 |
+
typing-inspection==0.4.0
|
| 118 |
+
tzdata==2025.2
|
| 119 |
+
ultralytics==8.3.69
|
| 120 |
+
ultralytics-thop==2.0.14
|
| 121 |
+
urllib3==2.4.0
|
| 122 |
+
uvicorn==0.34.2
|
| 123 |
+
voyager==2.1.0
|
| 124 |
+
websockets==15.0.1
|
| 125 |
+
werkzeug==3.1.3
|
| 126 |
+
wheel==0.45.1
|
| 127 |
+
wrapt==1.17.2
|
tests/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test package initialization
|
| 3 |
+
"""
|
tests/test_vtt_parser.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from utils.vtt_parser import parse_vtt_offsets
|
| 3 |
+
|
| 4 |
+
def test_parse_simple_vtt():
|
| 5 |
+
"""Test parsing a simple VTT file with one timestamp and coordinates"""
|
| 6 |
+
vtt_content = """WEBVTT
|
| 7 |
+
|
| 8 |
+
00:00:05.000 --> 00:00:10.000
|
| 9 |
+
xywh=100,200,300,400
|
| 10 |
+
"""
|
| 11 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 12 |
+
assert len(result) == 1
|
| 13 |
+
left, top, right, bottom, time = result[0]
|
| 14 |
+
assert left == 100
|
| 15 |
+
assert top == 200
|
| 16 |
+
assert right == 300
|
| 17 |
+
assert bottom == 400
|
| 18 |
+
assert time == 5.0
|
| 19 |
+
|
| 20 |
+
def test_parse_multiple_entries():
|
| 21 |
+
"""Test parsing multiple timestamps and coordinates"""
|
| 22 |
+
vtt_content = """WEBVTT
|
| 23 |
+
|
| 24 |
+
00:00:05.000 --> 00:00:10.000
|
| 25 |
+
xywh=100,200,300,400
|
| 26 |
+
|
| 27 |
+
00:01:30.500 --> 00:01:35.000
|
| 28 |
+
xywh=150,250,350,450
|
| 29 |
+
"""
|
| 30 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 31 |
+
assert len(result) == 2
|
| 32 |
+
|
| 33 |
+
# First entry
|
| 34 |
+
left, top, right, bottom, time = result[0]
|
| 35 |
+
assert (left, top, right, bottom) == (100, 200, 300, 400)
|
| 36 |
+
assert time == 5.0
|
| 37 |
+
|
| 38 |
+
# Second entry
|
| 39 |
+
left, top, right, bottom, time = result[1]
|
| 40 |
+
assert (left, top, right, bottom) == (150, 250, 350, 450)
|
| 41 |
+
assert time == 90.5 # 1 minute 30.5 seconds
|
| 42 |
+
|
| 43 |
+
def test_parse_empty_vtt():
|
| 44 |
+
"""Test parsing an empty VTT file"""
|
| 45 |
+
vtt_content = "WEBVTT\n"
|
| 46 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 47 |
+
assert len(result) == 0
|
| 48 |
+
|
| 49 |
+
def test_parse_invalid_format():
|
| 50 |
+
"""Test parsing VTT with invalid format should not yield results"""
|
| 51 |
+
vtt_content = """WEBVTT
|
| 52 |
+
|
| 53 |
+
00:00:05.000 --> 00:00:10.000
|
| 54 |
+
invalid_line
|
| 55 |
+
"""
|
| 56 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 57 |
+
assert len(result) == 0
|
| 58 |
+
|
| 59 |
+
def test_parse_hour_timestamp():
|
| 60 |
+
"""Test parsing timestamp with hours"""
|
| 61 |
+
vtt_content = """WEBVTT
|
| 62 |
+
|
| 63 |
+
01:30:05.000 --> 01:30:10.000
|
| 64 |
+
xywh=100,200,300,400
|
| 65 |
+
"""
|
| 66 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 67 |
+
assert len(result) == 1
|
| 68 |
+
left, top, right, bottom, time = result[0]
|
| 69 |
+
assert time == 5405.0 # 1 hour + 30 minutes + 5 seconds
|
| 70 |
+
|
| 71 |
+
def test_parse_missing_coordinates():
|
| 72 |
+
"""Test that entries without coordinates are skipped"""
|
| 73 |
+
vtt_content = """WEBVTT
|
| 74 |
+
|
| 75 |
+
00:00:05.000 --> 00:00:10.000
|
| 76 |
+
Some text content
|
| 77 |
+
|
| 78 |
+
00:00:10.000 --> 00:00:15.000
|
| 79 |
+
xywh=100,200,300,400
|
| 80 |
+
"""
|
| 81 |
+
result = list(parse_vtt_offsets(vtt_content.encode('utf-8')))
|
| 82 |
+
assert len(result) == 1
|
| 83 |
+
left, top, right, bottom, time = result[0]
|
| 84 |
+
assert time == 10.0
|
| 85 |
+
assert (left, top, right, bottom) == (100, 200, 300, 400)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# utils package
|
utils/vtt_parser.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Generator
|
| 2 |
+
|
| 3 |
+
def parse_vtt_offsets(vtt_content: bytes) -> Generator[Tuple[int, int, int, int, float], None, None]:
|
| 4 |
+
"""
|
| 5 |
+
Parse VTT file content and extract offsets and timestamps.
|
| 6 |
+
|
| 7 |
+
Parameters:
|
| 8 |
+
vtt_content: Raw VTT file content as bytes
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
Generator yielding tuples of (left, top, right, bottom, time_seconds)
|
| 12 |
+
"""
|
| 13 |
+
time_seconds = 0
|
| 14 |
+
left = top = right = bottom = None
|
| 15 |
+
|
| 16 |
+
for line in vtt_content.decode("utf-8").split("\n"):
|
| 17 |
+
line = line.strip()
|
| 18 |
+
|
| 19 |
+
if "-->" in line:
|
| 20 |
+
# grab the start time
|
| 21 |
+
# 00:00:00.000 --> 00:00:41.000
|
| 22 |
+
start = line.split("-->")[0].strip().split(":")
|
| 23 |
+
# convert to seconds
|
| 24 |
+
time_seconds = (
|
| 25 |
+
int(start[0]) * 3600
|
| 26 |
+
+ int(start[1]) * 60
|
| 27 |
+
+ float(start[2])
|
| 28 |
+
)
|
| 29 |
+
left = top = right = bottom = None
|
| 30 |
+
elif "xywh=" in line:
|
| 31 |
+
left, top, right, bottom = line.split("xywh=")[-1].split(",")
|
| 32 |
+
left, top, right, bottom = (
|
| 33 |
+
int(left),
|
| 34 |
+
int(top),
|
| 35 |
+
int(right),
|
| 36 |
+
int(bottom),
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
if not left:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
yield left, top, right, bottom, time_seconds
|
web/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# web package
|
web/interface.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
from models.data_manager import DataManager
|
| 5 |
+
from models.image_processor import (
|
| 6 |
+
image_search_performer,
|
| 7 |
+
image_search_performers,
|
| 8 |
+
find_faces_in_sprite
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
class WebInterface:
|
| 12 |
+
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the web interface.
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
data_manager: DataManager instance
|
| 18 |
+
default_threshold: Default confidence threshold
|
| 19 |
+
"""
|
| 20 |
+
self.data_manager = data_manager
|
| 21 |
+
self.default_threshold = default_threshold
|
| 22 |
+
|
| 23 |
+
def image_search(self, img, threshold, results):
|
| 24 |
+
"""Wrapper for the image search function"""
|
| 25 |
+
return image_search_performer(img, self.data_manager, threshold, results)
|
| 26 |
+
|
| 27 |
+
def multiple_image_search(self, img, threshold, results):
|
| 28 |
+
"""Wrapper for the multiple image search function"""
|
| 29 |
+
return image_search_performers(img, self.data_manager, threshold, results)
|
| 30 |
+
|
| 31 |
+
def vector_search(self, vector_json, threshold, results):
|
| 32 |
+
"""Wrapper for the vector search function (deprecated)"""
|
| 33 |
+
return {'status': 'not implemented'}
|
| 34 |
+
|
| 35 |
+
def _create_image_search_interface(self):
|
| 36 |
+
"""Create the single face search interface"""
|
| 37 |
+
with gr.Blocks() as interface:
|
| 38 |
+
gr.Markdown("# Who is in the photo?")
|
| 39 |
+
gr.Markdown("Upload an image of a person and we'll tell you who it is.")
|
| 40 |
+
|
| 41 |
+
with gr.Row():
|
| 42 |
+
with gr.Column():
|
| 43 |
+
img_input = gr.Image()
|
| 44 |
+
threshold = gr.Slider(
|
| 45 |
+
label="threshold",
|
| 46 |
+
minimum=0.0,
|
| 47 |
+
maximum=1.0,
|
| 48 |
+
value=self.default_threshold
|
| 49 |
+
)
|
| 50 |
+
results_count = gr.Slider(
|
| 51 |
+
label="results",
|
| 52 |
+
minimum=0,
|
| 53 |
+
maximum=50,
|
| 54 |
+
value=3,
|
| 55 |
+
step=1
|
| 56 |
+
)
|
| 57 |
+
search_btn = gr.Button("Search")
|
| 58 |
+
|
| 59 |
+
with gr.Column():
|
| 60 |
+
output = gr.JSON(label="Results")
|
| 61 |
+
|
| 62 |
+
search_btn.click(
|
| 63 |
+
fn=self.image_search,
|
| 64 |
+
inputs=[img_input, threshold, results_count],
|
| 65 |
+
outputs=output
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return interface
|
| 69 |
+
|
| 70 |
+
def _create_multiple_image_search_interface(self):
|
| 71 |
+
"""Create the multiple face search interface"""
|
| 72 |
+
with gr.Blocks() as interface:
|
| 73 |
+
gr.Markdown("# Who is in the photo?")
|
| 74 |
+
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
|
| 75 |
+
|
| 76 |
+
with gr.Row():
|
| 77 |
+
with gr.Column():
|
| 78 |
+
img_input = gr.Image(type="pil")
|
| 79 |
+
threshold = gr.Slider(
|
| 80 |
+
label="threshold",
|
| 81 |
+
minimum=0.0,
|
| 82 |
+
maximum=1.0,
|
| 83 |
+
value=self.default_threshold
|
| 84 |
+
)
|
| 85 |
+
results_count = gr.Slider(
|
| 86 |
+
label="results",
|
| 87 |
+
minimum=0,
|
| 88 |
+
maximum=50,
|
| 89 |
+
value=3,
|
| 90 |
+
step=1
|
| 91 |
+
)
|
| 92 |
+
search_btn = gr.Button("Search")
|
| 93 |
+
|
| 94 |
+
with gr.Column():
|
| 95 |
+
output = gr.JSON(label="Results")
|
| 96 |
+
|
| 97 |
+
search_btn.click(
|
| 98 |
+
fn=self.multiple_image_search,
|
| 99 |
+
inputs=[img_input, threshold, results_count],
|
| 100 |
+
outputs=output
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return interface
|
| 104 |
+
|
| 105 |
+
def _create_vector_search_interface(self):
|
| 106 |
+
"""Create the vector search interface (deprecated)"""
|
| 107 |
+
with gr.Blocks() as interface:
|
| 108 |
+
gr.Markdown("# Vector Search (deprecated)")
|
| 109 |
+
|
| 110 |
+
with gr.Row():
|
| 111 |
+
with gr.Column():
|
| 112 |
+
vector_input = gr.Textbox()
|
| 113 |
+
threshold = gr.Slider(
|
| 114 |
+
label="threshold",
|
| 115 |
+
minimum=0.0,
|
| 116 |
+
maximum=1.0,
|
| 117 |
+
value=self.default_threshold
|
| 118 |
+
)
|
| 119 |
+
results_count = gr.Slider(
|
| 120 |
+
label="results",
|
| 121 |
+
minimum=0,
|
| 122 |
+
maximum=50,
|
| 123 |
+
value=3,
|
| 124 |
+
step=1
|
| 125 |
+
)
|
| 126 |
+
search_btn = gr.Button("Search")
|
| 127 |
+
|
| 128 |
+
with gr.Column():
|
| 129 |
+
output = gr.JSON(label="Results")
|
| 130 |
+
|
| 131 |
+
search_btn.click(
|
| 132 |
+
fn=self.vector_search,
|
| 133 |
+
inputs=[vector_input, threshold, results_count],
|
| 134 |
+
outputs=output
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return interface
|
| 138 |
+
|
| 139 |
+
def _create_faces_in_sprite_interface(self):
|
| 140 |
+
"""Create the faces in sprite interface"""
|
| 141 |
+
with gr.Blocks() as interface:
|
| 142 |
+
gr.Markdown("# Find Faces in Sprite")
|
| 143 |
+
|
| 144 |
+
with gr.Row():
|
| 145 |
+
with gr.Column():
|
| 146 |
+
img_input = gr.Image()
|
| 147 |
+
vtt_input = gr.Textbox(label="VTT file")
|
| 148 |
+
search_btn = gr.Button("Process")
|
| 149 |
+
|
| 150 |
+
with gr.Column():
|
| 151 |
+
output = gr.JSON(label="Results")
|
| 152 |
+
|
| 153 |
+
search_btn.click(
|
| 154 |
+
fn=find_faces_in_sprite,
|
| 155 |
+
inputs=[img_input, vtt_input],
|
| 156 |
+
outputs=output
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return interface
|
| 160 |
+
|
| 161 |
+
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
|
| 162 |
+
"""Launch the web interface"""
|
| 163 |
+
with gr.Blocks() as demo:
|
| 164 |
+
with gr.Tabs() as tabs:
|
| 165 |
+
with gr.TabItem("Single Face Search"):
|
| 166 |
+
self._create_image_search_interface()
|
| 167 |
+
with gr.TabItem("Multiple Face Search"):
|
| 168 |
+
self._create_multiple_image_search_interface()
|
| 169 |
+
with gr.TabItem("Vector Search"):
|
| 170 |
+
self._create_vector_search_interface()
|
| 171 |
+
with gr.TabItem("Faces in Sprite"):
|
| 172 |
+
self._create_faces_in_sprite_interface()
|
| 173 |
+
|
| 174 |
+
demo.queue().launch(share=share, ssr_mode=False)
|