Spaces:
Sleeping
Sleeping
import gradio as gr | |
from utils import SafeProgress, format_categories_html | |
from embeddings import create_product_embeddings | |
from similarity import compute_similarities | |
# Global variable for embeddings | |
embeddings = {} | |
def categorize_products_from_text(product_text, top_n=5, confidence_threshold=0.5, progress=None): | |
"""Categorize products from text input (one product per line)""" | |
# Create a safe progress tracker | |
progress_tracker = SafeProgress(progress) | |
progress_tracker(0, desc="Starting...") | |
# Parse input text to get product names | |
product_names = [line.strip() for line in product_text.split("\n") if line.strip()] | |
if not product_names: | |
return "No product names provided." | |
# Create product embeddings | |
progress_tracker(0.1, desc="Generating product embeddings...") | |
products_embeddings = create_product_embeddings(product_names) | |
# Compute similarities | |
progress_tracker(0.6, desc="Computing similarities...") | |
all_similarities = compute_similarities(embeddings, products_embeddings) | |
# Format results | |
progress_tracker(0.9, desc="Formatting results...") | |
output_html = "<div style='font-family: Arial, sans-serif;'>" | |
for product, similarities in all_similarities.items(): | |
# Filter by confidence threshold and take top N | |
filtered_similarities = [(ingredient, score) for ingredient, score in similarities | |
if score >= confidence_threshold] | |
top_similarities = filtered_similarities[:top_n] | |
output_html += format_categories_html(product, top_similarities) | |
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>" | |
output_html += "</div>" | |
if not all_similarities: | |
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>" | |
progress_tracker(1.0, desc="Done!") | |
return output_html | |
def categorize_products_from_file(file, top_n=5, confidence_threshold=0.5, progress=None): | |
"""Categorize products from a JSON or text file""" | |
from utils import parse_product_file | |
# Create a safe progress tracker | |
progress_tracker = SafeProgress(progress) | |
progress_tracker(0.1, desc="Reading file...") | |
try: | |
product_names = parse_product_file(file.name) | |
except Exception as e: | |
return f"<div style='color: #d32f2f; font-weight: bold;'>Error: {str(e)}</div>" | |
if not product_names: | |
return "<div style='color: #d32f2f;'>No product names found in the file.</div>" | |
# Create product embeddings | |
progress_tracker(0.2, desc="Generating product embeddings...") | |
products_embeddings = create_product_embeddings(product_names) | |
# Compute similarities | |
progress_tracker(0.7, desc="Computing similarities...") | |
all_similarities = compute_similarities(embeddings, products_embeddings) | |
# Format results | |
progress_tracker(0.9, desc="Formatting results...") | |
output_html = f"<div style='font-family: Arial, sans-serif;'>" | |
output_html += f"<div style='margin-bottom: 20px; padding: 10px; background-color: #e8f5e9; border-radius: 5px;'>" | |
output_html += f"Found <b>{len(product_names)}</b> products in file. Showing results with confidence ≥ {confidence_threshold}." | |
output_html += "</div>" | |
for product, similarities in all_similarities.items(): | |
# Filter by confidence threshold and take top N | |
filtered_similarities = [(ingredient, score) for ingredient, score in similarities | |
if score >= confidence_threshold] | |
top_similarities = filtered_similarities[:top_n] | |
output_html += format_categories_html(product, top_similarities) | |
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>" | |
output_html += "</div>" | |
progress_tracker(1.0, desc="Done!") | |
return output_html | |
def create_demo(): | |
"""Create the Gradio interface""" | |
# Basic CSS theme | |
css = """ | |
.container { | |
max-width: 1200px; | |
margin: auto; | |
padding: 0; | |
} | |
footer {display: none !important;} | |
.header { | |
background-color: #0d47a1; | |
padding: 15px 20px; | |
border-radius: 10px; | |
color: white; | |
margin-bottom: 20px; | |
display: flex; | |
align-items: center; | |
} | |
.header svg { | |
margin-right: 10px; | |
height: 30px; | |
width: 30px; | |
} | |
.header h1 { | |
margin: 0; | |
font-size: 24px; | |
} | |
.description { | |
margin-bottom: 20px; | |
padding: 15px; | |
background-color: #f5f5f5; | |
border-radius: 5px; | |
} | |
""" | |
# Custom theme | |
theme = gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="indigo", | |
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"] | |
).set( | |
button_primary_background_fill="*primary_500", | |
button_primary_background_fill_hover="*primary_600", | |
button_secondary_background_fill="*neutral_200", | |
block_title_text_size="lg", | |
block_label_text_size="md" | |
) | |
with gr.Blocks(css=css, theme=theme) as demo: | |
# Header with icon | |
gr.HTML(""" | |
<div class="header"> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white"> | |
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"></path> | |
</svg> | |
<h1>Product Categorization Tool</h1> | |
</div> | |
<div class="description"> | |
This tool analyzes products and finds the most similar ingredients using AI embeddings. | |
Just enter product names or upload a file to get started. | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Text Input"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
example_products = [ | |
"Tomato Sauce\nApple Pie\nGreek Yogurt\nChocolate Chip Cookies", | |
"Banana Bread\nOrange Juice\nGrilled Chicken\nCaesar Salad", | |
"Vanilla Ice Cream\nPizza Dough\nStrawberry Jam\nGrilled Salmon" | |
] | |
text_input = gr.Textbox( | |
lines=10, | |
placeholder="Enter product names, one per line", | |
label="Product Names" | |
) | |
gr.Examples( | |
examples=example_products, | |
inputs=text_input, | |
label="Example Product Sets" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
top_n = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Number of Top Categories" | |
) | |
with gr.Column(scale=1): | |
confidence = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=0.5, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
submit_button = gr.Button("Categorize Products", variant="primary") | |
with gr.Column(scale=3): | |
text_output = gr.HTML(label="Categorization Results", | |
value="<div style='height: 450px; display: flex; justify-content: center; align-items: center; color: #666;'>Results will appear here</div>") | |
submit_button.click( | |
fn=categorize_products_from_text, | |
inputs=[text_input, top_n, confidence], | |
outputs=text_output | |
) | |
with gr.TabItem("File Upload"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
file_input = gr.File( | |
label="Upload JSON or text file with products", | |
file_types=[".json", ".txt"] | |
) | |
with gr.Accordion("Help", open=False): | |
gr.Markdown(""" | |
- JSON files should contain either: | |
- A list of objects with a 'name' field for each product | |
- A simple array of product name strings | |
- Text files should have one product name per line | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_top_n = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Number of Top Categories" | |
) | |
with gr.Column(scale=1): | |
file_confidence = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=0.5, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
file_button = gr.Button("Process File", variant="primary") | |
with gr.Column(scale=3): | |
file_output = gr.HTML( | |
label="Categorization Results", | |
value="<div style='height: 450px; display: flex; justify-content: center; align-items: center; color: #666;'>Upload a file to see results</div>" | |
) | |
file_button.click( | |
fn=categorize_products_from_file, | |
inputs=[file_input, file_top_n, file_confidence], | |
outputs=file_output | |
) | |
# Footer | |
gr.HTML(""" | |
<div style="margin-top: 20px; text-align: center; color: #666;"> | |
Powered by Voyage AI embeddings • Built with Gradio | |
</div> | |
""") | |
return demo | |