stashface / models /image_processor.py
cc1234
init
244b0b6
raw
history blame
4.77 kB
import io
import base64
import numpy as np
from uuid import uuid4
from PIL import Image as PILImage
from typing import List, Dict, Any, Tuple
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
from models.data_manager import DataManager
from utils.vtt_parser import parse_vtt_offsets
def get_face_predictions(face, ensemble, data_manager, results):
"""
Get predictions for a single face
Parameters:
face: Face image array
ensemble: EnsembleFaceRecognition instance
data_manager: DataManager instance
results: Number of results to return
Returns:
List of (name, confidence) tuples
"""
# Get embeddings for original and flipped images
embeddings_orig = ensemble.get_face_embeddings(face)
embeddings_flip = ensemble.get_face_embeddings(np.fliplr(face))
# Average the embeddings
facenet = np.mean([embeddings_orig['facenet'], embeddings_flip['facenet']], axis=0)
arc = np.mean([embeddings_orig['arc'], embeddings_flip['arc']], axis=0)
# Get predictions from both models
model_predictions = {
'facenet': data_manager.query_facenet_index(facenet, max(results, 50)),
'arc': data_manager.query_arc_index(arc, max(results, 50)),
}
return ensemble.ensemble_prediction(model_predictions)
def image_search_performer(image, data_manager, threshold=0.5, results=3):
"""
Search for a performer in an image
Parameters:
image: PIL Image object
data_manager: DataManager instance
threshold: Confidence threshold
results: Number of results to return
Returns:
List of performer information dictionaries
"""
image_array = np.array(image)
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
try:
faces = extract_faces(image_array)
except ValueError:
raise ValueError("No faces found")
predictions = get_face_predictions(faces[0]['face'], ensemble, data_manager, results)
response = []
for name, confidence in predictions:
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
if performer_info:
response.append(performer_info)
return response
def image_search_performers(image, data_manager, threshold=0.5, results=3):
"""
Search for multiple performers in an image
Parameters:
image: PIL Image object
data_manager: DataManager instance
threshold: Confidence threshold
results: Number of results to return
Returns:
List of dictionaries with face image and performer information
"""
image_array = np.array(image)
ensemble = EnsembleFaceRecognition({"facenet": 1.0, "arc": 1.0})
try:
faces = extract_faces(image_array)
except ValueError:
raise ValueError("No faces found")
response = []
for face in faces:
predictions = get_face_predictions(face['face'], ensemble, data_manager, results)
# Crop and encode face image
area = face['facial_area']
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
buf = io.BytesIO()
cimage.save(buf, format='JPEG')
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
# Get performer information
performers = []
for name, confidence in predictions:
performer_info = data_manager.get_performer_info(data_manager.faces[name], confidence)
if performer_info:
performers.append(performer_info)
response.append({
'image': im_b64,
'confidence': face['confidence'],
'performers': performers
})
return response
def find_faces_in_sprite(image, vtt_data):
"""
Find faces in a sprite image using VTT data
Parameters:
image: PIL Image object
vtt_data: Base64 encoded VTT data
Returns:
List of dictionaries with face information
"""
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
sprite = PILImage.fromarray(image)
results = []
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
cut_frame = sprite.crop((left, top, left + right, top + bottom))
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
faces = [face for face in faces if face['confidence'] > 0.6]
if faces:
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
results.append(data)
return results