Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import json | |
import pickle | |
import numpy as np | |
from typing import Tuple, List, Dict, Any, Optional | |
import gradio as gr | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('ui_core') | |
# Global variables | |
embeddings = {} | |
# Update default path to point to the pickle file | |
EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/ingredient_embeddings_voyageai.pkl') | |
CATEGORY_EMBEDDINGS_PATH = os.environ.get('CATEGORY_EMBEDDINGS_PATH', 'data/category_embeddings.pickle') | |
def load_embeddings(filepath: str = EMBEDDINGS_PATH) -> Dict[str, Any]: | |
""" | |
Load ingredient embeddings from a pickle file | |
Args: | |
filepath: Path to the embeddings file | |
Returns: | |
Dictionary of ingredient embeddings | |
""" | |
try: | |
logger.info(f"Attempting to load embeddings from: {filepath}") | |
if not os.path.exists(filepath): | |
logger.error(f"Embeddings file not found: {filepath}") | |
# Try alternative file formats | |
alt_paths = [ | |
filepath.replace('.pkl', '.pickle'), | |
filepath.replace('.pickle', '.pkl'), | |
'data/ingredient_embeddings_voyageai.pkl', | |
'data/ingredient_embeddings.pickle' | |
] | |
for alt_path in alt_paths: | |
if os.path.exists(alt_path) and alt_path != filepath: | |
logger.info(f"Found alternative embeddings file: {alt_path}") | |
filepath = alt_path | |
break | |
else: | |
return {} | |
# Determine file type and load accordingly | |
if filepath.endswith(('.pkl', '.pickle')): | |
logger.info(f"Loading pickle file: {filepath}") | |
with open(filepath, 'rb') as f: | |
loaded_embeddings = pickle.load(f) | |
else: | |
logger.info(f"Loading JSON file: {filepath}") | |
with open(filepath, 'r') as f: | |
loaded_embeddings = json.load(f) | |
# Validate the loaded data | |
if not isinstance(loaded_embeddings, dict) or not loaded_embeddings: | |
logger.error(f"Invalid embeddings format in {filepath}") | |
return {} | |
# Convert lists to numpy arrays for faster processing | |
processed_embeddings = {} | |
for ingredient, embedding in loaded_embeddings.items(): | |
if isinstance(embedding, list): | |
processed_embeddings[ingredient] = np.array(embedding) | |
else: | |
processed_embeddings[ingredient] = embedding | |
logger.info(f"Successfully loaded {len(processed_embeddings)} ingredient embeddings") | |
return processed_embeddings | |
except json.JSONDecodeError: | |
logger.error(f"Invalid JSON format in embeddings file: {filepath}") | |
return {} | |
except pickle.UnpicklingError: | |
logger.error(f"Invalid pickle format in embeddings file: {filepath}") | |
return {} | |
except Exception as e: | |
logger.error(f"Error loading embeddings: {str(e)}") | |
return {} | |
# Load embeddings at module import time | |
embeddings = load_embeddings() | |
# If embeddings is empty, try loading category embeddings | |
if not embeddings: | |
logger.info("No ingredient embeddings found, trying category embeddings...") | |
embeddings = load_embeddings(CATEGORY_EMBEDDINGS_PATH) | |
# Sample product names for the example button | |
EXAMPLE_PRODUCTS = """Nature's Promise Spring Water Multipack | |
Red's Burritos | |
Nature's Promise Spring Water Multipack | |
Schweppes Seltzer 12 Pack | |
Hunt's Pasta Sauce | |
Buitoni Filled Pasta | |
Buitoni Filled Pasta | |
Samuel Adams or Blue Moon 12 Pack | |
Mrs. T's Pierogies | |
Buitoni Filled Pasta | |
Pillsbury Dough | |
Nature's Promise Organic Celery Hearts | |
MorningStar Farms Meatless Nuggets, Patties or Crumbles | |
Nature's Promise Organic Celery Hearts | |
Boar's Head Mild Provolone Cheese | |
Athenos Feta Crumbles""" | |
def load_examples(): | |
"""Load example product names into the text input""" | |
return EXAMPLE_PRODUCTS | |
# Removed obsolete theme/CSS imports and functions previously here. | |
# Theme is now handled by Streamlit via .streamlit/config.toml | |
def parse_input(input_text, is_file=False) -> Tuple[List[str], Optional[str]]: | |
"""Parse user input into a list of product names""" | |
try: | |
if is_file: | |
# Handle file input (assuming newline-separated product names) | |
product_names = [line.strip() for line in input_text.split('\n') if line.strip()] | |
else: | |
# Handle text input (assuming newline-separated product names) | |
product_names = [line.strip() for line in input_text.split('\n') if line.strip()] | |
if not product_names: | |
return [], "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No valid product names found. Please check your input.</div>" | |
return product_names, None | |
except Exception as e: | |
logger.error(f"Error parsing input: {str(e)}") | |
return [], f"<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error parsing input: {str(e)}</div>" | |