product_ingredient_demo / category_matching.py
esilver's picture
More bug fixes
e314c06
raw
history blame
11 kB
import json
import numpy as np
import pickle
import os.path
from typing import Dict, List, Any, Tuple
from embeddings import create_product_embeddings
from similarity import compute_similarities
from utils import SafeProgress
import voyageai
# Update default path to be consistent
DEFAULT_CATEGORY_EMBEDDINGS_PATH = "data/category_embeddings.pickle"
def load_categories(file_path="categories.json") -> Dict[str, str]:
"""
Load categories from JSON file
Args:
file_path: Path to the categories JSON file
Returns:
Dictionary mapping category IDs to their descriptions
"""
try:
with open(file_path, 'r') as f:
categories_list = json.load(f)
# Convert to dictionary format with id as key and text as value
categories = {item["id"]: item["text"] for item in categories_list}
print(f"Loaded {len(categories)} categories")
return categories
except Exception as e:
print(f"Error loading categories: {e}")
return {}
def create_category_embeddings(categories: Dict[str, str], progress=None,
pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH,
force_regenerate=False) -> Dict[str, Any]:
"""
Create embeddings for category descriptions
Args:
categories: Dictionary mapping category IDs to their descriptions
progress: Optional progress tracking object
pickle_path: Path to the pickle file for caching embeddings
force_regenerate: If True, regenerate embeddings even if cache exists
Returns:
Dictionary mapping category IDs to their embeddings
"""
progress_tracker = SafeProgress(progress, desc="Generating category embeddings")
# Try to load embeddings from pickle file if it exists and force_regenerate is False
if not force_regenerate and os.path.exists(pickle_path):
progress_tracker(0.1, desc=f"Loading cached embeddings from {pickle_path}")
try:
with open(pickle_path, 'rb') as f:
category_embeddings = pickle.load(f)
progress_tracker(1.0, desc=f"Loaded embeddings for {len(category_embeddings)} categories from cache")
return category_embeddings
except Exception as e:
print(f"Error loading cached embeddings: {e}")
# Continue with generating new embeddings
progress_tracker(0.1, desc=f"Processing {len(categories)} categories")
# Extract descriptions to create embeddings
category_ids = list(categories.keys())
category_texts = list(categories.values())
# Use the same embedding function used for products
texts_with_embeddings = create_product_embeddings(category_texts, progress=progress)
# Map embeddings back to category IDs
category_embeddings = {}
for i, category_id in enumerate(category_ids):
if i < len(category_texts) and category_texts[i] in texts_with_embeddings:
category_embeddings[category_id] = texts_with_embeddings[category_texts[i]]
# Ensure the data directory exists
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
# Save embeddings to pickle file
progress_tracker(0.9, desc=f"Saving embeddings to {pickle_path}")
try:
with open(pickle_path, 'wb') as f:
pickle.dump(category_embeddings, f)
except Exception as e:
print(f"Error saving embeddings to pickle file: {e}")
progress_tracker(1.0, desc=f"Completed embeddings for {len(category_embeddings)} categories")
return category_embeddings
def load_category_embeddings(pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, Any]:
"""
Load pre-computed category embeddings from pickle file
Args:
pickle_path: Path to the pickle file with cached embeddings
Returns:
Dictionary mapping category IDs to their embeddings
"""
if os.path.exists(pickle_path):
try:
with open(pickle_path, 'rb') as f:
category_embeddings = pickle.load(f)
print(f"Loaded embeddings for {len(category_embeddings)} categories from {pickle_path}")
return category_embeddings
except Exception as e:
print(f"Error loading cached embeddings: {e}")
print(f"No embeddings found at {pickle_path}")
return {}
def match_products_to_categories(product_names: List[str], categories: Dict[str, str], top_n=5,
confidence_threshold=0.5, progress=None,
embeddings_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, List]:
"""
Match products to their most likely categories
Args:
product_names: List of product names to categorize
categories: Dictionary mapping category IDs to their descriptions
top_n: Number of top categories to return per product
confidence_threshold: Minimum similarity score to include
progress: Optional progress tracking object
embeddings_path: Path to pre-computed category embeddings
Returns:
Dictionary mapping products to their matched categories with scores
"""
progress_tracker = SafeProgress(progress, desc="Matching products to categories")
# Step 1: Load or create category embeddings
progress_tracker(0.2, desc="Loading category embeddings")
category_embeddings = load_category_embeddings(embeddings_path)
# If no embeddings were loaded, create them
if not category_embeddings:
progress_tracker(0.3, desc="Creating category embeddings")
category_embeddings = create_category_embeddings(categories, progress, pickle_path=embeddings_path)
# Step 2: Create product embeddings
progress_tracker(0.4, desc="Creating product embeddings")
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Step 3: Compute similarities between products and categories
progress_tracker(0.6, desc="Computing similarities")
similarities = compute_similarities(category_embeddings, product_embeddings)
# Process results
results = {}
progress_tracker(0.8, desc="Processing results")
for product, product_similarities in similarities.items():
# Take top N without filtering by threshold
top_categories = product_similarities[:top_n]
# Add category texts to the results
results[product] = [(category_id, categories.get(category_id, "Unknown"), score)
for category_id, score in top_categories]
progress_tracker(1.0, desc="Completed category matching")
return results
def hybrid_category_matching(products: List[str], categories: Dict[str, str],
embedding_top_n: int = 20, final_top_n: int = 5,
confidence_threshold: float = 0.5,
expanded_descriptions=None, progress=None) -> Dict[str, List[Tuple]]:
"""
Two-stage matching: first use embeddings to find candidates, then apply re-ranking
Args:
products: List of product names to categorize
categories: Dictionary mapping category IDs to their descriptions
embedding_top_n: Number of top categories to retrieve using embeddings
final_top_n: Number of final categories to return after re-ranking
confidence_threshold: Minimum score threshold for final results
expanded_descriptions: Optional dictionary of expanded product descriptions
progress: Optional progress tracking object
Returns:
Dictionary mapping products to their matched categories with scores
"""
progress_tracker = SafeProgress(progress, desc="Hybrid category matching")
progress_tracker(0.1, desc="Stage 1: Finding candidates with embeddings")
# Stage 1: Use embeddings to find candidate categories
embedding_results = match_products_to_categories(
products,
categories,
top_n=embedding_top_n, # Get more candidates from embeddings than we'll ultimately return
progress=progress_tracker
)
progress_tracker(0.4, desc="Stage 2: Re-ranking candidates")
# Initialize Voyage AI client
client = voyageai.Client()
# Stage 2: Re-rank the candidates for each product
final_results = {}
for i, product in enumerate(progress_tracker.tqdm(products, desc="Re-ranking product candidates")):
progress_tracker((0.4 + 0.5 * i / len(products)), desc=f"Re-ranking: {product}")
# Get the embedding candidates for this product
if product not in embedding_results:
final_results[product] = []
continue
candidates = embedding_results[product]
if not candidates:
final_results[product] = []
continue
# Extract just the category descriptions for re-ranking
candidate_ids = [c[0] for c in candidates]
candidate_texts = [f"{c[1]}" for c in candidates]
try:
# Apply re-ranking to the candidates
if expanded_descriptions and product in expanded_descriptions:
query = f"Which category best describes the product: {expanded_descriptions[product]}"
else:
query = f"Which category best describes the product: {product}"
print(f"Query: {query}")
reranking = client.rerank(
query=query,
documents=candidate_texts,
model="rerank-2",
top_k=final_top_n
)
# Process re-ranking results
product_categories = []
print(f"RERAANKING RESULTS: {reranking.results}")
for result in reranking.results:
# Find the category ID for this result
candidate_index = candidate_texts.index(result.document)
category_id = candidate_ids[candidate_index]
score = result.relevance_score
# Only include results above the confidence threshold
if score >= confidence_threshold:
product_categories.append((category_id, result.document, score))
print(f"Product: {product}")
print(f"Top 3 candidates before re-ranking: {candidates[:3]}")
print(f"Top 3 candidates after re-ranking: {product_categories[:3]}")
final_results[product] = product_categories
except Exception as e:
print(f"Error during re-ranking for '{product}': {e}")
# Fall back to embedding results if re-ranking fails
final_results[product] = candidates[:final_top_n]
progress_tracker(1.0, desc="Hybrid matching complete")
return final_results