|
import gradio as gr |
|
from typing import Dict, Any |
|
|
|
from models.data_manager import DataManager |
|
from models.image_processor import ( |
|
image_search_performer, |
|
image_search_performers, |
|
find_faces_in_sprite |
|
) |
|
|
|
class WebInterface: |
|
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5): |
|
""" |
|
Initialize the web interface. |
|
|
|
Parameters: |
|
data_manager: DataManager instance |
|
default_threshold: Default confidence threshold |
|
""" |
|
self.data_manager = data_manager |
|
self.default_threshold = default_threshold |
|
|
|
def image_search(self, img, threshold, results): |
|
"""Wrapper for the image search function""" |
|
return image_search_performer(img, self.data_manager, threshold, results) |
|
|
|
def multiple_image_search(self, img, threshold, results): |
|
"""Wrapper for the multiple image search function""" |
|
return image_search_performers(img, self.data_manager, threshold, results) |
|
|
|
def vector_search(self, vector_json, threshold, results): |
|
"""Wrapper for the vector search function (deprecated)""" |
|
return {'status': 'not implemented'} |
|
|
|
def _create_image_search_interface(self): |
|
"""Create the single face search interface""" |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Who is in the photo?") |
|
gr.Markdown("Upload an image of a person and we'll tell you who it is.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image() |
|
threshold = gr.Slider( |
|
label="threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=self.default_threshold |
|
) |
|
results_count = gr.Slider( |
|
label="results", |
|
minimum=0, |
|
maximum=50, |
|
value=3, |
|
step=1 |
|
) |
|
search_btn = gr.Button("Search") |
|
|
|
with gr.Column(): |
|
output = gr.JSON(label="Results") |
|
|
|
search_btn.click( |
|
fn=self.image_search, |
|
inputs=[img_input, threshold, results_count], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
def _create_multiple_image_search_interface(self): |
|
"""Create the multiple face search interface""" |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Who is in the photo?") |
|
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type="pil") |
|
threshold = gr.Slider( |
|
label="threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=self.default_threshold |
|
) |
|
results_count = gr.Slider( |
|
label="results", |
|
minimum=0, |
|
maximum=50, |
|
value=3, |
|
step=1 |
|
) |
|
search_btn = gr.Button("Search") |
|
|
|
with gr.Column(): |
|
output = gr.JSON(label="Results") |
|
|
|
search_btn.click( |
|
fn=self.multiple_image_search, |
|
inputs=[img_input, threshold, results_count], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
def _create_vector_search_interface(self): |
|
"""Create the vector search interface (deprecated)""" |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Vector Search (deprecated)") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
vector_input = gr.Textbox() |
|
threshold = gr.Slider( |
|
label="threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=self.default_threshold |
|
) |
|
results_count = gr.Slider( |
|
label="results", |
|
minimum=0, |
|
maximum=50, |
|
value=3, |
|
step=1 |
|
) |
|
search_btn = gr.Button("Search") |
|
|
|
with gr.Column(): |
|
output = gr.JSON(label="Results") |
|
|
|
search_btn.click( |
|
fn=self.vector_search, |
|
inputs=[vector_input, threshold, results_count], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
def _create_faces_in_sprite_interface(self): |
|
"""Create the faces in sprite interface""" |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Find Faces in Sprite") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image() |
|
vtt_input = gr.Textbox(label="VTT file") |
|
search_btn = gr.Button("Process") |
|
|
|
with gr.Column(): |
|
output = gr.JSON(label="Results") |
|
|
|
search_btn.click( |
|
fn=find_faces_in_sprite, |
|
inputs=[img_input, vtt_input], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
def launch(self, server_name="0.0.0.0", server_port=7860, share=True): |
|
"""Launch the web interface""" |
|
with gr.Blocks() as demo: |
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("Single Face Search"): |
|
self._create_image_search_interface() |
|
with gr.TabItem("Multiple Face Search"): |
|
self._create_multiple_image_search_interface() |
|
with gr.TabItem("Vector Search"): |
|
self._create_vector_search_interface() |
|
with gr.TabItem("Faces in Sprite"): |
|
self._create_faces_in_sprite_interface() |
|
|
|
demo.queue().launch(share=share, ssr_mode=False) |
|
|