esilver's picture
refactored
a318724
raw
history blame
11 kB
import gradio as gr
from utils import SafeProgress, format_categories_html
from embeddings import create_product_embeddings
from similarity import compute_similarities
# Global variable for embeddings
embeddings = {}
def categorize_products_from_text(product_text, top_n=5, confidence_threshold=0.5, progress=None):
"""Categorize products from text input (one product per line)"""
# Create a safe progress tracker
progress_tracker = SafeProgress(progress)
progress_tracker(0, desc="Starting...")
# Parse input text to get product names
product_names = [line.strip() for line in product_text.split("\n") if line.strip()]
if not product_names:
return "No product names provided."
# Create product embeddings
progress_tracker(0.1, desc="Generating product embeddings...")
products_embeddings = create_product_embeddings(product_names)
# Compute similarities
progress_tracker(0.6, desc="Computing similarities...")
all_similarities = compute_similarities(embeddings, products_embeddings)
# Format results
progress_tracker(0.9, desc="Formatting results...")
output_html = "<div style='font-family: Arial, sans-serif;'>"
for product, similarities in all_similarities.items():
# Filter by confidence threshold and take top N
filtered_similarities = [(ingredient, score) for ingredient, score in similarities
if score >= confidence_threshold]
top_similarities = filtered_similarities[:top_n]
output_html += format_categories_html(product, top_similarities)
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>"
output_html += "</div>"
if not all_similarities:
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
progress_tracker(1.0, desc="Done!")
return output_html
def categorize_products_from_file(file, top_n=5, confidence_threshold=0.5, progress=None):
"""Categorize products from a JSON or text file"""
from utils import parse_product_file
# Create a safe progress tracker
progress_tracker = SafeProgress(progress)
progress_tracker(0.1, desc="Reading file...")
try:
product_names = parse_product_file(file.name)
except Exception as e:
return f"<div style='color: #d32f2f; font-weight: bold;'>Error: {str(e)}</div>"
if not product_names:
return "<div style='color: #d32f2f;'>No product names found in the file.</div>"
# Create product embeddings
progress_tracker(0.2, desc="Generating product embeddings...")
products_embeddings = create_product_embeddings(product_names)
# Compute similarities
progress_tracker(0.7, desc="Computing similarities...")
all_similarities = compute_similarities(embeddings, products_embeddings)
# Format results
progress_tracker(0.9, desc="Formatting results...")
output_html = f"<div style='font-family: Arial, sans-serif;'>"
output_html += f"<div style='margin-bottom: 20px; padding: 10px; background-color: #e8f5e9; border-radius: 5px;'>"
output_html += f"Found <b>{len(product_names)}</b> products in file. Showing results with confidence ≥ {confidence_threshold}."
output_html += "</div>"
for product, similarities in all_similarities.items():
# Filter by confidence threshold and take top N
filtered_similarities = [(ingredient, score) for ingredient, score in similarities
if score >= confidence_threshold]
top_similarities = filtered_similarities[:top_n]
output_html += format_categories_html(product, top_similarities)
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>"
output_html += "</div>"
progress_tracker(1.0, desc="Done!")
return output_html
def create_demo():
"""Create the Gradio interface"""
# Basic CSS theme
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 0;
}
footer {display: none !important;}
.header {
background-color: #0d47a1;
padding: 15px 20px;
border-radius: 10px;
color: white;
margin-bottom: 20px;
display: flex;
align-items: center;
}
.header svg {
margin-right: 10px;
height: 30px;
width: 30px;
}
.header h1 {
margin: 0;
font-size: 24px;
}
.description {
margin-bottom: 20px;
padding: 15px;
background-color: #f5f5f5;
border-radius: 5px;
}
"""
# Custom theme
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"]
).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
button_secondary_background_fill="*neutral_200",
block_title_text_size="lg",
block_label_text_size="md"
)
with gr.Blocks(css=css, theme=theme) as demo:
# Header with icon
gr.HTML("""
<div class="header">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white">
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"></path>
</svg>
<h1>Product Categorization Tool</h1>
</div>
<div class="description">
This tool analyzes products and finds the most similar ingredients using AI embeddings.
Just enter product names or upload a file to get started.
</div>
""")
with gr.Tabs():
with gr.TabItem("Text Input"):
with gr.Row():
with gr.Column(scale=2):
example_products = [
"Tomato Sauce\nApple Pie\nGreek Yogurt\nChocolate Chip Cookies",
"Banana Bread\nOrange Juice\nGrilled Chicken\nCaesar Salad",
"Vanilla Ice Cream\nPizza Dough\nStrawberry Jam\nGrilled Salmon"
]
text_input = gr.Textbox(
lines=10,
placeholder="Enter product names, one per line",
label="Product Names"
)
gr.Examples(
examples=example_products,
inputs=text_input,
label="Example Product Sets"
)
with gr.Row():
with gr.Column(scale=1):
top_n = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Top Categories"
)
with gr.Column(scale=1):
confidence = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.05,
label="Confidence Threshold"
)
submit_button = gr.Button("Categorize Products", variant="primary")
with gr.Column(scale=3):
text_output = gr.HTML(label="Categorization Results",
value="<div style='height: 450px; display: flex; justify-content: center; align-items: center; color: #666;'>Results will appear here</div>")
submit_button.click(
fn=categorize_products_from_text,
inputs=[text_input, top_n, confidence],
outputs=text_output
)
with gr.TabItem("File Upload"):
with gr.Row():
with gr.Column(scale=2):
file_input = gr.File(
label="Upload JSON or text file with products",
file_types=[".json", ".txt"]
)
with gr.Accordion("Help", open=False):
gr.Markdown("""
- JSON files should contain either:
- A list of objects with a 'name' field for each product
- A simple array of product name strings
- Text files should have one product name per line
""")
with gr.Row():
with gr.Column(scale=1):
file_top_n = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Top Categories"
)
with gr.Column(scale=1):
file_confidence = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.05,
label="Confidence Threshold"
)
file_button = gr.Button("Process File", variant="primary")
with gr.Column(scale=3):
file_output = gr.HTML(
label="Categorization Results",
value="<div style='height: 450px; display: flex; justify-content: center; align-items: center; color: #666;'>Upload a file to see results</div>"
)
file_button.click(
fn=categorize_products_from_file,
inputs=[file_input, file_top_n, file_confidence],
outputs=file_output
)
# Footer
gr.HTML("""
<div style="margin-top: 20px; text-align: center; color: #666;">
Powered by Voyage AI embeddings • Built with Gradio
</div>
""")
return demo