Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
import pickle | |
import os.path | |
from typing import Dict, List, Any, Tuple | |
from embeddings import create_product_embeddings | |
from similarity import compute_similarities | |
from utils import SafeProgress | |
import voyageai | |
# Update default path to be consistent | |
DEFAULT_CATEGORY_EMBEDDINGS_PATH = "data/category_embeddings.pickle" | |
def load_categories(file_path="categories.json") -> Dict[str, str]: | |
""" | |
Load categories from JSON file | |
Args: | |
file_path: Path to the categories JSON file | |
Returns: | |
Dictionary mapping category IDs to their descriptions | |
""" | |
try: | |
with open(file_path, 'r') as f: | |
categories_list = json.load(f) | |
# Convert to dictionary format with id as key and text as value | |
categories = {item["id"]: item["text"] for item in categories_list} | |
print(f"Loaded {len(categories)} categories") | |
return categories | |
except Exception as e: | |
print(f"Error loading categories: {e}") | |
return {} | |
def create_category_embeddings(categories: Dict[str, str], progress=None, | |
pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH, | |
force_regenerate=False) -> Dict[str, Any]: | |
""" | |
Create embeddings for category descriptions | |
Args: | |
categories: Dictionary mapping category IDs to their descriptions | |
progress: Optional progress tracking object | |
pickle_path: Path to the pickle file for caching embeddings | |
force_regenerate: If True, regenerate embeddings even if cache exists | |
Returns: | |
Dictionary mapping category IDs to their embeddings | |
""" | |
progress_tracker = SafeProgress(progress, desc="Generating category embeddings") | |
# Try to load embeddings from pickle file if it exists and force_regenerate is False | |
if not force_regenerate and os.path.exists(pickle_path): | |
progress_tracker(0.1, desc=f"Loading cached embeddings from {pickle_path}") | |
try: | |
with open(pickle_path, 'rb') as f: | |
category_embeddings = pickle.load(f) | |
progress_tracker(1.0, desc=f"Loaded embeddings for {len(category_embeddings)} categories from cache") | |
return category_embeddings | |
except Exception as e: | |
print(f"Error loading cached embeddings: {e}") | |
# Continue with generating new embeddings | |
progress_tracker(0.1, desc=f"Processing {len(categories)} categories") | |
# Extract descriptions to create embeddings | |
category_ids = list(categories.keys()) | |
category_texts = list(categories.values()) | |
# Use the same embedding function used for products | |
texts_with_embeddings = create_product_embeddings(category_texts, progress=progress) | |
# Map embeddings back to category IDs | |
category_embeddings = {} | |
for i, category_id in enumerate(category_ids): | |
if i < len(category_texts) and category_texts[i] in texts_with_embeddings: | |
category_embeddings[category_id] = texts_with_embeddings[category_texts[i]] | |
# Ensure the data directory exists | |
os.makedirs(os.path.dirname(pickle_path), exist_ok=True) | |
# Save embeddings to pickle file | |
progress_tracker(0.9, desc=f"Saving embeddings to {pickle_path}") | |
try: | |
with open(pickle_path, 'wb') as f: | |
pickle.dump(category_embeddings, f) | |
except Exception as e: | |
print(f"Error saving embeddings to pickle file: {e}") | |
progress_tracker(1.0, desc=f"Completed embeddings for {len(category_embeddings)} categories") | |
return category_embeddings | |
def load_category_embeddings(pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, Any]: | |
""" | |
Load pre-computed category embeddings from pickle file | |
Args: | |
pickle_path: Path to the pickle file with cached embeddings | |
Returns: | |
Dictionary mapping category IDs to their embeddings | |
""" | |
if os.path.exists(pickle_path): | |
try: | |
with open(pickle_path, 'rb') as f: | |
category_embeddings = pickle.load(f) | |
print(f"Loaded embeddings for {len(category_embeddings)} categories from {pickle_path}") | |
return category_embeddings | |
except Exception as e: | |
print(f"Error loading cached embeddings: {e}") | |
print(f"No embeddings found at {pickle_path}") | |
return {} | |
def match_products_to_categories(product_names: List[str], categories: Dict[str, str], top_n=5, | |
confidence_threshold=0.5, progress=None, | |
embeddings_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, List]: | |
""" | |
Match products to their most likely categories | |
Args: | |
product_names: List of product names to categorize | |
categories: Dictionary mapping category IDs to their descriptions | |
top_n: Number of top categories to return per product | |
confidence_threshold: Minimum similarity score to include | |
progress: Optional progress tracking object | |
embeddings_path: Path to pre-computed category embeddings | |
Returns: | |
Dictionary mapping products to their matched categories with scores | |
""" | |
progress_tracker = SafeProgress(progress, desc="Matching products to categories") | |
# Step 1: Load or create category embeddings | |
progress_tracker(0.2, desc="Loading category embeddings") | |
category_embeddings = load_category_embeddings(embeddings_path) | |
# If no embeddings were loaded, create them | |
if not category_embeddings: | |
progress_tracker(0.3, desc="Creating category embeddings") | |
category_embeddings = create_category_embeddings(categories, progress, pickle_path=embeddings_path) | |
# Step 2: Create product embeddings | |
progress_tracker(0.4, desc="Creating product embeddings") | |
product_embeddings = create_product_embeddings(product_names, progress=progress) | |
# Step 3: Compute similarities between products and categories | |
progress_tracker(0.6, desc="Computing similarities") | |
similarities = compute_similarities(category_embeddings, product_embeddings) | |
# Process results | |
results = {} | |
progress_tracker(0.8, desc="Processing results") | |
for product, product_similarities in similarities.items(): | |
# Take top N without filtering by threshold | |
top_categories = product_similarities[:top_n] | |
# Add category texts to the results | |
results[product] = [(category_id, categories.get(category_id, "Unknown"), score) | |
for category_id, score in top_categories] | |
progress_tracker(1.0, desc="Completed category matching") | |
return results | |
def hybrid_category_matching(products: List[str], categories: Dict[str, str], | |
embedding_top_n: int = 20, final_top_n: int = 5, | |
confidence_threshold: float = 0.5, | |
expanded_descriptions=None, progress=None) -> Dict[str, List[Tuple]]: | |
""" | |
Two-stage matching: first use embeddings to find candidates, then apply re-ranking | |
Args: | |
products: List of product names to categorize | |
categories: Dictionary mapping category IDs to their descriptions | |
embedding_top_n: Number of top categories to retrieve using embeddings | |
final_top_n: Number of final categories to return after re-ranking | |
confidence_threshold: Minimum score threshold for final results | |
expanded_descriptions: Optional dictionary of expanded product descriptions | |
progress: Optional progress tracking object | |
Returns: | |
Dictionary mapping products to their matched categories with scores | |
""" | |
progress_tracker = SafeProgress(progress, desc="Hybrid category matching") | |
progress_tracker(0.1, desc="Stage 1: Finding candidates with embeddings") | |
# Stage 1: Use embeddings to find candidate categories | |
embedding_results = match_products_to_categories( | |
products, | |
categories, | |
top_n=embedding_top_n, # Get more candidates from embeddings than we'll ultimately return | |
progress=progress_tracker | |
) | |
progress_tracker(0.4, desc="Stage 2: Re-ranking candidates") | |
# Initialize Voyage AI client | |
client = voyageai.Client() | |
# Stage 2: Re-rank the candidates for each product | |
final_results = {} | |
for i, product in enumerate(progress_tracker.tqdm(products, desc="Re-ranking product candidates")): | |
progress_tracker((0.4 + 0.5 * i / len(products)), desc=f"Re-ranking: {product}") | |
# Get the embedding candidates for this product | |
if product not in embedding_results: | |
final_results[product] = [] | |
continue | |
candidates = embedding_results[product] | |
if not candidates: | |
final_results[product] = [] | |
continue | |
# Extract just the category descriptions for re-ranking | |
candidate_ids = [c[0] for c in candidates] | |
candidate_texts = [f"{c[1]}" for c in candidates] | |
try: | |
# Apply re-ranking to the candidates | |
if expanded_descriptions and product in expanded_descriptions: | |
query = f"Which category best describes the product: {expanded_descriptions[product]}" | |
else: | |
query = f"Which category best describes the product: {product}" | |
print(f"Query: {query}") | |
reranking = client.rerank( | |
query=query, | |
documents=candidate_texts, | |
model="rerank-2", | |
top_k=final_top_n | |
) | |
# Process re-ranking results | |
product_categories = [] | |
print(f"RERAANKING RESULTS: {reranking.results}") | |
for result in reranking.results: | |
# Find the category ID for this result | |
candidate_index = candidate_texts.index(result.document) | |
category_id = candidate_ids[candidate_index] | |
score = result.relevance_score | |
# Only include results above the confidence threshold | |
if score >= confidence_threshold: | |
product_categories.append((category_id, result.document, score)) | |
print(f"Product: {product}") | |
print(f"Top 3 candidates before re-ranking: {candidates[:3]}") | |
print(f"Top 3 candidates after re-ranking: {product_categories[:3]}") | |
final_results[product] = product_categories | |
except Exception as e: | |
print(f"Error during re-ranking for '{product}': {e}") | |
# Fall back to embedding results if re-ranking fails | |
final_results[product] = candidates[:final_top_n] | |
progress_tracker(1.0, desc="Hybrid matching complete") | |
return final_results | |