product_ingredient_demo / ui_expanded_matching.py
esilver's picture
Initial commit
31ebc8b
raw
history blame
9.91 kB
import gradio as gr
from utils import SafeProgress
from embeddings import create_product_embeddings
from similarity import compute_similarities
from openai_expansion import expand_product_descriptions
from ui_core import embeddings, parse_input, CATEGORY_EMBEDDINGS_PATH
from ui_formatters import format_expanded_results_html, create_results_container
from api_utils import get_openai_client, process_in_parallel, rank_ingredients_openai, rank_categories_openai
from category_matching import load_categories, load_category_embeddings
import json
import os
def categorize_products_with_expansion(product_input, is_file=False, top_n=10, confidence_threshold=0.5, match_type="ingredients", progress=gr.Progress()):
"""
Categorize products using expanded descriptions from OpenAI
Args:
product_input: Text input with product names
is_file: Whether the input is a file
top_n: Number of top results to show
confidence_threshold: Confidence threshold for matches
match_type: Either "ingredients" or "categories"
progress: Progress tracking object
Returns:
HTML formatted results
"""
progress_tracker = SafeProgress(progress)
progress_tracker(0, desc="Starting...")
# Parse input
product_names, error = parse_input(product_input, is_file)
if error:
return error
# Validate embeddings are loaded if doing ingredient matching
if match_type == "ingredients" and not embeddings:
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.</div>"
# Expand product descriptions
progress_tracker(0.2, desc="Expanding product descriptions...")
expanded_descriptions = expand_product_descriptions(product_names, progress=progress)
if not expanded_descriptions:
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to expand product descriptions. Please try again or check your OpenAI API key.</div>"
# Get shared OpenAI client
openai_client = get_openai_client()
if match_type == "ingredients":
# Generate product embeddings
progress_tracker(0.4, desc="Generating product embeddings...")
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Compute embedding similarities for ingredients
progress_tracker(0.6, desc="Computing ingredient similarities...")
all_similarities = compute_similarities(embeddings, product_embeddings)
if not all_similarities:
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No similarities found. Please try different product names.</div>"
# Setup for OpenAI reranking
embedding_top_n = 20 # Number of candidates to consider from embeddings
progress_tracker(0.7, desc="Re-ranking with expanded descriptions...")
# Function for processing each product
def process_reranking(product):
if product not in all_similarities:
return product, []
candidates = all_similarities[product][:embedding_top_n]
if not candidates:
return product, []
candidate_ingredients = [c[0] for c in candidates]
expanded_text = expanded_descriptions.get(product, "")
try:
# Use the shared utility function
reranked_ingredients = rank_ingredients_openai(
product=product,
candidates=candidate_ingredients,
expanded_description=expanded_text,
client=openai_client,
model="o3-mini",
max_results=top_n,
confidence_threshold=confidence_threshold,
debug=True
)
return product, reranked_ingredients
except Exception as e:
print(f"Error reranking {product}: {e}")
# Fall back to top embedding match
return product, candidates[:1] if candidates[0][1] >= confidence_threshold else []
# Process all products in parallel
final_results = process_in_parallel(
items=product_names,
processor_func=process_reranking,
max_workers=min(10, len(product_names)),
progress_tracker=progress_tracker,
progress_start=0.7,
progress_end=0.9,
progress_desc="Re-ranking"
)
else: # categories
# Load category embeddings instead of JSON categories
progress_tracker(0.5, desc="Loading category embeddings...")
category_embeddings = load_category_embeddings()
if not category_embeddings:
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category embeddings found. Please check that the embeddings file exists at data/category_embeddings.pickle.</div>"
# Generate product embeddings
progress_tracker(0.6, desc="Generating product embeddings...")
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Compute embedding similarities for categories
progress_tracker(0.7, desc="Computing category similarities...")
all_similarities = compute_similarities(category_embeddings, product_embeddings)
if not all_similarities:
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category similarities found. Please try different product names.</div>"
embedding_top_n = min(20, top_n * 2) # Number of candidates to consider from embeddings
# Collect all needed category IDs first
needed_category_ids = set()
for product, similarities in all_similarities.items():
for category_id, score in similarities[:embedding_top_n]:
if score >= confidence_threshold:
needed_category_ids.add(category_id)
# Load only the needed categories from JSON
progress_tracker(0.75, desc="Loading category descriptions...")
category_descriptions = {}
if needed_category_ids:
try:
with open("categories.json", 'r') as f:
categories_list = json.load(f)
for item in categories_list:
if item["id"] in needed_category_ids:
category_descriptions[item["id"]] = item["text"]
except Exception as e:
print(f"Error loading category descriptions: {e}")
# Function to process each product
def process_category_matching(product):
if product not in all_similarities:
return product, []
candidates = all_similarities[product][:embedding_top_n]
if not candidates:
return product, []
# Get the expanded description
expanded_text = expanded_descriptions.get(product, "")
try:
# Use rank_categories_openai instead of match_products_to_categories_with_description
category_matches = rank_categories_openai(
product=product,
categories=category_descriptions,
expanded_description=expanded_text,
client=openai_client,
# model="o3-mini",
model="gpt-4o-mini",
# model="gpt-4o",
max_results=top_n,
confidence_threshold=confidence_threshold,
debug=True
)
# Format results with category descriptions if needed
formatted_matches = []
for category_id, score in category_matches:
category_text = category_descriptions.get(category_id, "Unknown category")
formatted_matches.append((category_id, category_text, score))
return product, formatted_matches
except Exception as e:
print(f"Error matching {product} to categories: {e}")
return product, []
# Process all products in parallel
final_results = process_in_parallel(
items=product_names,
processor_func=process_category_matching,
max_workers=min(10, len(product_names)),
progress_tracker=progress_tracker,
progress_start=0.7,
progress_end=0.9,
progress_desc="Category matching"
)
# Format results
progress_tracker(0.9, desc="Formatting results...")
result_elements = []
for product, matches in final_results.items():
result_elements.append(
format_expanded_results_html(
product=product,
results=matches,
expanded_description=expanded_descriptions.get(product, ""),
match_type=match_type
)
)
output_html = create_results_container(
result_elements,
header_text=f"Matched {len(product_names)} products to {match_type} using expanded descriptions."
)
if not final_results:
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
progress_tracker(1.0, desc="Done!")
return output_html