Spaces:
Running
Running
import os | |
import json | |
import pyzipper | |
from typing import Dict, Any, Optional | |
from voyager import Index, Space, StorageDataType | |
class DataManager: | |
def __init__(self, faces_path: str = "data/faces.json", | |
performers_zip: str = "data/persons.zip", | |
facenet_index_path: str = "data/face_facenet.voy", | |
arc_index_path: str = "data/face_arc.voy"): | |
""" | |
Initialize the data manager. | |
Parameters: | |
faces_path: Path to the faces.json file | |
performers_zip: Path to the performers zip file | |
facenet_index_path: Path to the facenet index file | |
arc_index_path: Path to the arc index file | |
""" | |
self.faces_path = faces_path | |
self.performers_zip = performers_zip | |
self.facenet_index_path = facenet_index_path | |
self.arc_index_path = arc_index_path | |
# Initialize indices | |
self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3) | |
self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3) | |
# Load data | |
self.faces = {} | |
self.performer_db = {} | |
self.load_data() | |
def load_data(self): | |
"""Load all data from files""" | |
self._load_faces() | |
self._load_performer_db() | |
self._load_indices() | |
def _load_faces(self): | |
"""Load faces from JSON file""" | |
try: | |
with open(self.faces_path, 'r') as f: | |
self.faces = json.load(f) | |
except Exception as e: | |
print(f"Error loading faces: {e}") | |
self.faces = {} | |
def _load_performer_db(self): | |
"""Load performer database from encrypted zip file""" | |
try: | |
with pyzipper.AESZipFile(self.performers_zip) as zf: | |
password = os.getenv("VISAGE_KEY", "").encode('ascii') | |
zf.setpassword(password) | |
self.performer_db = json.loads(zf.read('performers.json')) | |
except Exception as e: | |
print(f"Error loading performer database: {e}") | |
self.performer_db = {} | |
def _load_indices(self): | |
"""Load face recognition indices""" | |
try: | |
with open(self.arc_index_path, 'rb') as f: | |
self.index_arc = self.index_arc.load(f) | |
with open(self.facenet_index_path, 'rb') as f: | |
self.index_facenet = self.index_facenet.load(f) | |
except Exception as e: | |
print(f"Error loading indices: {e}") | |
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]: | |
""" | |
Get performer information from the database | |
Parameters: | |
stash_id: Stash ID of the performer | |
confidence: Confidence score (0-1) | |
Returns: | |
Dictionary with performer information or None if not found | |
""" | |
performer = self.performer_db.get(stash_id, []) | |
if not performer: | |
return None | |
confidence_int = int(confidence * 100) | |
return { | |
'id': stash_id, | |
"name": performer['name'], | |
"confidence": confidence_int, | |
'image': performer['image'], | |
'country': performer['country'], | |
'hits': 1, | |
'distance': confidence_int, | |
'performer_url': f"https://stashdb.org/performers/{stash_id}" | |
} | |
def query_facenet_index(self, embedding, limit): | |
"""Query the facenet index with an embedding""" | |
return self.index_facenet.query(embedding, limit) | |
def query_arc_index(self, embedding, limit): | |
"""Query the arc index with an embedding""" | |
return self.index_arc.query(embedding, limit) |