stashface / web /interface.py
cc1234
init
244b0b6
import gradio as gr
from typing import Dict, Any
from models.data_manager import DataManager
from models.image_processor import (
image_search_performer,
image_search_performers,
find_faces_in_sprite
)
class WebInterface:
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
"""
Initialize the web interface.
Parameters:
data_manager: DataManager instance
default_threshold: Default confidence threshold
"""
self.data_manager = data_manager
self.default_threshold = default_threshold
def image_search(self, img, threshold, results):
"""Wrapper for the image search function"""
return image_search_performer(img, self.data_manager, threshold, results)
def multiple_image_search(self, img, threshold, results):
"""Wrapper for the multiple image search function"""
return image_search_performers(img, self.data_manager, threshold, results)
def vector_search(self, vector_json, threshold, results):
"""Wrapper for the vector search function (deprecated)"""
return {'status': 'not implemented'}
def _create_image_search_interface(self):
"""Create the single face search interface"""
with gr.Blocks() as interface:
gr.Markdown("# Who is in the photo?")
gr.Markdown("Upload an image of a person and we'll tell you who it is.")
with gr.Row():
with gr.Column():
img_input = gr.Image()
threshold = gr.Slider(
label="threshold",
minimum=0.0,
maximum=1.0,
value=self.default_threshold
)
results_count = gr.Slider(
label="results",
minimum=0,
maximum=50,
value=3,
step=1
)
search_btn = gr.Button("Search")
with gr.Column():
output = gr.JSON(label="Results")
search_btn.click(
fn=self.image_search,
inputs=[img_input, threshold, results_count],
outputs=output
)
return interface
def _create_multiple_image_search_interface(self):
"""Create the multiple face search interface"""
with gr.Blocks() as interface:
gr.Markdown("# Who is in the photo?")
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil")
threshold = gr.Slider(
label="threshold",
minimum=0.0,
maximum=1.0,
value=self.default_threshold
)
results_count = gr.Slider(
label="results",
minimum=0,
maximum=50,
value=3,
step=1
)
search_btn = gr.Button("Search")
with gr.Column():
output = gr.JSON(label="Results")
search_btn.click(
fn=self.multiple_image_search,
inputs=[img_input, threshold, results_count],
outputs=output
)
return interface
def _create_vector_search_interface(self):
"""Create the vector search interface (deprecated)"""
with gr.Blocks() as interface:
gr.Markdown("# Vector Search (deprecated)")
with gr.Row():
with gr.Column():
vector_input = gr.Textbox()
threshold = gr.Slider(
label="threshold",
minimum=0.0,
maximum=1.0,
value=self.default_threshold
)
results_count = gr.Slider(
label="results",
minimum=0,
maximum=50,
value=3,
step=1
)
search_btn = gr.Button("Search")
with gr.Column():
output = gr.JSON(label="Results")
search_btn.click(
fn=self.vector_search,
inputs=[vector_input, threshold, results_count],
outputs=output
)
return interface
def _create_faces_in_sprite_interface(self):
"""Create the faces in sprite interface"""
with gr.Blocks() as interface:
gr.Markdown("# Find Faces in Sprite")
with gr.Row():
with gr.Column():
img_input = gr.Image()
vtt_input = gr.Textbox(label="VTT file")
search_btn = gr.Button("Process")
with gr.Column():
output = gr.JSON(label="Results")
search_btn.click(
fn=find_faces_in_sprite,
inputs=[img_input, vtt_input],
outputs=output
)
return interface
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
"""Launch the web interface"""
with gr.Blocks() as demo:
with gr.Tabs() as tabs:
with gr.TabItem("Single Face Search"):
self._create_image_search_interface()
with gr.TabItem("Multiple Face Search"):
self._create_multiple_image_search_interface()
with gr.TabItem("Vector Search"):
self._create_vector_search_interface()
with gr.TabItem("Faces in Sprite"):
self._create_faces_in_sprite_interface()
demo.queue().launch(share=share, ssr_mode=False)