Spaces:
Sleeping
Sleeping
# import gradio as gr # Removed Gradio import | |
# from utils import SafeProgress # Removed SafeProgress import | |
from embeddings import create_product_embeddings | |
from similarity import compute_similarities | |
from openai_expansion import expand_product_descriptions | |
from ui_core import embeddings, parse_input, CATEGORY_EMBEDDINGS_PATH | |
from ui_formatters import format_reranking_results_html | |
from api_utils import get_openai_client, process_in_parallel, rank_ingredients_openai, rank_categories_openai | |
from category_matching import load_categories, load_category_embeddings | |
import json | |
import traceback # Import traceback for detailed error logging | |
def categorize_products_with_openai_reranking(product_input, is_file=False, use_expansion=False, | |
embedding_top_n=20, top_n=10, confidence_threshold=0.5, | |
match_type="ingredients"): # Removed progress parameter | |
""" | |
Categorize products using OpenAI reranking with optional description expansion | |
""" | |
# Removed Gradio progress tracking | |
# progress_tracker = SafeProgress(progress) | |
# progress_tracker(0, desc="Starting OpenAI reranking...") | |
# Parse input | |
product_names, error = parse_input(product_input, is_file) | |
if error: | |
return error | |
# Validate embeddings are loaded if doing ingredient matching | |
if match_type == "ingredients" and not embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.</div>" | |
# Optional description expansion | |
expanded_descriptions = {} | |
if use_expansion: | |
# progress_tracker(0.2, desc="Expanding product descriptions...") # Removed progress | |
try: | |
expanded_descriptions = expand_product_descriptions(product_names) # Removed progress argument | |
except Exception as e: | |
print(f"ERROR during description expansion: {e}") | |
print(traceback.format_exc()) | |
return f"<div style='color: red;'>Error during description expansion: {e}</div>" | |
# Get shared OpenAI client | |
openai_client = get_openai_client() | |
product_embeddings = {} # Initialize here for broader scope | |
all_similarities = {} # Initialize here | |
try: # Wrap embedding generation and similarity computation | |
if match_type == "ingredients": | |
# --- Ingredient Matching Logic --- | |
# Generate product embeddings | |
if use_expansion and expanded_descriptions: | |
products_for_embedding = [expanded_descriptions.get(name, name) for name in product_names] | |
temp_embeddings = create_product_embeddings(products_for_embedding, original_products=product_names) | |
# Correctly map using original product names as keys | |
for product_name in product_names: | |
if product_name in temp_embeddings: | |
product_embeddings[product_name] = temp_embeddings[product_name] | |
else: | |
product_embeddings = create_product_embeddings(product_names) | |
# Check if embeddings were successfully generated/mapped | |
if not product_embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to generate product embeddings for ingredients. Please try again.</div>" | |
# Compute embedding similarities for ingredients | |
all_similarities = compute_similarities(embeddings, product_embeddings) | |
else: # categories | |
# --- Category Matching Logic --- | |
category_embeddings = load_category_embeddings() | |
if not category_embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category embeddings found. Please check 'data/category_embeddings.pickle'.</div>" | |
# Generate product embeddings | |
if use_expansion and expanded_descriptions: | |
products_for_embedding = [expanded_descriptions.get(name, name) for name in product_names] | |
temp_embeddings = create_product_embeddings(products_for_embedding, original_products=product_names) | |
# Correctly map using original product names as keys | |
for product_name in product_names: | |
if product_name in temp_embeddings: | |
product_embeddings[product_name] = temp_embeddings[product_name] | |
else: | |
product_embeddings = create_product_embeddings(product_names) | |
# Check if embeddings were successfully generated/mapped | |
if not product_embeddings: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to generate product embeddings for categories. Please try again.</div>" | |
# Compute embedding similarities for categories | |
all_similarities = compute_similarities(category_embeddings, product_embeddings) | |
# --- Common Logic Post Similarity --- | |
if not all_similarities: | |
# This check might be redundant if product_embeddings check catches the issue earlier, but keep for safety | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No similarities found. Please try different product names.</div>" | |
except Exception as e: # Catch errors during embedding/similarity | |
print(f"ERROR during embedding generation or similarity computation: {e}") | |
print(traceback.format_exc()) | |
return f"<div style='color: red;'>Error during processing: {e}<br><pre>{traceback.format_exc()}</pre></div>" | |
# --- Reranking Logic --- | |
final_results = {} | |
if match_type == "ingredients": | |
# Function for processing each product (Ingredients) | |
def process_reranking_ingredients(product): | |
if product not in all_similarities: return product, [] | |
candidates = all_similarities[product][:embedding_top_n] | |
if not candidates: return product, [] | |
candidate_ingredients = [c[0] for c in candidates] | |
expanded_text = expanded_descriptions.get(product, product) if use_expansion else product | |
try: | |
reranked_ingredients = rank_ingredients_openai( | |
product=product, candidates=candidate_ingredients, expanded_description=expanded_text, | |
client=openai_client, model="gpt-4o-mini", max_results=top_n, | |
confidence_threshold=0.0, debug=True | |
) | |
return product, reranked_ingredients | |
except Exception as e: | |
print(f"Error reranking ingredients for {product}: {e}") | |
return product, candidates[:1] # Fallback | |
# Process all products in parallel | |
final_results = process_in_parallel( | |
items=product_names, processor_func=process_reranking_ingredients, | |
max_workers=min(10, len(product_names)) | |
) | |
else: # categories | |
# Load category descriptions needed for reranking | |
needed_category_ids = set() | |
for product, similarities in all_similarities.items(): | |
for category_id, score in similarities[:embedding_top_n]: | |
needed_category_ids.add(category_id) | |
category_descriptions = {} | |
if needed_category_ids: | |
try: | |
with open("categories.json", 'r') as f: | |
categories_list = json.load(f) | |
for item in categories_list: | |
if item["id"] in needed_category_ids: | |
category_descriptions[item["id"]] = item["text"] | |
except Exception as e: | |
print(f"Error loading category descriptions: {e}") # Non-fatal, continue without descriptions | |
# Function to process each product (Categories) | |
def process_reranking_categories(product): | |
if product not in all_similarities: return product, [] | |
candidates = all_similarities[product][:embedding_top_n] | |
if not candidates: return product, [] | |
product_category_ids = [cat_id for cat_id, _ in candidates] | |
filtered_categories = {cat_id: category_descriptions.get(cat_id, f"Category {cat_id}") # Use get with fallback | |
for cat_id in product_category_ids} | |
expanded_text = expanded_descriptions.get(product, product) if use_expansion else product | |
try: | |
category_matches = rank_categories_openai( | |
product=product, categories=filtered_categories, expanded_description=expanded_text, | |
client=openai_client, model="gpt-4o-mini", max_results=top_n, | |
confidence_threshold=0.0, debug=True | |
) | |
# Format results with category descriptions | |
formatted_matches = [] | |
for category_id, score in category_matches: | |
category_text = category_descriptions.get(category_id, "Unknown category") | |
formatted_matches.append((category_id, category_text, score)) | |
return product, formatted_matches | |
except Exception as e: | |
print(f"Error reranking categories for {product}: {e}") | |
# Fallback: Format top embedding candidates (without reranking score) | |
fallback_matches = [] | |
for cat_id, score in candidates[:1]: # Take top 1 embedding match as fallback | |
category_text = category_descriptions.get(cat_id, "Unknown category") | |
fallback_matches.append((cat_id, category_text, score)) # Use embedding score | |
return product, fallback_matches | |
# Process all products in parallel | |
final_results = process_in_parallel( | |
items=product_names, processor_func=process_reranking_categories, | |
max_workers=min(10, len(product_names)) | |
) | |
# --- Format final results --- | |
formatted_results = [] | |
for product, matches in final_results.items(): | |
formatted_result = { | |
"product_name": product, | |
"confidence": max([item[-1] for item in matches]) if matches else 0, | |
"matching_items": [], | |
"item_scores": [], | |
"explanation": expanded_descriptions.get(product, "") if use_expansion else "" | |
} | |
if match_type == "ingredients": | |
formatted_result["matching_items"] = [item for item, score in matches] | |
formatted_result["item_scores"] = [score for item, score in matches] | |
else: # categories | |
for cat_id, cat_desc, score in matches: | |
formatted_result["matching_items"].append(f"{cat_id}: {cat_desc}") | |
formatted_result["item_scores"].append(score) | |
formatted_results.append(formatted_result) | |
if not formatted_results: | |
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found after processing.</div>" | |
result_html = format_reranking_results_html( | |
results=formatted_results, | |
match_type=match_type, | |
show_scores=True, | |
include_explanation=use_expansion, | |
method="openai", | |
confidence_threshold=confidence_threshold | |
) | |
return result_html |