product_ingredient_demo / category_matching.py
esilver's picture
some consilidation
9a56a50
raw
history blame
11.1 kB
import json
import numpy as np
import pickle
import os.path
from typing import Dict, List, Any, Tuple
from embeddings import create_product_embeddings
from similarity import compute_similarities
from utils import SafeProgress
import voyageai
# Update default path to be consistent
DEFAULT_CATEGORY_EMBEDDINGS_PATH = "data/category_embeddings.pickle"
def load_categories(file_path="categories.json") -> Dict[str, str]:
"""
Load categories from JSON file
Args:
file_path: Path to the categories JSON file
Returns:
Dictionary mapping category IDs to their descriptions
"""
try:
with open(file_path, 'r') as f:
categories_list = json.load(f)
# Convert to dictionary format with id as key and text as value
categories = {item["id"]: item["text"] for item in categories_list}
print(f"Loaded {len(categories)} categories")
return categories
except Exception as e:
print(f"Error loading categories: {e}")
return {}
def create_category_embeddings(categories: Dict[str, str], progress=None,
pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH,
force_regenerate=False) -> Dict[str, Any]:
"""
Create embeddings for category descriptions
Args:
categories: Dictionary mapping category IDs to their descriptions
progress: Optional progress tracking object
pickle_path: Path to the pickle file for caching embeddings
force_regenerate: If True, regenerate embeddings even if cache exists
Returns:
Dictionary mapping category IDs to their embeddings
"""
progress_tracker = SafeProgress(progress, desc="Generating category embeddings")
# Try to load embeddings from pickle file if it exists and force_regenerate is False
if not force_regenerate and os.path.exists(pickle_path):
progress_tracker(0.1, desc=f"Loading cached embeddings from {pickle_path}")
try:
with open(pickle_path, 'rb') as f:
category_embeddings = pickle.load(f)
progress_tracker(1.0, desc=f"Loaded embeddings for {len(category_embeddings)} categories from cache")
return category_embeddings
except Exception as e:
print(f"Error loading cached embeddings: {e}")
# Continue with generating new embeddings
progress_tracker(0.1, desc=f"Processing {len(categories)} categories")
# Extract descriptions to create embeddings
category_ids = list(categories.keys())
category_texts = list(categories.values())
# Use the same embedding function used for products
texts_with_embeddings = create_product_embeddings(category_texts, progress=progress)
# Map embeddings back to category IDs
category_embeddings = {}
for i, category_id in enumerate(category_ids):
if i < len(category_texts) and category_texts[i] in texts_with_embeddings:
category_embeddings[category_id] = texts_with_embeddings[category_texts[i]]
# Ensure the data directory exists
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
# Save embeddings to pickle file
progress_tracker(0.9, desc=f"Saving embeddings to {pickle_path}")
try:
with open(pickle_path, 'wb') as f:
pickle.dump(category_embeddings, f)
except Exception as e:
print(f"Error saving embeddings to pickle file: {e}")
progress_tracker(1.0, desc=f"Completed embeddings for {len(category_embeddings)} categories")
return category_embeddings
def load_category_embeddings(pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, Any]:
"""
Load pre-computed category embeddings from pickle file
Args:
pickle_path: Path to the pickle file with cached embeddings
Returns:
Dictionary mapping category IDs to their embeddings
"""
if os.path.exists(pickle_path):
try:
with open(pickle_path, 'rb') as f:
category_embeddings = pickle.load(f)
print(f"Loaded embeddings for {len(category_embeddings)} categories from {pickle_path}")
return category_embeddings
except Exception as e:
print(f"Error loading cached embeddings: {e}")
print(f"No embeddings found at {pickle_path}")
return {}
def match_products_to_categories(product_names: List[str], categories: Dict[str, str], top_n=5,
confidence_threshold=0.5, progress=None,
embeddings_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, List]:
"""
Match products to their most likely categories
Args:
product_names: List of product names to categorize
categories: Dictionary mapping category IDs to their descriptions
top_n: Number of top categories to return per product
confidence_threshold: Minimum similarity score to include
progress: Optional progress tracking object
embeddings_path: Path to pre-computed category embeddings
Returns:
Dictionary mapping products to their matched categories with scores
"""
progress_tracker = SafeProgress(progress, desc="Matching products to categories")
# Step 1: Load or create category embeddings
progress_tracker(0.2, desc="Loading category embeddings")
category_embeddings = load_category_embeddings(embeddings_path)
# If no embeddings were loaded, create them
if not category_embeddings:
progress_tracker(0.3, desc="Creating category embeddings")
category_embeddings = create_category_embeddings(categories, progress, pickle_path=embeddings_path)
# Step 2: Create product embeddings
progress_tracker(0.4, desc="Creating product embeddings")
product_embeddings = create_product_embeddings(product_names, progress=progress)
# Step 3: Compute similarities between products and categories
progress_tracker(0.6, desc="Computing similarities")
similarities = compute_similarities(category_embeddings, product_embeddings)
# Process results
results = {}
progress_tracker(0.8, desc="Processing results")
for product, product_similarities in similarities.items():
# Filter by threshold and take top N
filtered_categories = [(category_id, score)
for category_id, score in product_similarities
if score >= confidence_threshold]
top_categories = filtered_categories[:top_n]
# Add category texts to the results
results[product] = [(category_id, categories.get(category_id, "Unknown"), score)
for category_id, score in top_categories]
progress_tracker(1.0, desc="Completed category matching")
return results
def hybrid_category_matching(products: List[str], categories: Dict[str, str],
embedding_top_n: int = 20, final_top_n: int = 5,
confidence_threshold: float = 0.5,
expanded_descriptions=None, progress=None) -> Dict[str, List[Tuple]]:
"""
Two-stage matching: first use embeddings to find candidates, then apply re-ranking
Args:
products: List of product names to categorize
categories: Dictionary mapping category IDs to their descriptions
embedding_top_n: Number of top categories to retrieve using embeddings
final_top_n: Number of final categories to return after re-ranking
confidence_threshold: Minimum score threshold for final results
expanded_descriptions: Optional dictionary of expanded product descriptions
progress: Optional progress tracking object
Returns:
Dictionary mapping products to their matched categories with scores
"""
progress_tracker = SafeProgress(progress, desc="Hybrid category matching")
progress_tracker(0.1, desc="Stage 1: Finding candidates with embeddings")
# Stage 1: Use embeddings to find candidate categories
embedding_results = match_products_to_categories(
products,
categories,
top_n=embedding_top_n, # Get more candidates from embeddings than we'll ultimately return
progress=progress_tracker
)
progress_tracker(0.4, desc="Stage 2: Re-ranking candidates")
# Initialize Voyage AI client
client = voyageai.Client()
# Stage 2: Re-rank the candidates for each product
final_results = {}
for i, product in enumerate(progress_tracker.tqdm(products, desc="Re-ranking product candidates")):
progress_tracker((0.4 + 0.5 * i / len(products)), desc=f"Re-ranking: {product}")
# Get the embedding candidates for this product
if product not in embedding_results:
final_results[product] = []
continue
candidates = embedding_results[product]
if not candidates:
final_results[product] = []
continue
# Extract just the category descriptions for re-ranking
candidate_ids = [c[0] for c in candidates]
candidate_texts = [f"{c[1]}" for c in candidates]
try:
# Apply re-ranking to the candidates
if expanded_descriptions and product in expanded_descriptions:
query = f"Which category best describes the product: {expanded_descriptions[product]}"
else:
query = f"Which category best describes the product: {product}"
reranking = client.rerank(
query=query,
documents=candidate_texts,
model="rerank-2",
top_k=final_top_n
)
# Process re-ranking results
product_categories = []
for result in reranking.results:
# Find the category ID for this result
candidate_index = candidate_texts.index(result.document)
category_id = candidate_ids[candidate_index]
score = result.relevance_score
# Only include results above the confidence threshold
if score >= confidence_threshold:
product_categories.append((category_id, result.document, score))
print(f"Product: {product}")
print(f"Top 3 candidates before re-ranking: {candidates[:3]}")
print(f"Top 3 candidates after re-ranking: {product_categories[:3]}")
final_results[product] = product_categories
except Exception as e:
print(f"Error during re-ranking for '{product}': {e}")
# Fall back to embedding results if re-ranking fails
final_results[product] = candidates[:final_top_n]
progress_tracker(1.0, desc="Hybrid matching complete")
return final_results