import gradio as gr
from utils import SafeProgress
from embeddings import create_product_embeddings
from similarity import compute_similarities
from openai_expansion import expand_product_descriptions
from ui_core import embeddings, parse_input, CATEGORY_EMBEDDINGS_PATH
from ui_formatters import format_reranking_results_html
from api_utils import get_openai_client, process_in_parallel, rank_ingredients_openai, rank_categories_openai
from category_matching import load_categories, load_category_embeddings
import json
def categorize_products_with_openai_reranking(product_input, is_file=False, use_expansion=False,
embedding_top_n=20, top_n=10, confidence_threshold=0.5,
match_type="ingredients", progress=gr.Progress()):
"""
Categorize products using OpenAI reranking with optional description expansion
"""
progress_tracker = SafeProgress(progress)
progress_tracker(0, desc="Starting OpenAI reranking...")
# Parse input
product_names, error = parse_input(product_input, is_file)
if error:
return error
# Validate embeddings are loaded if doing ingredient matching
if match_type == "ingredients" and not embeddings:
return "
Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.
"
# Optional description expansion
expanded_descriptions = {}
if use_expansion:
progress_tracker(0.2, desc="Expanding product descriptions...")
expanded_descriptions = expand_product_descriptions(product_names, progress=progress)
# Get shared OpenAI client
openai_client = get_openai_client()
products_for_embedding = ''
if match_type == "ingredients":
# Generate product embeddings
progress_tracker(0.4, desc="Generating product embeddings...")
if use_expansion and expanded_descriptions:
# Use expanded descriptions for embedding creation when available
products_for_embedding = [expanded_descriptions.get(name, name) for name in product_names]
# Map expanded descriptions back to original product names for consistent keys
product_embeddings = {}
temp_embeddings = create_product_embeddings(products_for_embedding, progress=progress)
# Ensure we use original product names as keys
for i, product_name in enumerate(product_names):
if i < len(products_for_embedding) and products_for_embedding[i] in temp_embeddings:
product_embeddings[product_name] = temp_embeddings[products_for_embedding[i]]
else:
# Standard embedding creation with just product names
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Compute embedding similarities for ingredients
progress_tracker(0.6, desc="Computing ingredient similarities...")
all_similarities = compute_similarities(embeddings, product_embeddings)
print(f"product_names: {product_names}")
print(f"products_for_embedding: {products_for_embedding}")
# print(f"all_similarities: {all_similarities}")
if not all_similarities:
return "Error: No similarities found. Please try different product names.
"
progress_tracker(0.7, desc="Re-ranking with OpenAI...")
# Function for processing each product
def process_reranking(product):
if product not in all_similarities:
return product, []
candidates = all_similarities[product][:embedding_top_n]
if not candidates:
return product, []
candidate_ingredients = [c[0] for c in candidates]
expanded_text = expanded_descriptions.get(product, product) if use_expansion else product
try:
# Use the shared utility function - now passing 0.0 as threshold to get all results
# We'll apply the threshold at display time
reranked_ingredients = rank_ingredients_openai(
product=product,
candidates=candidate_ingredients,
expanded_description=expanded_text,
client=openai_client,
model="gpt-4o-mini",
max_results=top_n,
confidence_threshold=0.0, # Don't filter here, do it at display time
debug=True
)
return product, reranked_ingredients
except Exception as e:
print(f"Error reranking {product}: {e}")
# Fall back to top embedding match
return product, candidates[:1] # Don't filter here
# Process all products in parallel
final_results = process_in_parallel(
items=product_names,
processor_func=process_reranking,
max_workers=min(10, len(product_names)),
progress_tracker=progress_tracker,
progress_start=0.7,
progress_end=0.9,
progress_desc="Re-ranking"
)
else: # categories
# Load category embeddings instead of JSON categories
progress_tracker(0.5, desc="Loading category embeddings...")
category_embeddings = load_category_embeddings()
if not category_embeddings:
return "Error: No category embeddings found. Please check that the embeddings file exists at data/category_embeddings.pickle.
"
# Generate product embeddings
progress_tracker(0.6, desc="Generating product embeddings...")
if use_expansion and expanded_descriptions:
# Use expanded descriptions for embedding creation when available
products_for_embedding = [expanded_descriptions.get(name, name) for name in product_names]
# Map expanded descriptions back to original product names for consistent keys
product_embeddings = {}
temp_embeddings = create_product_embeddings(products_for_embedding, progress=progress)
# Ensure we use original product names as keys
for i, product_name in enumerate(product_names):
if i < len(products_for_embedding) and products_for_embedding[i] in temp_embeddings:
product_embeddings[product_name] = temp_embeddings[products_for_embedding[i]]
else:
# Standard embedding creation with just product names
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Compute embedding similarities for categories
progress_tracker(0.7, desc="Computing category similarities...")
all_similarities = compute_similarities(category_embeddings, product_embeddings)
if not all_similarities:
return "Error: No category similarities found. Please try different product names.
"
# Collect all needed category IDs first - don't filter by threshold here
needed_category_ids = set()
for product, similarities in all_similarities.items():
for category_id, score in similarities[:embedding_top_n]:
needed_category_ids.add(category_id)
# Load only the needed categories from JSON
progress_tracker(0.75, desc="Loading category descriptions...")
category_descriptions = {}
if needed_category_ids:
try:
with open("categories.json", 'r') as f:
categories_list = json.load(f)
for item in categories_list:
if item["id"] in needed_category_ids:
category_descriptions[item["id"]] = item["text"]
except Exception as e:
print(f"Error loading category descriptions: {e}")
# Function to process each product
def process_category_matching(product):
if product not in all_similarities:
return product, []
candidates = all_similarities[product][:embedding_top_n]
print(f"candidates: {candidates}")
if not candidates:
return product, []
# Get the expanded description or use product name if no expansion
expanded_text = expanded_descriptions.get(product, product) if use_expansion else product
try:
# FIXED: Filter categories to only include those in the current product's candidates
product_category_ids = [cat_id for cat_id, _ in candidates]
filtered_categories = {cat_id: category_descriptions[cat_id]
for cat_id in product_category_ids
if cat_id in category_descriptions}
# Pass 0.0 as threshold to get all results - apply threshold at display time
category_matches = rank_categories_openai(
product=product,
categories=filtered_categories, # Pass only this product's relevant categories
expanded_description=expanded_text,
client=openai_client,
model="gpt-4o-mini",
max_results=top_n,
confidence_threshold=0.0, # Don't filter here
debug=True
)
# Format results with category descriptions if needed
formatted_matches = []
for category_id, score in category_matches:
category_text = category_descriptions.get(category_id, "Unknown category")
formatted_matches.append((category_id, category_text, score))
return product, formatted_matches
except Exception as e:
print(f"Error matching {product} to categories: {e}")
return product, []
# Process all products in parallel
final_results = process_in_parallel(
items=product_names,
processor_func=process_category_matching,
max_workers=min(10, len(product_names)),
progress_tracker=progress_tracker,
progress_start=0.7,
progress_end=0.9,
progress_desc="Category matching"
)
# Format results
progress_tracker(0.9, desc="Formatting results...")
# Create a list of result dictionaries in consistent format
formatted_results = []
for product, matches in final_results.items():
# Include all products, even with no matches
formatted_result = {
"product_name": product,
"confidence": max([item[-1] for item in matches]) if matches else 0,
"matching_items": [],
"item_scores": [], # Add item_scores to align with Voyage implementation
"explanation": expanded_descriptions.get(product, "") if use_expansion else ""
}
# Format matching items based on match type
if match_type == "ingredients":
formatted_result["matching_items"] = [item for item, score in matches]
formatted_result["item_scores"] = [score for item, score in matches]
else: # categories
for cat_id, cat_desc, score in matches:
formatted_result["matching_items"].append(
f"{cat_id}: {cat_desc}" if cat_desc else f"{cat_id}"
)
formatted_result["item_scores"].append(score)
formatted_results.append(formatted_result)
if not formatted_results:
return "No results found. Please check your input or try different products.
"
result_html = format_reranking_results_html(
results=formatted_results,
match_type=match_type,
show_scores=True,
include_explanation=use_expansion,
method="openai",
confidence_threshold=confidence_threshold # Pass the threshold to the formatter
)
progress_tracker(1.0, desc="Done!")
return result_html