stashface / models /data_manager.py
cc1234
init
244b0b6
import os
import json
import pyzipper
from typing import Dict, Any, Optional
from voyager import Index, Space, StorageDataType
class DataManager:
def __init__(self, faces_path: str = "data/faces.json",
performers_zip: str = "data/persons.zip",
facenet_index_path: str = "data/face_facenet.voy",
arc_index_path: str = "data/face_arc.voy"):
"""
Initialize the data manager.
Parameters:
faces_path: Path to the faces.json file
performers_zip: Path to the performers zip file
facenet_index_path: Path to the facenet index file
arc_index_path: Path to the arc index file
"""
self.faces_path = faces_path
self.performers_zip = performers_zip
self.facenet_index_path = facenet_index_path
self.arc_index_path = arc_index_path
# Initialize indices
self.index_arc = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
self.index_facenet = Index(Space.Cosine, num_dimensions=512, storage_data_type=StorageDataType.E4M3)
# Load data
self.faces = {}
self.performer_db = {}
self.load_data()
def load_data(self):
"""Load all data from files"""
self._load_faces()
self._load_performer_db()
self._load_indices()
def _load_faces(self):
"""Load faces from JSON file"""
try:
with open(self.faces_path, 'r') as f:
self.faces = json.load(f)
except Exception as e:
print(f"Error loading faces: {e}")
self.faces = {}
def _load_performer_db(self):
"""Load performer database from encrypted zip file"""
try:
with pyzipper.AESZipFile(self.performers_zip) as zf:
password = os.getenv("VISAGE_KEY", "").encode('ascii')
zf.setpassword(password)
self.performer_db = json.loads(zf.read('performers.json'))
except Exception as e:
print(f"Error loading performer database: {e}")
self.performer_db = {}
def _load_indices(self):
"""Load face recognition indices"""
try:
with open(self.arc_index_path, 'rb') as f:
self.index_arc = self.index_arc.load(f)
with open(self.facenet_index_path, 'rb') as f:
self.index_facenet = self.index_facenet.load(f)
except Exception as e:
print(f"Error loading indices: {e}")
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
"""
Get performer information from the database
Parameters:
stash_id: Stash ID of the performer
confidence: Confidence score (0-1)
Returns:
Dictionary with performer information or None if not found
"""
performer = self.performer_db.get(stash_id, [])
if not performer:
return None
confidence_int = int(confidence * 100)
return {
'id': stash_id,
"name": performer['name'],
"confidence": confidence_int,
'image': performer['image'],
'country': performer['country'],
'hits': 1,
'distance': confidence_int,
'performer_url': f"https://stashdb.org/performers/{stash_id}"
}
def query_facenet_index(self, embedding, limit):
"""Query the facenet index with an embedding"""
return self.index_facenet.query(embedding, limit)
def query_arc_index(self, embedding, limit):
"""Query the arc index with an embedding"""
return self.index_arc.query(embedding, limit)