stashface / models /image_processor.py
cc1234
init
244b0b6
import io
import base64
import numpy as np
from uuid import uuid4
from PIL import Image as PILImage
from typing import List, Dict, Any, Tuple
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
from models.data_manager import DataManager
from utils.vtt_parser import parse_vtt_offsets
def get_face_predictions(face, ensemble, data_manager, results):
"""
Get predictions for a single face
Parameters:
face: Face image array
ensemble: EnsembleFaceRecognition instance
data_manager: DataManager instance
results: Number of results to return
Returns:
List of (name, confidence) tuples
"""
# Get embeddings for original and flipped images
embeddings_orig = ensemble.get_face_embeddings(face)
embeddings_flip = ensemble.get_face_embeddings(np.fliplr(face))
# Average the embeddings
facenet = np.mean([embeddings_orig['facenet'], embeddings_flip['facenet']], axis=0)
arc = np.mean([embeddings_orig['arc'], embeddings_flip['arc']], axis=0)
# Get predictions from both models
model_predictions = {
'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
'arc': data_manager.query_arc_index(arc, max(results, 50)),
}
return ensemble.ensemble_prediction(model_predictions)
def image_search_performer(image, data_manager, threshold=0.5, results=3):
"""
Search for a performer in an image
Parameters:
image: PIL Image object
data_manager: DataManager instance
threshold: Confidence threshold
results: Number of results to return
Returns:
List of performer information dictionaries
"""
image_array = np.array(image)
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
try:
faces = extract_faces(image_array)
except ValueError:
raise ValueError("No faces found")
predictions = get_face_predictions(faces[0]['face'], ensemble, data_manager, results)
response = []
for name, confidence in predictions:
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
if performer_info:
response.append(performer_info)
return response
def image_search_performers(image, data_manager, threshold=0.5, results=3):
"""
Search for multiple performers in an image
Parameters:
image: PIL Image object
data_manager: DataManager instance
threshold: Confidence threshold
results: Number of results to return
Returns:
List of dictionaries with face image and performer information
"""
image_array = np.array(image)
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
try:
faces = extract_faces(image_array)
except ValueError:
raise ValueError("No faces found")
response = []
for face in faces:
predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
# Crop and encode face image
area = face['facial_area']
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
buf = io.BytesIO()
cimage.save(buf, format='JPEG')
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
# Get performer information
performers = []
for name, confidence in predictions:
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
if performer_info:
performers.append(performer_info)
response.append({
'image': im_b64,
'confidence': face['confidence'],
'performers': performers
})
return response
def find_faces_in_sprite(image, vtt_data):
"""
Find faces in a sprite image using VTT data
Parameters:
image: PIL Image object
vtt_data: Base64 encoded VTT data
Returns:
List of dictionaries with face information
"""
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
sprite = PILImage.fromarray(image)
results = []
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
cut_frame = sprite.crop((left, top, left + right, top + bottom))
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
faces = [face for face in faces if face['confidence'] > 0.6]
if faces:
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
results.append(data)
return results