Spaces:
Sleeping
Sleeping
import gradio as gr | |
from utils import SafeProgress | |
from embeddings import create_product_embeddings | |
from similarity import compute_similarities | |
from openai_expansion import expand_product_descriptions | |
from ui_core import embeddings, parse_input, CATEGORY_EMBEDDINGS_PATH | |
from ui_formatters import format_expanded_results_html, create_results_container | |
from api_utils import get_openai_client, process_in_parallel, rank_ingredients_openai, rank_categories_openai | |
from category_matching import load_categories, load_category_embeddings | |
import json | |
import os | |
def categorize_products_with_expansion(product_input, is_file=False, top_n=10, confidence_threshold=0.5, match_type="ingredients", progress=gr.Progress()): | |
""" | |
Categorize products using expanded descriptions from OpenAI | |
Args: | |
product_input: Text input with product names | |
is_file: Whether the input is a file | |
top_n: Number of top results to show | |
confidence_threshold: Confidence threshold for matches | |
match_type: Either "ingredients" or "categories" | |
progress: Progress tracking object | |
Returns: | |
HTML formatted results | |
""" | |
progress_tracker = SafeProgress(progress) | |
progress_tracker(0, desc="Starting...") | |
# Parse input | |
product_names, error = parse_input(product_input, is_file) | |
if error: | |
return error | |
# Validate embeddings are loaded if doing ingredient matching | |
if match_type == "ingredients" and not embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.</div>" | |
# Expand product descriptions | |
progress_tracker(0.2, desc="Expanding product descriptions...") | |
expanded_descriptions = expand_product_descriptions(product_names, progress=progress) | |
if not expanded_descriptions: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to expand product descriptions. Please try again or check your OpenAI API key.</div>" | |
# Get shared OpenAI client | |
openai_client = get_openai_client() | |
if match_type == "ingredients": | |
# Generate product embeddings | |
progress_tracker(0.4, desc="Generating product embeddings...") | |
product_embeddings = create_product_embeddings(product_names, progress=progress) | |
# Compute embedding similarities for ingredients | |
progress_tracker(0.6, desc="Computing ingredient similarities...") | |
all_similarities = compute_similarities(embeddings, product_embeddings) | |
if not all_similarities: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No similarities found. Please try different product names.</div>" | |
# Setup for OpenAI reranking | |
embedding_top_n = 20 # Number of candidates to consider from embeddings | |
progress_tracker(0.7, desc="Re-ranking with expanded descriptions...") | |
# Function for processing each product | |
def process_reranking(product): | |
if product not in all_similarities: | |
return product, [] | |
candidates = all_similarities[product][:embedding_top_n] | |
if not candidates: | |
return product, [] | |
candidate_ingredients = [c[0] for c in candidates] | |
expanded_text = expanded_descriptions.get(product, "") | |
try: | |
# Use the shared utility function | |
reranked_ingredients = rank_ingredients_openai( | |
product=product, | |
candidates=candidate_ingredients, | |
expanded_description=expanded_text, | |
client=openai_client, | |
model="o3-mini", | |
max_results=top_n, | |
confidence_threshold=confidence_threshold, | |
debug=True | |
) | |
return product, reranked_ingredients | |
except Exception as e: | |
print(f"Error reranking {product}: {e}") | |
# Fall back to top embedding match | |
return product, candidates[:1] if candidates[0][1] >= confidence_threshold else [] | |
# Process all products in parallel | |
final_results = process_in_parallel( | |
items=product_names, | |
processor_func=process_reranking, | |
max_workers=min(10, len(product_names)), | |
progress_tracker=progress_tracker, | |
progress_start=0.7, | |
progress_end=0.9, | |
progress_desc="Re-ranking" | |
) | |
else: # categories | |
# Load category embeddings instead of JSON categories | |
progress_tracker(0.5, desc="Loading category embeddings...") | |
category_embeddings = load_category_embeddings() | |
if not category_embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category embeddings found. Please check that the embeddings file exists at data/category_embeddings.pickle.</div>" | |
# Generate product embeddings | |
progress_tracker(0.6, desc="Generating product embeddings...") | |
product_embeddings = create_product_embeddings(product_names, progress=progress) | |
# Compute embedding similarities for categories | |
progress_tracker(0.7, desc="Computing category similarities...") | |
all_similarities = compute_similarities(category_embeddings, product_embeddings) | |
if not all_similarities: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category similarities found. Please try different product names.</div>" | |
embedding_top_n = min(20, top_n * 2) # Number of candidates to consider from embeddings | |
# Collect all needed category IDs first | |
needed_category_ids = set() | |
for product, similarities in all_similarities.items(): | |
for category_id, score in similarities[:embedding_top_n]: | |
if score >= confidence_threshold: | |
needed_category_ids.add(category_id) | |
# Load only the needed categories from JSON | |
progress_tracker(0.75, desc="Loading category descriptions...") | |
category_descriptions = {} | |
if needed_category_ids: | |
try: | |
with open("categories.json", 'r') as f: | |
categories_list = json.load(f) | |
for item in categories_list: | |
if item["id"] in needed_category_ids: | |
category_descriptions[item["id"]] = item["text"] | |
except Exception as e: | |
print(f"Error loading category descriptions: {e}") | |
# Function to process each product | |
def process_category_matching(product): | |
if product not in all_similarities: | |
return product, [] | |
candidates = all_similarities[product][:embedding_top_n] | |
if not candidates: | |
return product, [] | |
# Get the expanded description | |
expanded_text = expanded_descriptions.get(product, "") | |
try: | |
# Use rank_categories_openai instead of match_products_to_categories_with_description | |
category_matches = rank_categories_openai( | |
product=product, | |
categories=category_descriptions, | |
expanded_description=expanded_text, | |
client=openai_client, | |
# model="o3-mini", | |
model="gpt-4o-mini", | |
# model="gpt-4o", | |
max_results=top_n, | |
confidence_threshold=confidence_threshold, | |
debug=True | |
) | |
# Format results with category descriptions if needed | |
formatted_matches = [] | |
for category_id, score in category_matches: | |
category_text = category_descriptions.get(category_id, "Unknown category") | |
formatted_matches.append((category_id, category_text, score)) | |
return product, formatted_matches | |
except Exception as e: | |
print(f"Error matching {product} to categories: {e}") | |
return product, [] | |
# Process all products in parallel | |
final_results = process_in_parallel( | |
items=product_names, | |
processor_func=process_category_matching, | |
max_workers=min(10, len(product_names)), | |
progress_tracker=progress_tracker, | |
progress_start=0.7, | |
progress_end=0.9, | |
progress_desc="Category matching" | |
) | |
# Format results | |
progress_tracker(0.9, desc="Formatting results...") | |
result_elements = [] | |
for product, matches in final_results.items(): | |
result_elements.append( | |
format_expanded_results_html( | |
product=product, | |
results=matches, | |
expanded_description=expanded_descriptions.get(product, ""), | |
match_type=match_type | |
) | |
) | |
output_html = create_results_container( | |
result_elements, | |
header_text=f"Matched {len(product_names)} products to {match_type} using expanded descriptions." | |
) | |
if not final_results: | |
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>" | |
progress_tracker(1.0, desc="Done!") | |
return output_html | |