Spaces:
Sleeping
Sleeping
Commit
·
31ebc8b
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .gitignore +6 -0
- README.md +51 -0
- api_utils.py +384 -0
- app.py +29 -0
- categories.json +140 -0
- category_embeddings.pickle +3 -0
- category_matching.py +258 -0
- chicory_api.py +91 -0
- comparison.py +252 -0
- config.py +2 -0
- data/category_embeddings.pickle +3 -0
- data/ingredient_embeddings_voyageai.pkl +3 -0
- debug_embeddings.py +130 -0
- embeddings.py +128 -0
- generate_category_embeddings.py +29 -0
- main.py +48 -0
- openai_expansion.py +91 -0
- requirements.txt +6 -0
- similarity.py +262 -0
- ui.py +262 -0
- ui_category_matching.py +46 -0
- ui_core.py +140 -0
- ui_expanded_matching.py +224 -0
- ui_formatters.py +419 -0
- ui_hybrid_matching.py +86 -0
- ui_ingredient_matching.py +59 -0
- utils.py +156 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pyc
|
3 |
+
*.pem
|
4 |
+
|
5 |
+
.DS_Store
|
6 |
+
run_app.sh
|
README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
title: Demo
|
4 |
+
sdk: gradio
|
5 |
+
emoji: 🚀
|
6 |
+
colorFrom: purple
|
7 |
+
colorTo: yellow
|
8 |
+
---
|
9 |
+
# Product Categorization App - One-Click Solution
|
10 |
+
|
11 |
+
This is a turnkey solution for categorizing products based on their similarity to ingredients using Voyage AI.
|
12 |
+
|
13 |
+
## Quick Start
|
14 |
+
|
15 |
+
1. Place your `ingredient_embeddings_voyageai.pkl` file in the same folder as this README
|
16 |
+
2. Run the application:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
bash run_app.sh
|
20 |
+
```
|
21 |
+
|
22 |
+
3. That's it! A browser window will open with the app, and a public URL will be created for sharing
|
23 |
+
|
24 |
+
## What You Can Do
|
25 |
+
|
26 |
+
- **Text Input:** Enter product names one per line
|
27 |
+
- **File Upload:** Upload a JSON file with product data
|
28 |
+
- Adjust the number of categories and Similarity Threshold
|
29 |
+
- View the categorization results with confidence scores
|
30 |
+
|
31 |
+
## Hosting on Hugging Face Spaces
|
32 |
+
|
33 |
+
For permanent, free hosting on Gradio:
|
34 |
+
|
35 |
+
1. Create a free account on [Hugging Face](https://huggingface.co/)
|
36 |
+
2. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
37 |
+
3. Click "Create a Space"
|
38 |
+
4. Select "Gradio" as the SDK
|
39 |
+
5. Upload all files (including your embeddings file) to the space
|
40 |
+
6. Your app will be automatically deployed!
|
41 |
+
|
42 |
+
## Files Included
|
43 |
+
|
44 |
+
- `app.py`: The main application code
|
45 |
+
- `requirements.txt`: Required Python packages
|
46 |
+
- `run_app.sh`: One-click deployment script
|
47 |
+
|
48 |
+
## Requirements
|
49 |
+
|
50 |
+
- Python 3.7+
|
51 |
+
- Internet connection (for Voyage AI API)
|
api_utils.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import concurrent.futures
|
3 |
+
from typing import List, Dict, Callable, Any, Tuple
|
4 |
+
from openai import OpenAI
|
5 |
+
import voyageai
|
6 |
+
from utils import SafeProgress
|
7 |
+
import json
|
8 |
+
|
9 |
+
# Centralized API clients
|
10 |
+
def get_openai_client():
|
11 |
+
"""Get a configured OpenAI client"""
|
12 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
13 |
+
return OpenAI(api_key=OPENAI_API_KEY)
|
14 |
+
|
15 |
+
def get_voyage_client():
|
16 |
+
"""Get a configured Voyage AI client"""
|
17 |
+
return voyageai.Client()
|
18 |
+
|
19 |
+
# General batch processing utilities
|
20 |
+
def process_batch(items_batch: List[Any], processor_func: Callable) -> Dict:
|
21 |
+
"""
|
22 |
+
Process a batch of items using the provided processor function
|
23 |
+
|
24 |
+
Args:
|
25 |
+
items_batch: List of items to process
|
26 |
+
processor_func: Function that processes a single item and returns (key, value)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Dictionary of processing results
|
30 |
+
"""
|
31 |
+
results = {}
|
32 |
+
for item in items_batch:
|
33 |
+
try:
|
34 |
+
key, value = processor_func(item)
|
35 |
+
results[key] = value
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error processing batch item '{item}': {e}")
|
38 |
+
results[item] = []
|
39 |
+
return results
|
40 |
+
|
41 |
+
def process_in_parallel(
|
42 |
+
items: List[Any],
|
43 |
+
processor_func: Callable,
|
44 |
+
max_workers: int = 10,
|
45 |
+
progress_tracker: Any = None,
|
46 |
+
progress_start: float = 0.0,
|
47 |
+
progress_end: float = 1.0,
|
48 |
+
progress_desc: str = "Processing in parallel"
|
49 |
+
) -> Dict:
|
50 |
+
"""
|
51 |
+
Process items in parallel using thread pool
|
52 |
+
|
53 |
+
Args:
|
54 |
+
items: List of items to process
|
55 |
+
processor_func: Function that processes a single item
|
56 |
+
max_workers: Maximum number of threads
|
57 |
+
progress_tracker: Optional progress tracking object
|
58 |
+
progress_start: Starting progress percentage (0.0-1.0)
|
59 |
+
progress_end: Ending progress percentage (0.0-1.0)
|
60 |
+
progress_desc: Description for the progress tracker
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Combined results dictionary
|
64 |
+
"""
|
65 |
+
# Ensure reasonable number of workers
|
66 |
+
max_workers = min(max_workers, len(items))
|
67 |
+
|
68 |
+
# Split items into batches
|
69 |
+
batch_size = max(1, len(items) // max_workers)
|
70 |
+
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
|
71 |
+
|
72 |
+
# Process batches in parallel
|
73 |
+
results = {}
|
74 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
75 |
+
future_to_batch = {executor.submit(process_batch, batch, processor_func): i
|
76 |
+
for i, batch in enumerate(batches)}
|
77 |
+
|
78 |
+
for i, future in enumerate(concurrent.futures.as_completed(future_to_batch)):
|
79 |
+
batch_index = future_to_batch[future]
|
80 |
+
|
81 |
+
# Update progress if tracker provided
|
82 |
+
if progress_tracker:
|
83 |
+
progress_percent = progress_start + ((progress_end - progress_start) * (i+1) / len(batches))
|
84 |
+
progress_tracker(progress_percent, desc=f"{progress_desc}: batch {batch_index+1}/{len(batches)}")
|
85 |
+
|
86 |
+
try:
|
87 |
+
batch_results = future.result()
|
88 |
+
results.update(batch_results)
|
89 |
+
except Exception as e:
|
90 |
+
print(f"Error processing batch {batch_index}: {e}")
|
91 |
+
|
92 |
+
return results
|
93 |
+
|
94 |
+
def openai_structured_query(
|
95 |
+
prompt: str,
|
96 |
+
system_message: str = "You are a helpful assistant.",
|
97 |
+
schema: dict = None,
|
98 |
+
model: str = "o3-mini",
|
99 |
+
client=None,
|
100 |
+
schema_name: str = "structured_output"
|
101 |
+
) -> dict:
|
102 |
+
"""
|
103 |
+
Make an OpenAI API call with structured output format
|
104 |
+
|
105 |
+
Args:
|
106 |
+
prompt: The user prompt
|
107 |
+
system_message: The system message to guide the model
|
108 |
+
schema: JSON schema for structured output
|
109 |
+
model: OpenAI model to use
|
110 |
+
client: Optional pre-configured client, otherwise will be created
|
111 |
+
schema_name: Name for the schema
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Parsed JSON response as dictionary
|
115 |
+
"""
|
116 |
+
if client is None:
|
117 |
+
client = get_openai_client()
|
118 |
+
|
119 |
+
try:
|
120 |
+
response = client.responses.create(
|
121 |
+
model=model,
|
122 |
+
input=[
|
123 |
+
{"role": "system", "content": system_message},
|
124 |
+
{"role": "user", "content": prompt}
|
125 |
+
],
|
126 |
+
text={
|
127 |
+
"format": {
|
128 |
+
"type": "json_schema",
|
129 |
+
"name": schema_name,
|
130 |
+
"schema": schema,
|
131 |
+
"strict": True
|
132 |
+
}
|
133 |
+
}
|
134 |
+
)
|
135 |
+
|
136 |
+
# Parse the response
|
137 |
+
return json.loads(response.output_text)
|
138 |
+
except Exception as e:
|
139 |
+
print(f"Error in OpenAI structured query: {e}")
|
140 |
+
raise
|
141 |
+
|
142 |
+
def rank_ingredients_openai(
|
143 |
+
product: str,
|
144 |
+
candidates: List[str],
|
145 |
+
expanded_description: str = None,
|
146 |
+
client=None,
|
147 |
+
model: str = "o3-mini",
|
148 |
+
max_results: int = 3,
|
149 |
+
confidence_threshold: float = 0.5,
|
150 |
+
debug: bool = False
|
151 |
+
) -> List[Tuple[str, float]]:
|
152 |
+
"""
|
153 |
+
Rank ingredients for a product using OpenAI
|
154 |
+
|
155 |
+
Args:
|
156 |
+
product: Product name
|
157 |
+
candidates: List of candidate ingredients
|
158 |
+
expanded_description: Optional expanded product description
|
159 |
+
client: Optional pre-configured client
|
160 |
+
model: OpenAI model to use
|
161 |
+
max_results: Maximum number of results to return
|
162 |
+
confidence_threshold: Minimum confidence threshold
|
163 |
+
debug: Whether to print debug info
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
List of (ingredient, confidence) tuples
|
167 |
+
"""
|
168 |
+
if not candidates:
|
169 |
+
return []
|
170 |
+
|
171 |
+
if client is None:
|
172 |
+
client = get_openai_client()
|
173 |
+
|
174 |
+
if debug:
|
175 |
+
print(f"Ranking for product: {product} with {len(candidates)} candidates")
|
176 |
+
|
177 |
+
# Format prompt with expanded description if available
|
178 |
+
prompt = f"Product: {product}"
|
179 |
+
if expanded_description:
|
180 |
+
prompt += f"\n\nExpanded description: {expanded_description}"
|
181 |
+
prompt += f"\n\nPotential ingredients: {', '.join(candidates)}"
|
182 |
+
|
183 |
+
# Define the ranking schema
|
184 |
+
ranking_schema = {
|
185 |
+
"type": "object",
|
186 |
+
"properties": {
|
187 |
+
"rankings": {
|
188 |
+
"type": "array",
|
189 |
+
"description": f"Only the top {max_results} most relevant ingredients with scores >= {confidence_threshold}",
|
190 |
+
"items": {
|
191 |
+
"type": "object",
|
192 |
+
"properties": {
|
193 |
+
"ingredient": {
|
194 |
+
"type": "string",
|
195 |
+
"description": "The name of the ingredient"
|
196 |
+
},
|
197 |
+
"relevance_score": {
|
198 |
+
"type": "number",
|
199 |
+
"description": "Score between 0 and 1 indicating relevance"
|
200 |
+
},
|
201 |
+
"explanation": {
|
202 |
+
"type": "string",
|
203 |
+
"description": "Brief explanation for the matching"
|
204 |
+
}
|
205 |
+
},
|
206 |
+
"required": ["ingredient", "relevance_score", "explanation"],
|
207 |
+
"additionalProperties": False
|
208 |
+
}
|
209 |
+
}
|
210 |
+
},
|
211 |
+
"required": ["rankings"],
|
212 |
+
"additionalProperties": False
|
213 |
+
}
|
214 |
+
|
215 |
+
try:
|
216 |
+
# Make the API call directly for more control
|
217 |
+
response = client.responses.create(
|
218 |
+
model=model,
|
219 |
+
reasoning={"effort": "low"}, # Include effort parameter from ui_expanded_matching
|
220 |
+
input=[
|
221 |
+
{"role": "system", "content": f"You are a food ingredient matching expert. Rank the top {max_results} ingredient based on how well they match the given product. Only include ingredients with relevance score >= {confidence_threshold}."},
|
222 |
+
{"role": "user", "content": prompt}
|
223 |
+
],
|
224 |
+
text={
|
225 |
+
"format": {
|
226 |
+
"type": "json_schema",
|
227 |
+
"name": "ingredient_ranking",
|
228 |
+
"schema": ranking_schema,
|
229 |
+
"strict": True
|
230 |
+
}
|
231 |
+
}
|
232 |
+
)
|
233 |
+
|
234 |
+
# Parse the response
|
235 |
+
result = json.loads(response.output_text)
|
236 |
+
|
237 |
+
# Process ranking results
|
238 |
+
ingredients = []
|
239 |
+
for item in result["rankings"]:
|
240 |
+
ingredient = item["ingredient"]
|
241 |
+
score = float(item["relevance_score"])
|
242 |
+
ingredients.append((ingredient, score))
|
243 |
+
|
244 |
+
if debug:
|
245 |
+
print(f"Ranking results for {product}: {len(ingredients)} ingredients")
|
246 |
+
if ingredients:
|
247 |
+
print(f"Top match: {ingredients[0]}")
|
248 |
+
|
249 |
+
return ingredients
|
250 |
+
except Exception as e:
|
251 |
+
print(f"Error ranking ingredients for '{product}': {e}")
|
252 |
+
return []
|
253 |
+
|
254 |
+
def rank_categories_openai(
|
255 |
+
product: str,
|
256 |
+
categories: dict,
|
257 |
+
expanded_description: str = None,
|
258 |
+
client=None,
|
259 |
+
model: str = "o3-mini",
|
260 |
+
max_results: int = 5,
|
261 |
+
confidence_threshold: float = 0.5,
|
262 |
+
debug: bool = False
|
263 |
+
) -> List[Tuple[str, float]]:
|
264 |
+
"""
|
265 |
+
Rank food categories for a product using OpenAI
|
266 |
+
|
267 |
+
Args:
|
268 |
+
product: Product name
|
269 |
+
categories: Dictionary of category data
|
270 |
+
expanded_description: Optional expanded product description
|
271 |
+
client: Optional pre-configured client
|
272 |
+
model: OpenAI model to use
|
273 |
+
max_results: Maximum number of results to return
|
274 |
+
confidence_threshold: Minimum confidence threshold
|
275 |
+
debug: Whether to print debug info
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
List of (category, confidence) tuples
|
279 |
+
"""
|
280 |
+
if not categories:
|
281 |
+
return []
|
282 |
+
|
283 |
+
if client is None:
|
284 |
+
client = get_openai_client()
|
285 |
+
|
286 |
+
if debug:
|
287 |
+
print(f"Category ranking for product: {product}")
|
288 |
+
|
289 |
+
# Format categories for the prompt - handle both string and dict formats
|
290 |
+
categories_text = ""
|
291 |
+
for category_id, category_data in categories.items():
|
292 |
+
if isinstance(category_data, str):
|
293 |
+
# Simple string description
|
294 |
+
categories_text += f"- {category_id}: {category_data}\n"
|
295 |
+
elif isinstance(category_data, dict) and 'description' in category_data:
|
296 |
+
# Dictionary with description field
|
297 |
+
categories_text += f"- {category_id}: {category_data['description']}\n"
|
298 |
+
else:
|
299 |
+
# Default case - just use the ID
|
300 |
+
categories_text += f"- {category_id}\n"
|
301 |
+
# categories_text += f"- {category_id}\n"
|
302 |
+
|
303 |
+
# Format prompt with expanded description if available
|
304 |
+
prompt = f"Product: {product}"
|
305 |
+
# if expanded_description:
|
306 |
+
# prompt += f"\n\nExpanded description: {expanded_description}"
|
307 |
+
prompt += f"\n\nAvailable food categories:\n{categories_text}"
|
308 |
+
|
309 |
+
# Define the ranking schema
|
310 |
+
ranking_schema = {
|
311 |
+
"type": "object",
|
312 |
+
"properties": {
|
313 |
+
"rankings": {
|
314 |
+
"type": "array",
|
315 |
+
"description": f"Only the top most relevant category with scores >= {confidence_threshold}",
|
316 |
+
"items": {
|
317 |
+
"type": "object",
|
318 |
+
"properties": {
|
319 |
+
"reasoning": {
|
320 |
+
"type": "string",
|
321 |
+
"description": "Reasoning, , step by step, first weigh options, then consider the best match"
|
322 |
+
},
|
323 |
+
"category": {
|
324 |
+
"type": "string",
|
325 |
+
"description": "The name of the food category"
|
326 |
+
},
|
327 |
+
"relevance_score": {
|
328 |
+
"type": "number",
|
329 |
+
"description": "Score between 0 and 1 indicating relevance"
|
330 |
+
},
|
331 |
+
|
332 |
+
},
|
333 |
+
"required": ["category", "relevance_score", "reasoning"],
|
334 |
+
# "required": ["category", "relevance_score", "explanation"],
|
335 |
+
"additionalProperties": False
|
336 |
+
}
|
337 |
+
}
|
338 |
+
},
|
339 |
+
"required": ["rankings"],
|
340 |
+
"additionalProperties": False
|
341 |
+
}
|
342 |
+
|
343 |
+
try:
|
344 |
+
# Make the API call
|
345 |
+
response = client.responses.create(
|
346 |
+
model=model,
|
347 |
+
# reasoning={"effort": "low"},
|
348 |
+
input=[
|
349 |
+
{"role": "system", "content": f"You are a food categorization expert. Think this through step by step: Rank the top category based on how well it match the given product. Only include categories with relevance score >= {confidence_threshold}."},
|
350 |
+
{"role": "user", "content": prompt}
|
351 |
+
],
|
352 |
+
text={
|
353 |
+
"format": {
|
354 |
+
"type": "json_schema",
|
355 |
+
"name": "category_ranking",
|
356 |
+
"schema": ranking_schema,
|
357 |
+
"strict": True
|
358 |
+
}
|
359 |
+
}
|
360 |
+
)
|
361 |
+
|
362 |
+
# Parse the response
|
363 |
+
result = json.loads(response.output_text)
|
364 |
+
|
365 |
+
# Process ranking results
|
366 |
+
categories = []
|
367 |
+
for item in result["rankings"]:
|
368 |
+
category = item["category"]
|
369 |
+
score = float(item["relevance_score"])
|
370 |
+
categories.append((category, score))
|
371 |
+
|
372 |
+
if debug:
|
373 |
+
print(f"Category results for {product}: {len(categories)} categories")
|
374 |
+
if categories:
|
375 |
+
print(f"Top match: {categories[0]}")
|
376 |
+
|
377 |
+
return categories
|
378 |
+
|
379 |
+
except Exception as e:
|
380 |
+
print(f"Error categorizing {product}: {e}")
|
381 |
+
if debug:
|
382 |
+
import traceback
|
383 |
+
traceback.print_exc()
|
384 |
+
return []
|
app.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
from utils import load_embeddings
|
5 |
+
from ui import categorize_products, create_demo # Updated imports
|
6 |
+
|
7 |
+
# Path to the embeddings file
|
8 |
+
EMBEDDINGS_PATH = "data/ingredient_embeddings_voyageai.pkl"
|
9 |
+
|
10 |
+
# Check if embeddings file exists
|
11 |
+
if not os.path.exists(EMBEDDINGS_PATH):
|
12 |
+
print(f"Error: Embeddings file {EMBEDDINGS_PATH} not found!")
|
13 |
+
print(f"Please ensure the file exists at {os.path.abspath(EMBEDDINGS_PATH)}")
|
14 |
+
sys.exit(1)
|
15 |
+
|
16 |
+
# Load embeddings globally
|
17 |
+
try:
|
18 |
+
embeddings_data = load_embeddings(EMBEDDINGS_PATH)
|
19 |
+
# Make embeddings available to the UI functions
|
20 |
+
import ui
|
21 |
+
ui.embeddings = embeddings_data
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Error loading embeddings: {e}")
|
24 |
+
sys.exit(1)
|
25 |
+
|
26 |
+
# Launch the Gradio interface
|
27 |
+
if __name__ == "__main__":
|
28 |
+
demo = create_demo()
|
29 |
+
demo.launch()
|
categories.json
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{"id": "alcoholic_beverages", "text": "Products containing alcohol for adult consumption. Includes beer (lagers, ales, IPAs), wine (red, white, rosé), spirits (vodka, whiskey, rum), hard seltzers, and pre-mixed alcoholic drinks."},
|
3 |
+
{"id": "beverages", "text": "Non-alcoholic drinks for hydration and refreshment. This parent category includes all drink types such as juices, sodas, water, coffee, tea, and milk alternatives."},
|
4 |
+
{"id": "cocktails_and_mixers", "text": "Products used specifically for creating mixed alcoholic drinks. Includes margarita mix, bloody mary mix, tonic water, sour mix, grenadine, bitters, and non-alcoholic cocktail components."},
|
5 |
+
{"id": "coffee", "text": "Coffee products in various forms and preparations. Includes whole/ground coffee beans, single-serve pods, instant coffee, cold brew, espresso, flavored coffee varieties, and coffee concentrates."},
|
6 |
+
{"id": "fruit_juice", "text": "Beverages made primarily from fruit. Includes 100% juice, juice blends, fresh-squeezed, from concentrate, smoothies, apple juice, orange juice, cranberry juice, and fruit nectars."},
|
7 |
+
{"id": "soft_drinks", "text": "Carbonated non-alcoholic beverages. Includes cola, root beer, ginger ale, lemon-lime soda, diet soda, sparkling water, flavored seltzer, and club soda."},
|
8 |
+
{"id": "specialty_drinks", "text": "Unique or premium non-alcoholic beverages with distinctive ingredients or processes. Includes kombucha, kefir drinks, drinking vinegars, botanical tonics, bubble tea, horchata, and craft sodas."},
|
9 |
+
{"id": "sports_and_energy_drinks", "text": "Beverages formulated for performance enhancement or energy boosting. Includes electrolyte drinks, protein shakes, caffeinated energy drinks, pre-workout beverages, and recovery drinks."},
|
10 |
+
{"id": "tea_and_hot_chocolate", "text": "Tea products and cocoa-based hot beverages. Includes black/green/herbal tea bags, loose leaf tea, chai, matcha, instant tea, hot cocoa mix, drinking chocolate, and cider mixes."},
|
11 |
+
{"id": "water", "text": "Bottled water products of various types. Includes purified water, spring water, mineral water, sparkling water, flavored water, alkaline water, and water with electrolytes."},
|
12 |
+
{"id": "bread_and_bakery", "text": "Bread products and baked goods. This parent category includes all types of bread, rolls, buns, bagels, pastries, and bakery desserts."},
|
13 |
+
{"id": "bagels_english_muffins_and_breakfast", "text": "Breakfast bread products typically served toasted. Includes plain/flavored bagels, English muffins, crumpets, breakfast breads, croissants, and morning buns."},
|
14 |
+
{"id": "desserts", "text": "Sweet baked goods meant as treats or meal finishers. Includes cakes, pies, cookies, brownies, tarts, cheesecakes, parfaits, and bakery dessert items."},
|
15 |
+
{"id": "donuts_and_pastries", "text": "Sweet baked or fried dough products. Includes donuts (glazed, filled, cake), pastries (Danish, croissants, turnovers), bear claws, fritters, and churros."},
|
16 |
+
{"id": "rolls_and_buns", "text": "Individual bread portions shaped for specific uses. Includes dinner rolls, hamburger buns, hot dog buns, submarine rolls, kaiser rolls, brioche buns, and slider buns."},
|
17 |
+
{"id": "sliced_bread", "text": "Pre-sliced loaves of bread for sandwiches and toast. Includes white, whole wheat, multigrain, sourdough, rye, pumpernickel, potato, and specialty grain varieties."},
|
18 |
+
{"id": "snack_cakes", "text": "Pre-packaged individual or small sweet baked goods. Includes cupcakes, mini pies, cream-filled cakes, coffee cakes, muffins, and packaged pastries with extended shelf life."},
|
19 |
+
{"id": "tortillas_and_flatbreads", "text": "Thin, unleavened or minimally leavened bread products. Includes corn/flour tortillas, pita bread, naan, lavash, flatbreads, wraps, and taco shells."},
|
20 |
+
{"id": "deli", "text": "Section featuring prepared ready-to-eat foods and freshly sliced ingredients. This parent category includes prepared meals, sliced meats, fresh dips, and specialty items requiring refrigeration."},
|
21 |
+
{"id": "cured_meats", "text": "Preserved meat products typically eaten uncooked. Includes salami, prosciutto, coppa, pancetta, pepperoni, chorizo, and specialty dried/cured meats."},
|
22 |
+
{"id": "deli_meals_and_sides", "text": "Ready-to-eat prepared foods sold by weight or package. Includes rotisserie chicken, prepared salads (potato, macaroni, coleslaw), heat-and-eat entrees, and prepared side dishes."},
|
23 |
+
{"id": "deli_meats", "text": "Cooked or processed meats sliced to order or pre-sliced. Includes turkey breast, ham, roast beef, pastrami, bologna, chicken breast, and specialty lunch meats."},
|
24 |
+
{"id": "fresh_pastas", "text": "Refrigerated pasta products requiring cooking. Includes fresh ravioli, tortellini, gnocchi, linguine, fettuccine, and stuffed pasta varieties with shorter shelf life than dried pasta."},
|
25 |
+
{"id": "hummus_fresh_dips_and_fresh_salsas", "text": "Refrigerated spreadable dips with limited shelf life. Includes hummus varieties, guacamole, fresh salsa, tzatziki, spinach dip, and refrigerated spreads requiring cold storage."},
|
26 |
+
{"id": "eggs_and_dairy", "text": "Products derived from animal milk and eggs. This parent category includes all dairy products and egg-based items requiring refrigeration."},
|
27 |
+
{"id": "butter_and_margarine", "text": "Spreadable fats for cooking and baking. Includes dairy butter (salted, unsalted, cultured), margarine, plant-based butter alternatives, ghee, and blended spreads."},
|
28 |
+
{"id": "cheese", "text": "Dairy products made from curdled milk. Includes cheddar, mozzarella, Swiss, provolone, American, brie, blue cheese, and various cheese formats (blocks, slices, shredded)."},
|
29 |
+
{"id": "ao_cheese", "text": "Specialty cheese products with artisanal production or organic certification. Includes imported cheeses, raw milk cheeses, aged specialty varieties, organic cheeses, and regional cheese specialties."},
|
30 |
+
{"id": "cream_and_creamers", "text": "Dairy and non-dairy products for coffee and cooking. Includes heavy cream, half & half, whipping cream, coffee creamers (dairy and non-dairy), and cooking creams."},
|
31 |
+
{"id": "dips", "text": "Dairy-based spreads for snacks and appetizers. Includes French onion dip, ranch dip, cream cheese-based dips, sour cream dips, and flavored spreadable dairy products."},
|
32 |
+
{"id": "eggs_and_egg_substitutes", "text": "Chicken eggs and egg alternatives. Includes whole eggs (white, brown, free-range, organic), liquid egg products, egg whites, and plant-based egg substitutes."},
|
33 |
+
{"id": "milk", "text": "Traditional dairy milk products. Includes whole milk, 2% reduced fat, 1% low fat, skim/fat-free milk, lactose-free milk, buttermilk, and flavored dairy milk varieties."},
|
34 |
+
{"id": "plant_based_milks", "text": "Non-dairy milk alternatives made from plants. Includes almond milk, soy milk, oat milk, coconut milk, rice milk, cashew milk, hemp milk, and blended plant milks."},
|
35 |
+
{"id": "pudding_and_gelatins", "text": "Ready-to-eat chilled desserts with soft texture. Includes dairy puddings, gelatin desserts, rice pudding, tapioca pudding, and parfait cups."},
|
36 |
+
{"id": "refrigerated_doughs_and_crusts", "text": "Ready-to-bake fresh dough products requiring refrigeration. Includes cookie dough, biscuit dough, pie crusts, pizza dough, cinnamon rolls, and crescent rolls."},
|
37 |
+
{"id": "sour_cream", "text": "Cultured dairy product with tangy flavor. Includes regular sour cream, light/reduced-fat sour cream, crème fraîche, and sour cream alternatives."},
|
38 |
+
{"id": "yogurt", "text": "Fermented dairy products with live cultures. Includes Greek yogurt, regular yogurt, Icelandic skyr, kefir, yogurt drinks, and varieties with different fat contents and flavors."},
|
39 |
+
{"id": "frozen", "text": "Foods stored and sold in frozen state requiring freezer storage. This parent category includes all items kept frozen until preparation or consumption."},
|
40 |
+
{"id": "frozen_beverages_and_ice", "text": "Frozen drink products and ice. Includes frozen juice concentrate, smoothie bases, frozen coffee drinks, popsicles, ice cubes, crushed ice, and frozen cocktail mixers."},
|
41 |
+
{"id": "frozen_bread_and_potatoes", "text": "Frozen starches requiring heating before serving. Includes frozen garlic bread, dinner rolls, French fries, hash browns, tater tots, potato wedges, and specialty potato products."},
|
42 |
+
{"id": "frozen_desserts", "text": "Sweet frozen treats besides traditional ice cream. Includes frozen yogurt, sherbet, sorbet, gelato, frozen novelties, ice cream cakes, and frozen pies."},
|
43 |
+
{"id": "frozen_family_meals", "text": "Multi-serving frozen entrees to feed multiple people. Includes frozen lasagna, casseroles, pot pies, complete dinners, meal kits, and large-format frozen meals."},
|
44 |
+
{"id": "frozen_fruits_and_vegetables", "text": "Flash-frozen produce for extended storage. Includes frozen berries, mixed fruits, vegetable medleys, stir-fry blends, broccoli, corn, peas, and individually quick frozen (IQF) produce."},
|
45 |
+
{"id": "frozen_meat_and_seafood", "text": "Frozen animal protein products. Includes frozen chicken breasts/tenders, ground beef, fish fillets, shrimp, scallops, meatballs, and specialty meat products requiring freezer storage."},
|
46 |
+
{"id": "frozen_pizza_and_pasta", "text": "Ready-to-heat frozen Italian-style convenience foods. Includes frozen pizza (thin crust, rising crust, specialty), pizza rolls, frozen pasta dishes, ravioli, and Italian entrees."},
|
47 |
+
{"id": "ice_cream", "text": "Frozen dairy desserts with high milk fat content. Includes traditional ice cream, premium ice cream, ice cream bars, ice cream sandwiches, and dairy-based frozen treats."},
|
48 |
+
{"id": "fruits_and_vegetables", "text": "Fresh produce items. This parent category includes all fresh fruits and vegetables, both whole and prepared."},
|
49 |
+
{"id": "fresh_fruit", "text": "Unprocessed whole fruits. Includes apples, bananas, citrus fruits, berries, grapes, stone fruits, tropical fruits, and seasonal fruit varieties."},
|
50 |
+
{"id": "fresh_herbs", "text": "Fresh culinary herbs for flavoring. Includes basil, cilantro, parsley, mint, rosemary, thyme, dill, chives, and other fresh herb varieties."},
|
51 |
+
{"id": "fresh_vegetables", "text": "Unprocessed whole vegetables. Includes tomatoes, peppers, onions, carrots, broccoli, cauliflower, cucumbers, and mainstream vegetable varieties."},
|
52 |
+
{"id": "ao_fresh_vegetables", "text": "Specialty, artisanal or organic vegetables. Includes heirloom tomatoes, organic produce, specialty greens, rare vegetable varieties, and premium vegetable selections."},
|
53 |
+
{"id": "leafy_greens", "text": "Edible plant leaves for salads and cooking. Includes lettuce varieties, spinach, kale, arugula, mixed salad greens, collards, chard, and cooking greens."},
|
54 |
+
{"id": "mushrooms", "text": "Edible fungi varieties. Includes button mushrooms, cremini, portobello, shiitake, oyster, enoki, chanterelle, and specialty mushroom varieties."},
|
55 |
+
{"id": "potatoes_and_starchy_vegetables", "text": "Root vegetables and high-starch produce. Includes potatoes, sweet potatoes, winter squash, yams, turnips, rutabagas, parsnips, and other starchy vegetables."},
|
56 |
+
{"id": "prepared_produce", "text": "Pre-processed fruits and vegetables for convenience. This parent category includes all ready-to-eat cut fruits and vegetables."},
|
57 |
+
{"id": "fresh_prepared_fruit", "text": "Ready-to-eat cut fruit products. Includes fruit salad, cut melon, pineapple chunks, apple slices, fruit platters, and fresh-cut fruit mixes."},
|
58 |
+
{"id": "fresh_prepared_vegetables", "text": "Ready-to-eat cut vegetable products. Includes vegetable trays, pre-cut stir fry mixes, spiralized vegetables, vegetable noodles, and prepared vegetable medleys."},
|
59 |
+
{"id": "vegetarian_protein_and_asian", "text": "Plant-based protein products and Asian ingredients. Includes tofu, tempeh, seitan, meat alternatives, edamame, Asian noodles, and vegetarian protein options."},
|
60 |
+
{"id": "meat", "text": "Animal protein products. This parent category includes all unprocessed and minimally processed animal proteins."},
|
61 |
+
{"id": "bacon_hot_dogs_and_sausage", "text": "Processed and formed meat products. Includes bacon, breakfast sausage, Italian sausage, hot dogs, bratwurst, kielbasa, and specialty sausage varieties."},
|
62 |
+
{"id": "beef", "text": "Meat products from cattle. Includes ground beef, steaks (ribeye, sirloin, filet), roasts, stew meat, brisket, and specialty beef cuts."},
|
63 |
+
{"id": "chicken", "text": "Poultry products from chickens. Includes whole chickens, breasts, thighs, wings, drumsticks, ground chicken, and boneless/bone-in varieties."},
|
64 |
+
{"id": "pork", "text": "Meat products from pigs. Includes pork chops, tenderloin, ribs, shoulder, ground pork, ham, and specialty pork cuts."},
|
65 |
+
{"id": "seafood", "text": "Edible aquatic animals. Includes fish (salmon, tuna, cod, tilapia), shellfish (shrimp, crab, lobster), mollusks (clams, mussels, oysters), and specialty seafood."},
|
66 |
+
{"id": "specialty_and_organic_meat", "text": "Premium meat products with special attributes. Includes organic meats, grass-fed beef, free-range poultry, heritage breed pork, halal/kosher meats, and specialty game meats."},
|
67 |
+
{"id": "turkey", "text": "Poultry products from turkeys. Includes whole turkeys, turkey breasts, ground turkey, turkey thighs, turkey sausage, and other turkey parts."},
|
68 |
+
{"id": "pantry", "text": "Shelf-stable foods stored at room temperature. This parent category includes all non-perishable food items with extended shelf life."},
|
69 |
+
{"id": "baking", "text": "Ingredients primarily used for baking. This parent category includes all baking ingredients, mixes, and decorating supplies."},
|
70 |
+
{"id": "ao_baking", "text": "Specialty baking ingredients with artisanal or organic attributes. Includes organic flour, specialty sugars, premium chocolate, heirloom grain products, and gourmet baking ingredients."},
|
71 |
+
{"id": "baking_mixes", "text": "Pre-measured dry ingredient combinations. Includes cake mixes, brownie mixes, pancake/waffle mixes, muffin mixes, bread mixes, and biscuit mixes."},
|
72 |
+
{"id": "baking_morsels_bars_and_cocoa", "text": "Chocolate and cocoa products for baking. Includes chocolate chips, baking chocolate bars, cocoa powder, white chocolate chips, and flavored baking morsels."},
|
73 |
+
{"id": "cake_decorations", "text": "Items used to decorate baked goods. Includes sprinkles, decorating icing, food coloring, fondant, sugar decorations, and cake toppers."},
|
74 |
+
{"id": "flour_and_meal", "text": "Ground grain products for baking and cooking. Includes all-purpose flour, bread flour, cake flour, whole wheat flour, almond flour, cornmeal, and specialty flours."},
|
75 |
+
{"id": "frosting", "text": "Ready-to-use cake and cookie toppings. Includes canned frosting, frosting tubes, glaze mixes, icing, cream cheese frosting, and specialty frosting varieties."},
|
76 |
+
{"id": "thickening_and_leavening_agents", "text": "Ingredients that change food texture or help it rise. Includes cornstarch, baking powder, baking soda, yeast, gelatin, pectin, xanthan gum, and arrowroot."},
|
77 |
+
{"id": "boxed_dinners", "text": "Shelf-stable meal kits with minimal preparation. Includes macaroni and cheese, hamburger helper, rice dishes, pasta meals, and boxed dinner kits requiring few additional ingredients."},
|
78 |
+
{"id": "broths_and_stocks", "text": "Liquid cooking bases for soups and recipes. Includes chicken broth, beef stock, vegetable broth, bone broth, bouillon, and cooking stock concentrates."},
|
79 |
+
{"id": "canned_goods", "text": "Food preserved in metal cans or glass jars. This parent category includes all canned and jarred shelf-stable foods."},
|
80 |
+
{"id": "ao_canned_goods", "text": "Premium preserved foods with artisanal or organic attributes. Includes organic canned vegetables, gourmet preserved items, imported specialty canned goods, and premium jarred items."},
|
81 |
+
{"id": "canned_beans", "text": "Legumes preserved in liquid. Includes kidney beans, black beans, chickpeas, pinto beans, baked beans, refried beans, and mixed bean varieties."},
|
82 |
+
{"id": "canned_fruit", "text": "Fruit preserved in syrup or juice. Includes peaches, pears, pineapple, mandarin oranges, fruit cocktail, applesauce, and specialty preserved fruits."},
|
83 |
+
{"id": "canned_meals", "text": "Ready-to-eat complete dishes in cans. Includes ravioli, chili, stew, pasta dishes, hash, and fully-prepared shelf-stable meals requiring minimal preparation."},
|
84 |
+
{"id": "canned_meat_poultry_and_hashes", "text": "Preserved meat products in cans. Includes canned chicken, potted meat, corned beef hash, SPAM, Vienna sausages, and shelf-stable meat products."},
|
85 |
+
{"id": "canned_seafood_and_tuna", "text": "Preserved fish and seafood in cans or pouches. Includes tuna (in water, oil), salmon, sardines, crab meat, clams, anchovies, and specialty canned seafood varieties."},
|
86 |
+
{"id": "canned_soups_and_stews", "text": "Ready-to-eat or condensed liquid meals. Includes cream soups, broth-based soups, chili, condensed soups, ready-to-eat soups, and hearty stews."},
|
87 |
+
{"id": "canned_tomatoes_and_dried_tomatoes", "text": "Preserved tomato products. Includes diced tomatoes, tomato sauce, paste, crushed tomatoes, whole peeled tomatoes, sun-dried tomatoes, and tomato puree."},
|
88 |
+
{"id": "canned_vegetables", "text": "Vegetables preserved in liquid. Includes green beans, corn, peas, carrots, mixed vegetables, mushrooms, asparagus, and specialty canned vegetable varieties."},
|
89 |
+
{"id": "cereal_and_breakfast_food", "text": "Ready-to-eat and hot morning meal foods. Includes cold cereals, hot cereals (oatmeal, grits, cream of wheat), breakfast bars, granola, and breakfast pastries."},
|
90 |
+
{"id": "condiments", "text": "Flavor-enhancing additions to prepared foods. This parent category includes all sauces, spices, and food enhancers."},
|
91 |
+
{"id": "ao_condiments", "text": "Specialty flavor enhancers with artisanal or organic attributes. Includes craft hot sauces, small-batch preserves, organic condiments, and gourmet flavor enhancers."},
|
92 |
+
{"id": "fruit_spreads", "text": "Sweet preserved fruit products. Includes jams, jellies, preserves, marmalade, fruit butters, honey, and specialty fruit spreads."},
|
93 |
+
{"id": "hot_sauces", "text": "Spicy condiments for food enhancement. Includes cayenne pepper sauce, habanero sauce, sriracha, tabasco, chipotle sauce, and specialty hot sauces of varying heat levels."},
|
94 |
+
{"id": "ketchup_mayo_and_mustards", "text": "Common sandwich and burger condiments. Includes ketchup, mayonnaise, yellow mustard, dijon mustard, specialty mustards, aioli, and basic table condiments."},
|
95 |
+
{"id": "nut_butters_and_spreads", "text": "Paste-like products made from ground nuts and seeds. Includes peanut butter, almond butter, cashew butter, sunflower seed butter, hazelnut spread, and specialty nut butters."},
|
96 |
+
{"id": "pickles_and_olives", "text": "Vegetables preserved in brine or vinegar. Includes dill pickles, sweet pickles, relish, green olives, kalamata olives, pickled vegetables, and specialty pickled items."},
|
97 |
+
{"id": "salad_dressings_and_toppings", "text": "Liquid and dry additions for salads. Includes ranch, Italian, balsamic, blue cheese dressings, croutons, salad toppings, and vinaigrettes."},
|
98 |
+
{"id": "sauces_marinades_and_gravy", "text": "Liquid flavor enhancers for cooking and finishing. Includes barbecue sauce, teriyaki sauce, pasta sauce, gravy, marinade, steak sauce, and cooking sauces."},
|
99 |
+
{"id": "sugars_sweeteners_and_honey", "text": "Sweet additions for beverages and baking. Includes granulated sugar, brown sugar, powdered sugar, honey, maple syrup, artificial sweeteners, and sugar alternatives."},
|
100 |
+
{"id": "cooking_oils_and_vinegar", "text": "Liquid fats for cooking and acidic flavor enhancers. Includes vegetable oil, olive oil, specialty oils, white vinegar, balsamic vinegar, apple cider vinegar, and specialty vinegars."},
|
101 |
+
{"id": "dried_pasta_and_pasta_sauces", "text": "Shelf-stable Italian-style noodles and accompanying sauces. This parent category includes all pasta and jarred sauces."},
|
102 |
+
{"id": "dried_pasta", "text": "Shelf-stable wheat or grain-based noodle products. Includes spaghetti, penne, fettuccine, elbow macaroni, specialty shapes, whole wheat pasta, and gluten-free pasta varieties."},
|
103 |
+
{"id": "pasta_sauces", "text": "Ready-to-use flavor bases for pasta dishes. Includes marinara, meat sauce, alfredo, pesto, vodka sauce, and specialty pasta sauce varieties in jars or pouches."},
|
104 |
+
{"id": "dried_soup_mixes_and_bouillon", "text": "Dehydrated soup bases and flavor enhancers. Includes ramen, bouillon cubes, soup mixes, broth concentrates, and instant soup packets requiring water addition."},
|
105 |
+
{"id": "herbs_spices_and_seasonings", "text": "Flavor additions for cooking. This parent category includes all dried herbs, spices, and seasoning blends."},
|
106 |
+
{"id": "dried_herbs_and_spices", "text": "Dehydrated plant parts for flavor enhancement. Includes basil, oregano, cinnamon, cumin, paprika, individual spices, and dried herb varieties."},
|
107 |
+
{"id": "salt_and_pepper", "text": "Basic seasoning agents for cooking and table use. Includes table salt, sea salt, kosher salt, specialty salts, black pepper, white pepper, and peppercorns."},
|
108 |
+
{"id": "seasoning_mixes", "text": "Pre-blended spice combinations for specific dishes. Includes taco seasoning, Italian seasoning, chili powder, poultry seasoning, meat rubs, and meal-specific spice blends."},
|
109 |
+
{"id": "international_foods", "text": "Products from global cuisines organized by region. Includes Mexican, Asian, Mediterranean, Indian, European, and other international food products and ingredients."},
|
110 |
+
{"id": "potatoes_and_stuffing", "text": "Shelf-stable potato products and bread mixes. Includes instant mashed potatoes, scalloped/au gratin potatoes, stuffing mix, and dehydrated potato products."},
|
111 |
+
{"id": "rice_grains_and_dried_beans", "text": "Shelf-stable carbohydrate staples. This parent category includes all uncooked grains and dry legumes."},
|
112 |
+
{"id": "dried_beans", "text": "Dehydrated legumes requiring cooking. Includes pinto beans, black beans, kidney beans, lentils, split peas, chickpeas, and dried bean varieties."},
|
113 |
+
{"id": "grains", "text": "Edible seeds from grass-like plants. Includes barley, quinoa, couscous, bulgur, farro, millet, and ancient grains."},
|
114 |
+
{"id": "rice", "text": "Various processed rice grain varieties. Includes white rice, brown rice, jasmine rice, basmati rice, arborio rice, wild rice, and specialty rice varieties."},
|
115 |
+
{"id": "snacks_and_candy", "text": "Ready-to-eat treats and sweets. This parent category includes all snack foods and confectionery items."},
|
116 |
+
{"id": "chips", "text": "Crispy snack foods in thin, flat format. Includes potato chips, tortilla chips, corn chips, vegetable chips, kettle chips, and flavored chip varieties."},
|
117 |
+
{"id": "chocolate_candy_and_gum", "text": "Sweet confectionery products. This parent category includes all candy items with and without chocolate."},
|
118 |
+
{"id": "candy_and_gum", "text": "Sweet non-chocolate confections. Includes hard candy, chewy candy, gummy candy, licorice, mints, chewing gum, caramels, and non-chocolate sweets."},
|
119 |
+
{"id": "chocolate", "text": "Cocoa-based confections and treats. Includes chocolate bars, chocolate candy, truffles, chocolate-covered nuts/fruits, and chocolate gift boxes."},
|
120 |
+
{"id": "cookies", "text": "Sweet baked treats in individual portions. Includes chocolate chip cookies, sandwich cookies, shortbread, specialty cookies, and packaged cookie varieties."},
|
121 |
+
{"id": "crackers", "text": "Crisp, dry, flat baked snack products. Includes saltines, cheese crackers, graham crackers, water crackers, snack crackers, and specialty cracker varieties."},
|
122 |
+
{"id": "fruit_snacks", "text": "Processed fruit-based treats. Includes fruit leather, fruit snacks, dried fruit rolls, fruit-flavored gummies, and portable fruit-based treats."},
|
123 |
+
{"id": "jerky_and_rinds", "text": "Dried meat snacks and crispy pork products. Includes beef jerky, turkey jerky, meat sticks, pork rinds, chicharrones, and dried meat snack varieties."},
|
124 |
+
{"id": "nuts_and_dried_fruit", "text": "Shelf-stable natural snacks. This parent category includes all nuts, seeds, and dried fruit products."},
|
125 |
+
{"id": "dried_fruit", "text": "Dehydrated fruit products for snacking. Includes raisins, dried cranberries, dried apricots, banana chips, apple rings, and mixed dried fruit varieties."},
|
126 |
+
{"id": "nuts", "text": "Edible seeds and kernels in shells or shelled. Includes almonds, peanuts, cashews, walnuts, pistachios, mixed nuts, and specialty nut varieties."},
|
127 |
+
{"id": "packaged_snack_cakes", "text": "Factory-produced sweet baked goods with extended shelf life. Includes snack cakes, mini muffins, donettes, cream-filled cakes, and individually wrapped sweet treats."},
|
128 |
+
{"id": "popcorn_and_pretzels", "text": "Crunchy grain-based snack foods. Includes microwave popcorn, ready-to-eat popcorn, hard pretzels, soft pretzels, pretzel bites, and flavored varieties."},
|
129 |
+
{"id": "snack_bars", "text": "Portable compressed food items in bar form. Includes granola bars, protein bars, cereal bars, energy bars, fruit bars, and meal replacement bars."},
|
130 |
+
{"id": "baby_and_child", "text": "Products specifically designed for infants and young children. Includes baby food, formula, diapers, baby wipes, child-specific snacks, and infant care items."},
|
131 |
+
{"id": "pet_products", "text": "Items for domestic animals. Includes dog food, cat food, pet treats, litter, pet supplies, toys, accessories, and pet care products."},
|
132 |
+
{"id": "personal_care", "text": "Products for human hygiene and grooming. Includes soap, shampoo, deodorant, lotion, toothpaste, feminine care, and toiletry items."},
|
133 |
+
{"id": "household_and_cleaning", "text": "Products for home maintenance. Includes cleaning supplies, laundry products, paper goods, storage items, and household essentials."},
|
134 |
+
{"id": "health_and_pharmacy", "text": "Products related to health and wellness. Includes over-the-counter medications, vitamins, supplements, first aid supplies, and pharmacy items."},
|
135 |
+
{"id": "floral_and_garden", "text": "Plant products and gardening supplies. Includes cut flowers, potted plants, bouquets, seeds, soil, garden tools, and seasonal plant items."},
|
136 |
+
{"id": "kitchenware", "text": "Tools and equipment for food preparation. Includes cookware, utensils, gadgets, small appliances, food storage, and kitchen accessories."},
|
137 |
+
{"id": "paper_products", "text": "Disposable paper-based household items. Includes paper towels, toilet paper, facial tissue, napkins, paper plates, and disposable tableware."},
|
138 |
+
{"id": "seasonal_and_holiday", "text": "Items specific to holidays or times of year. Includes decorations, seasonal foods, holiday-themed products, and limited-time specialty items."},
|
139 |
+
{"id": "electronics_and_media", "text": "Electronic devices and entertainment products. Includes batteries, chargers, headphones, small electronics, DVDs, magazines, and basic media items."}
|
140 |
+
]
|
category_embeddings.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abc3b4442b669e95e7e8c218fe5f5f9ea989dbe98b460f9b76dc0064a204725e
|
3 |
+
size 1276161
|
category_matching.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
import os.path
|
5 |
+
from typing import Dict, List, Any, Tuple
|
6 |
+
from embeddings import create_product_embeddings
|
7 |
+
from similarity import compute_similarities
|
8 |
+
from utils import SafeProgress
|
9 |
+
import voyageai
|
10 |
+
|
11 |
+
# Update default path to be consistent
|
12 |
+
DEFAULT_CATEGORY_EMBEDDINGS_PATH = "data/category_embeddings.pickle"
|
13 |
+
|
14 |
+
def load_categories(file_path="categories.json") -> Dict[str, str]:
|
15 |
+
"""
|
16 |
+
Load categories from JSON file
|
17 |
+
|
18 |
+
Args:
|
19 |
+
file_path: Path to the categories JSON file
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Dictionary mapping category IDs to their descriptions
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
with open(file_path, 'r') as f:
|
26 |
+
categories_list = json.load(f)
|
27 |
+
|
28 |
+
# Convert to dictionary format with id as key and text as value
|
29 |
+
categories = {item["id"]: item["text"] for item in categories_list}
|
30 |
+
print(f"Loaded {len(categories)} categories")
|
31 |
+
return categories
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error loading categories: {e}")
|
34 |
+
return {}
|
35 |
+
|
36 |
+
def create_category_embeddings(categories: Dict[str, str], progress=None,
|
37 |
+
pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH,
|
38 |
+
force_regenerate=False) -> Dict[str, Any]:
|
39 |
+
"""
|
40 |
+
Create embeddings for category descriptions
|
41 |
+
|
42 |
+
Args:
|
43 |
+
categories: Dictionary mapping category IDs to their descriptions
|
44 |
+
progress: Optional progress tracking object
|
45 |
+
pickle_path: Path to the pickle file for caching embeddings
|
46 |
+
force_regenerate: If True, regenerate embeddings even if cache exists
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Dictionary mapping category IDs to their embeddings
|
50 |
+
"""
|
51 |
+
progress_tracker = SafeProgress(progress, desc="Generating category embeddings")
|
52 |
+
|
53 |
+
# Try to load embeddings from pickle file if it exists and force_regenerate is False
|
54 |
+
if not force_regenerate and os.path.exists(pickle_path):
|
55 |
+
progress_tracker(0.1, desc=f"Loading cached embeddings from {pickle_path}")
|
56 |
+
try:
|
57 |
+
with open(pickle_path, 'rb') as f:
|
58 |
+
category_embeddings = pickle.load(f)
|
59 |
+
progress_tracker(1.0, desc=f"Loaded embeddings for {len(category_embeddings)} categories from cache")
|
60 |
+
return category_embeddings
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error loading cached embeddings: {e}")
|
63 |
+
# Continue with generating new embeddings
|
64 |
+
|
65 |
+
progress_tracker(0.1, desc=f"Processing {len(categories)} categories")
|
66 |
+
|
67 |
+
# Extract descriptions to create embeddings
|
68 |
+
category_ids = list(categories.keys())
|
69 |
+
category_texts = list(categories.values())
|
70 |
+
|
71 |
+
# Use the same embedding function used for products
|
72 |
+
texts_with_embeddings = create_product_embeddings(category_texts, progress=progress)
|
73 |
+
|
74 |
+
# Map embeddings back to category IDs
|
75 |
+
category_embeddings = {}
|
76 |
+
for i, category_id in enumerate(category_ids):
|
77 |
+
if i < len(category_texts) and category_texts[i] in texts_with_embeddings:
|
78 |
+
category_embeddings[category_id] = texts_with_embeddings[category_texts[i]]
|
79 |
+
|
80 |
+
# Ensure the data directory exists
|
81 |
+
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
|
82 |
+
|
83 |
+
# Save embeddings to pickle file
|
84 |
+
progress_tracker(0.9, desc=f"Saving embeddings to {pickle_path}")
|
85 |
+
try:
|
86 |
+
with open(pickle_path, 'wb') as f:
|
87 |
+
pickle.dump(category_embeddings, f)
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error saving embeddings to pickle file: {e}")
|
90 |
+
|
91 |
+
progress_tracker(1.0, desc=f"Completed embeddings for {len(category_embeddings)} categories")
|
92 |
+
return category_embeddings
|
93 |
+
|
94 |
+
def load_category_embeddings(pickle_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, Any]:
|
95 |
+
"""
|
96 |
+
Load pre-computed category embeddings from pickle file
|
97 |
+
|
98 |
+
Args:
|
99 |
+
pickle_path: Path to the pickle file with cached embeddings
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Dictionary mapping category IDs to their embeddings
|
103 |
+
"""
|
104 |
+
if os.path.exists(pickle_path):
|
105 |
+
try:
|
106 |
+
with open(pickle_path, 'rb') as f:
|
107 |
+
category_embeddings = pickle.load(f)
|
108 |
+
print(f"Loaded embeddings for {len(category_embeddings)} categories from {pickle_path}")
|
109 |
+
return category_embeddings
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error loading cached embeddings: {e}")
|
112 |
+
|
113 |
+
print(f"No embeddings found at {pickle_path}")
|
114 |
+
return {}
|
115 |
+
|
116 |
+
def match_products_to_categories(product_names: List[str], categories: Dict[str, str], top_n=5,
|
117 |
+
confidence_threshold=0.5, progress=None,
|
118 |
+
embeddings_path=DEFAULT_CATEGORY_EMBEDDINGS_PATH) -> Dict[str, List]:
|
119 |
+
"""
|
120 |
+
Match products to their most likely categories
|
121 |
+
|
122 |
+
Args:
|
123 |
+
product_names: List of product names to categorize
|
124 |
+
categories: Dictionary mapping category IDs to their descriptions
|
125 |
+
top_n: Number of top categories to return per product
|
126 |
+
confidence_threshold: Minimum similarity score to include
|
127 |
+
progress: Optional progress tracking object
|
128 |
+
embeddings_path: Path to pre-computed category embeddings
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
Dictionary mapping products to their matched categories with scores
|
132 |
+
"""
|
133 |
+
progress_tracker = SafeProgress(progress, desc="Matching products to categories")
|
134 |
+
|
135 |
+
# Step 1: Load or create category embeddings
|
136 |
+
progress_tracker(0.2, desc="Loading category embeddings")
|
137 |
+
category_embeddings = load_category_embeddings(embeddings_path)
|
138 |
+
|
139 |
+
# If no embeddings were loaded, create them
|
140 |
+
if not category_embeddings:
|
141 |
+
progress_tracker(0.3, desc="Creating category embeddings")
|
142 |
+
category_embeddings = create_category_embeddings(categories, progress, pickle_path=embeddings_path)
|
143 |
+
|
144 |
+
# Step 2: Create product embeddings
|
145 |
+
progress_tracker(0.4, desc="Creating product embeddings")
|
146 |
+
product_embeddings = create_product_embeddings(product_names, progress=progress)
|
147 |
+
|
148 |
+
# Step 3: Compute similarities between products and categories
|
149 |
+
progress_tracker(0.6, desc="Computing similarities")
|
150 |
+
similarities = compute_similarities(category_embeddings, product_embeddings)
|
151 |
+
|
152 |
+
# Process results
|
153 |
+
results = {}
|
154 |
+
progress_tracker(0.8, desc="Processing results")
|
155 |
+
|
156 |
+
for product, product_similarities in similarities.items():
|
157 |
+
# Filter by threshold and take top N
|
158 |
+
filtered_categories = [(category_id, score)
|
159 |
+
for category_id, score in product_similarities
|
160 |
+
if score >= confidence_threshold]
|
161 |
+
top_categories = filtered_categories[:top_n]
|
162 |
+
|
163 |
+
# Add category texts to the results
|
164 |
+
results[product] = [(category_id, categories.get(category_id, "Unknown"), score)
|
165 |
+
for category_id, score in top_categories]
|
166 |
+
|
167 |
+
progress_tracker(1.0, desc="Completed category matching")
|
168 |
+
return results
|
169 |
+
|
170 |
+
def hybrid_category_matching(products: List[str], categories: Dict[str, str],
|
171 |
+
embedding_top_n: int = 20, final_top_n: int = 5,
|
172 |
+
confidence_threshold: float = 0.5,
|
173 |
+
progress=None) -> Dict[str, List[Tuple]]:
|
174 |
+
"""
|
175 |
+
Two-stage matching: first use embeddings to find candidates, then apply re-ranking
|
176 |
+
|
177 |
+
Args:
|
178 |
+
products: List of product names to categorize
|
179 |
+
categories: Dictionary mapping category IDs to their descriptions
|
180 |
+
embedding_top_n: Number of top categories to retrieve using embeddings
|
181 |
+
final_top_n: Number of final categories to return after re-ranking
|
182 |
+
confidence_threshold: Minimum score threshold for final results
|
183 |
+
progress: Optional progress tracking object
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Dictionary mapping products to their matched categories with scores
|
187 |
+
"""
|
188 |
+
progress_tracker = SafeProgress(progress, desc="Hybrid category matching")
|
189 |
+
progress_tracker(0.1, desc="Stage 1: Finding candidates with embeddings")
|
190 |
+
|
191 |
+
# Stage 1: Use embeddings to find candidate categories
|
192 |
+
embedding_results = match_products_to_categories(
|
193 |
+
products,
|
194 |
+
categories,
|
195 |
+
top_n=embedding_top_n, # Get more candidates from embeddings than we'll ultimately return
|
196 |
+
progress=progress_tracker
|
197 |
+
)
|
198 |
+
|
199 |
+
progress_tracker(0.4, desc="Stage 2: Re-ranking candidates")
|
200 |
+
|
201 |
+
# Initialize Voyage AI client
|
202 |
+
client = voyageai.Client()
|
203 |
+
|
204 |
+
# Stage 2: Re-rank the candidates for each product
|
205 |
+
final_results = {}
|
206 |
+
|
207 |
+
for i, product in enumerate(progress_tracker.tqdm(products, desc="Re-ranking product candidates")):
|
208 |
+
progress_tracker((0.4 + 0.5 * i / len(products)), desc=f"Re-ranking: {product}")
|
209 |
+
|
210 |
+
# Get the embedding candidates for this product
|
211 |
+
if product not in embedding_results:
|
212 |
+
final_results[product] = []
|
213 |
+
continue
|
214 |
+
|
215 |
+
candidates = embedding_results[product]
|
216 |
+
if not candidates:
|
217 |
+
final_results[product] = []
|
218 |
+
continue
|
219 |
+
|
220 |
+
# Extract just the category descriptions for re-ranking
|
221 |
+
candidate_ids = [c[0] for c in candidates]
|
222 |
+
candidate_texts = [f"Category: {c[1]}" for c in candidates]
|
223 |
+
|
224 |
+
try:
|
225 |
+
# Apply re-ranking to the candidates
|
226 |
+
query = f"Which category best describes the product: {product}"
|
227 |
+
reranking = client.rerank(
|
228 |
+
query=query,
|
229 |
+
documents=candidate_texts,
|
230 |
+
model="rerank-2",
|
231 |
+
top_k=final_top_n
|
232 |
+
)
|
233 |
+
|
234 |
+
# Process re-ranking results
|
235 |
+
product_categories = []
|
236 |
+
for result in reranking.results:
|
237 |
+
# Find the category ID for this result
|
238 |
+
candidate_index = candidate_texts.index(result.document)
|
239 |
+
category_id = candidate_ids[candidate_index]
|
240 |
+
score = result.relevance_score
|
241 |
+
|
242 |
+
# Only include results above the confidence threshold
|
243 |
+
if score >= confidence_threshold:
|
244 |
+
product_categories.append((category_id, result.document, score))
|
245 |
+
|
246 |
+
print(f"Product: {product}")
|
247 |
+
print(f"Top 3 candidates before re-ranking: {candidates[:3]}")
|
248 |
+
print(f"Top 3 candidates after re-ranking: {product_categories[:3]}")
|
249 |
+
|
250 |
+
final_results[product] = product_categories
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
print(f"Error during re-ranking for '{product}': {e}")
|
254 |
+
# Fall back to embedding results if re-ranking fails
|
255 |
+
final_results[product] = candidates[:final_top_n]
|
256 |
+
|
257 |
+
progress_tracker(1.0, desc="Hybrid matching complete")
|
258 |
+
return final_results
|
chicory_api.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from utils import SafeProgress
|
6 |
+
|
7 |
+
def call_chicory_parser(product_names: List[str], batch_size: int = 25, delay_seconds: float = 0.1, progress=None) -> Dict[str, Any]:
|
8 |
+
"""
|
9 |
+
Call the Chicory Parser V3 API to get ingredient predictions
|
10 |
+
|
11 |
+
Args:
|
12 |
+
product_names: List of product names to parse
|
13 |
+
batch_size: Maximum number of products to process in one batch
|
14 |
+
delay_seconds: Delay between batches in seconds
|
15 |
+
progress: Optional progress tracking object (Gradio progress bar)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Dictionary mapping product names to their Chicory Parser results
|
19 |
+
"""
|
20 |
+
progress_tracker = SafeProgress(progress, desc="Parsing products")
|
21 |
+
|
22 |
+
# Check if batching is needed
|
23 |
+
if len(product_names) <= batch_size:
|
24 |
+
progress_tracker(0.1, desc=f"Parsing {len(product_names)} products...")
|
25 |
+
result = _make_chicory_api_call(product_names)
|
26 |
+
progress_tracker(1.0, desc="Parsing complete")
|
27 |
+
return result
|
28 |
+
|
29 |
+
# Process in batches
|
30 |
+
all_results = {}
|
31 |
+
total_batches = (len(product_names) + batch_size - 1) // batch_size
|
32 |
+
|
33 |
+
# Create batch index ranges
|
34 |
+
batch_ranges = [(i, min(i + batch_size, len(product_names)))
|
35 |
+
for i in range(0, len(product_names), batch_size)]
|
36 |
+
|
37 |
+
# Process each batch with tqdm progress
|
38 |
+
for i, (start, end) in enumerate(progress_tracker.tqdm(batch_ranges, desc="Processing batches")):
|
39 |
+
batch = product_names[start:end]
|
40 |
+
batch_number = i + 1
|
41 |
+
|
42 |
+
# Update with more specific progress info
|
43 |
+
batch_desc = f"Batch {batch_number}/{total_batches}: {len(batch)} products"
|
44 |
+
progress_tracker((i + 0.5) / total_batches, desc=batch_desc)
|
45 |
+
|
46 |
+
batch_results = _make_chicory_api_call(batch)
|
47 |
+
all_results.update(batch_results)
|
48 |
+
|
49 |
+
# Add delay before processing the next batch (but not after the last batch)
|
50 |
+
if end < len(product_names):
|
51 |
+
time.sleep(delay_seconds)
|
52 |
+
|
53 |
+
progress_tracker(1.0, desc=f"Completed parsing {len(product_names)} products")
|
54 |
+
return all_results
|
55 |
+
|
56 |
+
def _make_chicory_api_call(product_names: List[str]) -> Dict[str, Any]:
|
57 |
+
"""
|
58 |
+
Makes the actual API call to Chicory Parser
|
59 |
+
"""
|
60 |
+
url = "https://prod-parserv3.chicoryapp.com/api/v3/prediction"
|
61 |
+
|
62 |
+
# Prepare the payload
|
63 |
+
items = [{"id": i, "text": name} for i, name in enumerate(product_names)]
|
64 |
+
payload = json.dumps({"items": items})
|
65 |
+
|
66 |
+
# Set headers
|
67 |
+
headers = {
|
68 |
+
'Content-Type': 'application/json'
|
69 |
+
}
|
70 |
+
|
71 |
+
try:
|
72 |
+
response = requests.post(url, headers=headers, data=payload)
|
73 |
+
response.raise_for_status() # Raise exception for HTTP errors
|
74 |
+
|
75 |
+
# Parse the response
|
76 |
+
results = response.json()
|
77 |
+
|
78 |
+
# Create a dictionary mapping product names to results
|
79 |
+
product_results = {}
|
80 |
+
for result in results:
|
81 |
+
product_name = result["input_text"]
|
82 |
+
product_results[product_name] = result
|
83 |
+
|
84 |
+
return product_results
|
85 |
+
|
86 |
+
except requests.exceptions.RequestException as e:
|
87 |
+
print(f"Error calling Chicory Parser API: {e}")
|
88 |
+
return {}
|
89 |
+
except json.JSONDecodeError:
|
90 |
+
print(f"Error parsing Chicory API response: {response.text}")
|
91 |
+
return {}
|
comparison.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
from typing import Dict, List, Tuple, Any
|
4 |
+
import concurrent.futures
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
from api_utils import get_openai_client, get_voyage_client, process_in_parallel, rank_ingredients_openai
|
8 |
+
from ui_formatters import format_comparison_html, create_results_container
|
9 |
+
|
10 |
+
def compare_ingredient_methods(products: List[str], ingredients_dict: Dict[str, Any],
|
11 |
+
embedding_top_n: int = 20, final_top_n: int = 3,
|
12 |
+
confidence_threshold: float = 0.5,
|
13 |
+
progress=None) -> Dict[str, Dict[str, List[Tuple]]]:
|
14 |
+
"""
|
15 |
+
Compare four different methods for ingredient matching:
|
16 |
+
1. Base embeddings (without re-ranking)
|
17 |
+
2. Voyage AI reranker (via hybrid approach)
|
18 |
+
3. Chicory parser
|
19 |
+
4. GPT-4o structured output
|
20 |
+
|
21 |
+
Args:
|
22 |
+
products: List of product names to categorize
|
23 |
+
ingredients_dict: Dictionary of ingredient names to embeddings
|
24 |
+
embedding_top_n: Number of top ingredients to retrieve using embeddings
|
25 |
+
final_top_n: Number of final results to show for each method
|
26 |
+
confidence_threshold: Minimum score threshold for final results
|
27 |
+
progress: Optional progress tracking object
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Dictionary mapping products to results from each method
|
31 |
+
"""
|
32 |
+
from utils import SafeProgress, preprocess_product_for_matching
|
33 |
+
from embeddings import create_product_embeddings
|
34 |
+
from chicory_api import call_chicory_parser
|
35 |
+
from similarity import compute_similarities
|
36 |
+
|
37 |
+
progress_tracker = SafeProgress(progress, desc="Comparing ingredient matching methods")
|
38 |
+
|
39 |
+
# Step 1: Generate embeddings for all products (used by multiple methods)
|
40 |
+
progress_tracker(0.1, desc="Generating product embeddings")
|
41 |
+
product_embeddings = create_product_embeddings(products, progress=progress_tracker)
|
42 |
+
|
43 |
+
# Step 2: Get embedding-based candidates for all products
|
44 |
+
progress_tracker(0.2, desc="Finding embedding candidates")
|
45 |
+
similarities = compute_similarities(ingredients_dict, product_embeddings)
|
46 |
+
|
47 |
+
# Filter to top N candidates per product
|
48 |
+
embedding_results = {}
|
49 |
+
for product, product_similarities in similarities.items():
|
50 |
+
embedding_results[product] = product_similarities[:embedding_top_n]
|
51 |
+
|
52 |
+
# Step 3: Call Chicory Parser API (this is done for all products at once)
|
53 |
+
progress_tracker(0.3, desc="Calling Chicory Parser API")
|
54 |
+
chicory_results = call_chicory_parser(products, progress=progress_tracker)
|
55 |
+
|
56 |
+
# Create final results dictionary with base embeddings (which don't need any further processing)
|
57 |
+
comparison_results = {}
|
58 |
+
for product in products:
|
59 |
+
if product in embedding_results:
|
60 |
+
# Initialize with base embeddings already calculated
|
61 |
+
candidates = embedding_results[product]
|
62 |
+
base_results = [(c[0], c[1]) for c in candidates[:final_top_n] if c[1] >= confidence_threshold]
|
63 |
+
comparison_results[product] = {
|
64 |
+
"base": base_results,
|
65 |
+
"voyage": [],
|
66 |
+
"chicory": [],
|
67 |
+
"openai": []
|
68 |
+
}
|
69 |
+
|
70 |
+
# Also process Chicory results immediately as they're already fetched
|
71 |
+
chicory_matches = []
|
72 |
+
if product in chicory_results:
|
73 |
+
chicory_data = chicory_results[product]
|
74 |
+
if isinstance(chicory_data, dict):
|
75 |
+
ingredient = chicory_data.get("ingredient", "")
|
76 |
+
confidence = chicory_data.get("confidence", 0)
|
77 |
+
if ingredient and confidence >= confidence_threshold:
|
78 |
+
chicory_matches.append((ingredient, confidence))
|
79 |
+
comparison_results[product]["chicory"] = chicory_matches
|
80 |
+
else:
|
81 |
+
comparison_results[product] = {
|
82 |
+
"base": [],
|
83 |
+
"voyage": [],
|
84 |
+
"chicory": [],
|
85 |
+
"openai": []
|
86 |
+
}
|
87 |
+
|
88 |
+
# Initialize clients for reranking - REPLACED WITH UTILITY FUNCTIONS
|
89 |
+
voyage_client = get_voyage_client()
|
90 |
+
openai_client = get_openai_client()
|
91 |
+
|
92 |
+
# Define the methods that will be executed in parallel (now focused only on the API-heavy tasks)
|
93 |
+
def process_voyage_reranking(product):
|
94 |
+
if product not in embedding_results or not embedding_results[product]:
|
95 |
+
return product, []
|
96 |
+
|
97 |
+
candidates = embedding_results[product]
|
98 |
+
candidate_ingredients = [c[0] for c in candidates]
|
99 |
+
candidate_texts = [f"Ingredient: {c[0]}" for c in candidates]
|
100 |
+
|
101 |
+
try:
|
102 |
+
# Apply Voyage reranking to the candidates
|
103 |
+
query = product # Use product directly as query
|
104 |
+
reranking = voyage_client.rerank(
|
105 |
+
query=query,
|
106 |
+
documents=candidate_texts,
|
107 |
+
model="rerank-2",
|
108 |
+
top_k=final_top_n
|
109 |
+
)
|
110 |
+
|
111 |
+
# Process reranking results
|
112 |
+
voyage_ingredients = []
|
113 |
+
for result in reranking.results:
|
114 |
+
# Find the ingredient for this result
|
115 |
+
candidate_index = candidate_texts.index(result.document)
|
116 |
+
ingredient = candidate_ingredients[candidate_index]
|
117 |
+
score = float(result.relevance_score)
|
118 |
+
|
119 |
+
# Only include results above the confidence threshold
|
120 |
+
if score >= confidence_threshold:
|
121 |
+
voyage_ingredients.append((ingredient, score))
|
122 |
+
|
123 |
+
return product, voyage_ingredients
|
124 |
+
except Exception as e:
|
125 |
+
print(f"Error during Voyage reranking for '{product}': {e}")
|
126 |
+
# Fall back to embedding results
|
127 |
+
return product, [(c[0], c[1]) for c in candidates[:final_top_n] if c[1] >= confidence_threshold]
|
128 |
+
|
129 |
+
def process_openai(product):
|
130 |
+
if product not in embedding_results or not embedding_results[product]:
|
131 |
+
return product, []
|
132 |
+
|
133 |
+
candidates = embedding_results[product]
|
134 |
+
candidate_ingredients = [c[0] for c in candidates]
|
135 |
+
|
136 |
+
try:
|
137 |
+
# Use the shared utility function
|
138 |
+
openai_ingredients = rank_ingredients_openai(
|
139 |
+
product=product,
|
140 |
+
candidates=candidate_ingredients,
|
141 |
+
client=openai_client,
|
142 |
+
model="o3-mini",
|
143 |
+
max_results=final_top_n,
|
144 |
+
confidence_threshold=confidence_threshold
|
145 |
+
)
|
146 |
+
|
147 |
+
return product, openai_ingredients
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error during OpenAI processing for '{product}': {e}")
|
150 |
+
# Fall back to embedding results
|
151 |
+
return product, [(c[0], c[1]) for c in candidates[:final_top_n] if c[1] >= confidence_threshold]
|
152 |
+
|
153 |
+
# Process Voyage AI reranking in parallel - REPLACED WITH SHARED UTILITY
|
154 |
+
progress_tracker(0.4, desc="Running Voyage AI reranking in parallel")
|
155 |
+
voyage_results = process_in_parallel(
|
156 |
+
items=products,
|
157 |
+
processor_func=process_voyage_reranking,
|
158 |
+
max_workers=min(20, len(products)),
|
159 |
+
progress_tracker=progress_tracker,
|
160 |
+
progress_start=0.4,
|
161 |
+
progress_end=0.65,
|
162 |
+
progress_desc="Voyage AI"
|
163 |
+
)
|
164 |
+
|
165 |
+
# Update comparison results with Voyage results
|
166 |
+
for product, results in voyage_results.items():
|
167 |
+
if product in comparison_results:
|
168 |
+
comparison_results[product]["voyage"] = results
|
169 |
+
|
170 |
+
# Process OpenAI queries in parallel - REPLACED WITH SHARED UTILITY
|
171 |
+
progress_tracker(0.7, desc="Running OpenAI processing in parallel")
|
172 |
+
openai_results = process_in_parallel(
|
173 |
+
items=products,
|
174 |
+
processor_func=process_openai,
|
175 |
+
max_workers=min(20, len(products)),
|
176 |
+
progress_tracker=progress_tracker,
|
177 |
+
progress_start=0.7,
|
178 |
+
progress_end=0.95,
|
179 |
+
progress_desc="OpenAI"
|
180 |
+
)
|
181 |
+
|
182 |
+
# Update comparison results with OpenAI results
|
183 |
+
for product, results in openai_results.items():
|
184 |
+
if product in comparison_results:
|
185 |
+
comparison_results[product]["openai"] = results
|
186 |
+
|
187 |
+
progress_tracker(1.0, desc="Comparison complete")
|
188 |
+
return comparison_results
|
189 |
+
|
190 |
+
def compare_ingredient_methods_ui(product_input, is_file=False, embedding_top_n=20,
|
191 |
+
final_top_n=3, confidence_threshold=0.5, progress=None):
|
192 |
+
"""
|
193 |
+
Compare multiple ingredient matching methods on the same products
|
194 |
+
|
195 |
+
Args:
|
196 |
+
product_input: Text input with product names or file path
|
197 |
+
is_file: Whether the input is a file
|
198 |
+
embedding_top_n: Number of top ingredients to retrieve using embeddings
|
199 |
+
final_top_n: Number of final results to show for each method
|
200 |
+
confidence_threshold: Minimum score threshold for final results
|
201 |
+
progress: Optional progress tracking object
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
HTML formatted comparison results
|
205 |
+
"""
|
206 |
+
from utils import SafeProgress, load_embeddings
|
207 |
+
|
208 |
+
progress_tracker = SafeProgress(progress, desc="Comparing ingredient matching methods")
|
209 |
+
progress_tracker(0.1, desc="Processing input")
|
210 |
+
|
211 |
+
|
212 |
+
# Split text input by lines and remove empty lines
|
213 |
+
if not product_input:
|
214 |
+
return "Please enter at least one product."
|
215 |
+
product_names = [p.strip() for p in product_input.split('\n') if p.strip()]
|
216 |
+
if not product_names:
|
217 |
+
return "Please enter at least one product."
|
218 |
+
|
219 |
+
# Load ingredient embeddings
|
220 |
+
try:
|
221 |
+
progress_tracker(0.2, desc="Loading ingredient embeddings")
|
222 |
+
ingredients_dict = load_embeddings("data/ingredient_embeddings_voyageai.pkl")
|
223 |
+
|
224 |
+
progress_tracker(0.3, desc="Comparing methods")
|
225 |
+
comparison_results = compare_ingredient_methods(
|
226 |
+
products=product_names,
|
227 |
+
ingredients_dict=ingredients_dict,
|
228 |
+
embedding_top_n=embedding_top_n,
|
229 |
+
final_top_n=final_top_n,
|
230 |
+
confidence_threshold=confidence_threshold,
|
231 |
+
progress=progress_tracker
|
232 |
+
)
|
233 |
+
except Exception as e:
|
234 |
+
import traceback
|
235 |
+
error_details = traceback.format_exc()
|
236 |
+
return f"<div style='color: red;'>Error comparing methods: {str(e)}<br><pre>{error_details}</pre></div>"
|
237 |
+
|
238 |
+
# Format results as HTML using centralized formatters
|
239 |
+
progress_tracker(0.9, desc="Formatting results")
|
240 |
+
|
241 |
+
result_elements = []
|
242 |
+
for product in product_names:
|
243 |
+
if product in comparison_results:
|
244 |
+
result_elements.append(format_comparison_html(product, comparison_results[product]))
|
245 |
+
|
246 |
+
output_html = create_results_container(
|
247 |
+
result_elements,
|
248 |
+
header_text=f"Comparing {len(product_names)} products using multiple ingredient matching methods."
|
249 |
+
)
|
250 |
+
|
251 |
+
progress_tracker(1.0, desc="Complete")
|
252 |
+
return output_html
|
config.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Add UI configuration
|
2 |
+
UI_THEME = "dark" # "light" or "dark"
|
data/category_embeddings.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c51642451d7f5853975e974b46d7466c1a4c238f9caaa302c7ad454111c4fed
|
3 |
+
size 1275461
|
data/ingredient_embeddings_voyageai.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:394e5ca827ca948d6e44d830b12e071c24ac5898a52b9ce00ff54480a0f3e3c0
|
3 |
+
size 27292336
|
debug_embeddings.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Debug tool for checking ingredient embeddings
|
4 |
+
Run with: python debug_embeddings.py [optional_embeddings_path]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import json
|
10 |
+
import pickle
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
18 |
+
)
|
19 |
+
logger = logging.getLogger('debug_embeddings')
|
20 |
+
|
21 |
+
def check_embeddings_file(filepath):
|
22 |
+
"""Check if embeddings file exists and is valid"""
|
23 |
+
logger.info(f"Checking embeddings file: {filepath}")
|
24 |
+
|
25 |
+
# Check if file exists
|
26 |
+
if not os.path.exists(filepath):
|
27 |
+
logger.error(f"ERROR: Embeddings file not found at {filepath}")
|
28 |
+
return False
|
29 |
+
|
30 |
+
# Check file size
|
31 |
+
file_size = os.path.getsize(filepath) / (1024 * 1024) # Size in MB
|
32 |
+
logger.info(f"File size: {file_size:.2f} MB")
|
33 |
+
|
34 |
+
# Determine file type based on extension
|
35 |
+
is_pickle = filepath.endswith(('.pkl', '.pickle'))
|
36 |
+
|
37 |
+
# Check if file is valid
|
38 |
+
try:
|
39 |
+
if is_pickle:
|
40 |
+
with open(filepath, 'rb') as f:
|
41 |
+
data = pickle.load(f)
|
42 |
+
else:
|
43 |
+
with open(filepath, 'r') as f:
|
44 |
+
data = json.load(f)
|
45 |
+
|
46 |
+
if not isinstance(data, dict):
|
47 |
+
logger.error("ERROR: Embeddings file is not a valid dictionary")
|
48 |
+
return False
|
49 |
+
|
50 |
+
num_ingredients = len(data)
|
51 |
+
logger.info(f"Number of ingredients/categories: {num_ingredients}")
|
52 |
+
|
53 |
+
if num_ingredients == 0:
|
54 |
+
logger.error("ERROR: Embeddings dictionary is empty")
|
55 |
+
return False
|
56 |
+
|
57 |
+
# Check a few random entries
|
58 |
+
import random
|
59 |
+
sample_keys = random.sample(list(data.keys()), min(3, len(data)))
|
60 |
+
logger.info(f"Sample keys: {sample_keys}")
|
61 |
+
|
62 |
+
for key in sample_keys:
|
63 |
+
embedding = data[key]
|
64 |
+
if isinstance(embedding, list):
|
65 |
+
embedding_dim = len(embedding)
|
66 |
+
logger.info(f"Embedding for '{key}' is a list with dimension: {embedding_dim}")
|
67 |
+
elif hasattr(embedding, 'shape'): # numpy array
|
68 |
+
logger.info(f"Embedding for '{key}' is a numpy array with shape: {embedding.shape}")
|
69 |
+
else:
|
70 |
+
logger.info(f"Embedding for '{key}' is of type: {type(embedding)}")
|
71 |
+
|
72 |
+
return True
|
73 |
+
|
74 |
+
except json.JSONDecodeError:
|
75 |
+
logger.error("ERROR: File is not valid JSON")
|
76 |
+
return False
|
77 |
+
except pickle.UnpicklingError:
|
78 |
+
logger.error("ERROR: File is not a valid pickle file")
|
79 |
+
return False
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"ERROR: Unexpected error checking embeddings: {str(e)}")
|
82 |
+
return False
|
83 |
+
|
84 |
+
def main():
|
85 |
+
# Get embeddings path from argument or environment or default
|
86 |
+
if len(sys.argv) > 1:
|
87 |
+
filepath = sys.argv[1]
|
88 |
+
else:
|
89 |
+
filepath = os.environ.get('EMBEDDINGS_PATH', 'data/ingredient_embeddings_voyageai.pkl')
|
90 |
+
|
91 |
+
# Check if path exists and is valid
|
92 |
+
if check_embeddings_file(filepath):
|
93 |
+
logger.info("✅ Embeddings file looks valid!")
|
94 |
+
|
95 |
+
# Suggest setting environment variable if not already set
|
96 |
+
if 'EMBEDDINGS_PATH' not in os.environ:
|
97 |
+
logger.info(f"TIP: Set the EMBEDDINGS_PATH environment variable to: {filepath}")
|
98 |
+
logger.info(f" export EMBEDDINGS_PATH=\"{filepath}\"")
|
99 |
+
else:
|
100 |
+
logger.error("❌ Embeddings file has issues that need to be fixed")
|
101 |
+
|
102 |
+
# Look for specific pickle files
|
103 |
+
specific_files = [
|
104 |
+
'data/ingredient_embeddings_voyageai.pkl',
|
105 |
+
'data/category_embeddings.pickle'
|
106 |
+
]
|
107 |
+
|
108 |
+
# Look for embedding files in data directory
|
109 |
+
data_dir = Path('data')
|
110 |
+
if data_dir.exists():
|
111 |
+
logger.info("Checking 'data' directory for embedding files:")
|
112 |
+
for file in data_dir.glob('*embed*.p*'):
|
113 |
+
logger.info(f" - {file}")
|
114 |
+
if file.name in specific_files:
|
115 |
+
logger.info(f" ✓ Found target file: {file}")
|
116 |
+
logger.info(f" Try running with: python debug_embeddings.py {file}")
|
117 |
+
|
118 |
+
# Look for similar files that might be the correct embeddings
|
119 |
+
dir_path = os.path.dirname(filepath) or '.'
|
120 |
+
try:
|
121 |
+
similar_files = list(Path(dir_path).glob("*embed*.p*"))
|
122 |
+
if similar_files:
|
123 |
+
logger.info("Found similar files that might contain embeddings:")
|
124 |
+
for file in similar_files:
|
125 |
+
logger.info(f" - {file}")
|
126 |
+
except Exception:
|
127 |
+
pass
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main()
|
embeddings.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional
|
2 |
+
from utils import SafeProgress
|
3 |
+
import os
|
4 |
+
import voyageai
|
5 |
+
import time
|
6 |
+
import numpy as np
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
# Set Voyage AI API key directly
|
10 |
+
voyageai.api_key = os.getenv("VOYAGE_API_KEY")
|
11 |
+
|
12 |
+
def get_embeddings_batch(texts, model="voyage-3-large", batch_size=100):
|
13 |
+
"""Get embeddings for a list of texts in batches"""
|
14 |
+
all_embeddings = []
|
15 |
+
total_texts = len(texts)
|
16 |
+
|
17 |
+
# Pre-process all texts to replace newlines
|
18 |
+
texts = [text.replace("\n", " ") for text in texts]
|
19 |
+
|
20 |
+
for i in range(0, len(texts), batch_size):
|
21 |
+
batch = texts[i:i+batch_size]
|
22 |
+
current_count = min(i + batch_size, total_texts)
|
23 |
+
|
24 |
+
try:
|
25 |
+
response = voyageai.Embedding.create(input=batch, model=model)
|
26 |
+
batch_embeddings = [item['embedding'] for item in response['data']]
|
27 |
+
all_embeddings.extend(batch_embeddings)
|
28 |
+
|
29 |
+
# Sleep briefly to avoid rate limits
|
30 |
+
if i + batch_size < len(texts):
|
31 |
+
time.sleep(0.5)
|
32 |
+
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Error in batch {i//batch_size + 1}: {e}")
|
35 |
+
# Add empty embeddings for failed batch
|
36 |
+
all_embeddings.extend([None] * len(batch))
|
37 |
+
|
38 |
+
return all_embeddings
|
39 |
+
|
40 |
+
def create_product_embeddings(products: List[str], batch_size: int = 100, progress=None) -> Dict[str, Any]:
|
41 |
+
"""
|
42 |
+
Create embeddings for product names with optimization for duplicates
|
43 |
+
|
44 |
+
Args:
|
45 |
+
products: List of product names to create embeddings for
|
46 |
+
batch_size: Maximum number of products to process in one batch
|
47 |
+
progress: Optional progress tracking object (Gradio progress bar)
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Dictionary mapping product names to their embeddings
|
51 |
+
"""
|
52 |
+
progress_tracker = SafeProgress(progress, desc="Generating embeddings")
|
53 |
+
total_products = len(products)
|
54 |
+
|
55 |
+
# Initialize results dictionary
|
56 |
+
product_embeddings = {}
|
57 |
+
|
58 |
+
# Use the same model as for ingredients (voyage-3-large)
|
59 |
+
model = "voyage-3-large"
|
60 |
+
|
61 |
+
# Process in batches with de-duplication
|
62 |
+
progress_tracker(0.1, desc=f"Starting embeddings for {total_products} products")
|
63 |
+
|
64 |
+
# De-duplication step
|
65 |
+
unique_products = []
|
66 |
+
product_to_index = {}
|
67 |
+
index_map = {} # Maps original index to index in unique_products
|
68 |
+
|
69 |
+
for i, product in enumerate(products):
|
70 |
+
if product in product_to_index:
|
71 |
+
# Product already seen, just store the mapping
|
72 |
+
index_map[i] = product_to_index[product]
|
73 |
+
else:
|
74 |
+
# New unique product
|
75 |
+
product_to_index[product] = len(unique_products)
|
76 |
+
index_map[i] = len(unique_products)
|
77 |
+
unique_products.append(product)
|
78 |
+
|
79 |
+
progress_tracker(0.2, desc=f"Found {len(unique_products)} unique products out of {total_products} total")
|
80 |
+
|
81 |
+
if len(unique_products) == 0:
|
82 |
+
progress_tracker(1.0, desc="No valid products to process")
|
83 |
+
return {}
|
84 |
+
|
85 |
+
# Get embeddings in batches for unique products only
|
86 |
+
try:
|
87 |
+
# Pre-process all texts to replace newlines
|
88 |
+
clean_products = [product.replace("\n", " ") for product in unique_products]
|
89 |
+
|
90 |
+
progress_tracker(0.3, desc=f"Calling VoyageAI API for {len(clean_products)} unique products")
|
91 |
+
|
92 |
+
# Process in smaller batches for better reliability
|
93 |
+
unique_embeddings = get_embeddings_batch(clean_products, model=model, batch_size=batch_size)
|
94 |
+
|
95 |
+
# Map embeddings back to all products
|
96 |
+
progress_tracker(0.8, desc=f"Mapping embeddings back to all products")
|
97 |
+
for i, product in enumerate(products):
|
98 |
+
unique_idx = index_map[i]
|
99 |
+
if unique_idx < len(unique_embeddings) and unique_embeddings[unique_idx] is not None:
|
100 |
+
# Store as dictionary with 'embedding' key for consistent format
|
101 |
+
product_embeddings[product] = {
|
102 |
+
"embedding": unique_embeddings[unique_idx]
|
103 |
+
}
|
104 |
+
|
105 |
+
progress_tracker(0.9, desc="Processing embeddings completed")
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
progress_tracker(0.9, desc=f"Error generating embeddings: {str(e)}")
|
109 |
+
print(f"Error generating product embeddings: {e}")
|
110 |
+
|
111 |
+
progress_tracker(1.0, desc=f"Completed embeddings for {len(product_embeddings)} products")
|
112 |
+
return product_embeddings
|
113 |
+
|
114 |
+
def _generate_embeddings_for_batch(batch: List[str]) -> Dict[str, Any]:
|
115 |
+
"""
|
116 |
+
Generate embeddings for a batch of products
|
117 |
+
"""
|
118 |
+
# This is a placeholder for your actual embedding generation logic
|
119 |
+
# Replace with your actual implementation
|
120 |
+
import time
|
121 |
+
|
122 |
+
# Your existing embedding code should go here instead of this placeholder
|
123 |
+
embeddings = {}
|
124 |
+
for product in batch:
|
125 |
+
# Replace with actual embedding creation
|
126 |
+
embeddings[product] = {"embedding": [0.1, 0.2, 0.3]}
|
127 |
+
|
128 |
+
return embeddings
|
generate_category_embeddings.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
from category_matching import load_categories, create_category_embeddings
|
4 |
+
|
5 |
+
def main(categories_file, output_file):
|
6 |
+
# Load categories from the JSON file
|
7 |
+
categories = load_categories(categories_file)
|
8 |
+
if not categories:
|
9 |
+
print("No categories loaded. Exiting.")
|
10 |
+
return
|
11 |
+
print(f"Loaded {len(categories)} categories.")
|
12 |
+
|
13 |
+
# Generate category embeddings using Voyage AI
|
14 |
+
print("Generating category embeddings...")
|
15 |
+
embeddings = create_category_embeddings(categories)
|
16 |
+
|
17 |
+
# Save embeddings to pickle file
|
18 |
+
with open(output_file, 'wb') as f:
|
19 |
+
pickle.dump(embeddings, f)
|
20 |
+
print(f"Category embeddings saved to {output_file}")
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser(description="Generate and pickle category embeddings using Voyage AI.")
|
24 |
+
parser.add_argument("--categories", type=str, default="categories.json",
|
25 |
+
help="Path to the categories JSON file.")
|
26 |
+
parser.add_argument("--output", type=str, default="data/category_embeddings.pickle",
|
27 |
+
help="Path to output pickle file for embeddings")
|
28 |
+
args = parser.parse_args()
|
29 |
+
main(args.categories, args.output)
|
main.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import gradio as gr
|
5 |
+
from utils import load_embeddings
|
6 |
+
from ui import create_demo
|
7 |
+
from config import UI_THEME
|
8 |
+
from ui_formatters import set_theme
|
9 |
+
|
10 |
+
def main():
|
11 |
+
"""Main entry point for the application"""
|
12 |
+
parser = argparse.ArgumentParser(description='Run the Product Categorization web app')
|
13 |
+
parser.add_argument('--embeddings', default='data/ingredient_embeddings_voyageai.pkl',
|
14 |
+
help='Path to the ingredient embeddings pickle file')
|
15 |
+
parser.add_argument('--share', action='store_true', help='Create a public link for sharing')
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
# Check if embeddings file exists
|
20 |
+
if not os.path.exists(args.embeddings):
|
21 |
+
print(f"Error: Embeddings file {args.embeddings} not found!")
|
22 |
+
print(f"Please ensure the file exists at {os.path.abspath(args.embeddings)}")
|
23 |
+
sys.exit(1)
|
24 |
+
|
25 |
+
# Load embeddings
|
26 |
+
try:
|
27 |
+
embeddings_data = load_embeddings(args.embeddings)
|
28 |
+
# Update the embeddings in the ui_core module
|
29 |
+
import ui_core
|
30 |
+
ui_core.embeddings = embeddings_data
|
31 |
+
except Exception as e:
|
32 |
+
print(f"Error loading embeddings: {e}")
|
33 |
+
sys.exit(1)
|
34 |
+
|
35 |
+
# Set the application theme
|
36 |
+
set_theme(UI_THEME)
|
37 |
+
|
38 |
+
# Create and launch the interface
|
39 |
+
demo = create_demo()
|
40 |
+
|
41 |
+
# Launch with only supported parameters
|
42 |
+
demo.launch(
|
43 |
+
share=args.share,
|
44 |
+
show_api=False
|
45 |
+
)
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
openai_expansion.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Dict, Any
|
3 |
+
from openai import OpenAI
|
4 |
+
import concurrent.futures
|
5 |
+
from utils import SafeProgress
|
6 |
+
from api_utils import get_openai_client
|
7 |
+
|
8 |
+
def expand_product_descriptions(products: List[str],
|
9 |
+
max_workers: int = 5,
|
10 |
+
progress=None) -> Dict[str, str]:
|
11 |
+
"""
|
12 |
+
Expand product descriptions using OpenAI's structured output
|
13 |
+
|
14 |
+
Args:
|
15 |
+
products: List of product names to expand
|
16 |
+
max_workers: Maximum number of concurrent API calls
|
17 |
+
progress: Optional progress tracking object
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Dictionary mapping original product names to expanded descriptions
|
21 |
+
"""
|
22 |
+
progress_tracker = SafeProgress(progress, desc="Expanding product descriptions")
|
23 |
+
|
24 |
+
# Set up OpenAI client
|
25 |
+
openai_client = get_openai_client()
|
26 |
+
|
27 |
+
expanded_descriptions = {}
|
28 |
+
|
29 |
+
def process_product(product):
|
30 |
+
try:
|
31 |
+
response = openai_client.responses.create(
|
32 |
+
# model="o3-mini",
|
33 |
+
model="gpt-4o-mini",
|
34 |
+
max_output_tokens=100,
|
35 |
+
# reasoning={"effort": "low"},
|
36 |
+
input=[
|
37 |
+
{"role": "system", "content": """You are a product description expert. Your task is to expand product names into descriptions that would help an embedding model categorize them correctly.
|
38 |
+
"""},
|
39 |
+
{"role": "user", "content": f'Describe "{product}" to an embedding model categorizing products'}
|
40 |
+
],
|
41 |
+
text={
|
42 |
+
"format": {
|
43 |
+
"type": "json_schema",
|
44 |
+
"name": "product_description",
|
45 |
+
"schema": {
|
46 |
+
"type": "object",
|
47 |
+
"properties": {
|
48 |
+
"expanded_description": {
|
49 |
+
"type": "string",
|
50 |
+
"description": "An expanded description of the product that includes its category, type, common ingredients or components, and typical use cases."
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"required": ["expanded_description"],
|
54 |
+
"additionalProperties": False
|
55 |
+
},
|
56 |
+
"strict": True
|
57 |
+
}
|
58 |
+
}
|
59 |
+
)
|
60 |
+
|
61 |
+
# Parse the response
|
62 |
+
result = json.loads(response.output_text)
|
63 |
+
return product, result["expanded_description"]
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error expanding description for '{product}': {e}")
|
66 |
+
return product, f"{product} - No expanded description available."
|
67 |
+
|
68 |
+
# Process in batches for better parallelism
|
69 |
+
total_products = len(products)
|
70 |
+
progress_tracker(0.1, desc=f"Processing {total_products} products")
|
71 |
+
|
72 |
+
# Use thread pool for concurrent API calls
|
73 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
74 |
+
future_to_product = {executor.submit(process_product, product): i
|
75 |
+
for i, product in enumerate(products)}
|
76 |
+
|
77 |
+
for i, future in enumerate(concurrent.futures.as_completed(future_to_product)):
|
78 |
+
progress_percent = 0.1 + (0.8 * (i+1) / total_products)
|
79 |
+
product_index = future_to_product[future]
|
80 |
+
progress_tracker(progress_percent, desc=f"Expanded {i+1}/{total_products} products")
|
81 |
+
|
82 |
+
try:
|
83 |
+
original_product, expanded_description = future.result()
|
84 |
+
expanded_descriptions[original_product] = expanded_description
|
85 |
+
except Exception as e:
|
86 |
+
product = products[product_index]
|
87 |
+
print(f"Error processing expansion for '{product}': {e}")
|
88 |
+
expanded_descriptions[product] = product # Fallback to original product name
|
89 |
+
|
90 |
+
progress_tracker(1.0, desc="Expansion complete")
|
91 |
+
return expanded_descriptions
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
voyageai
|
2 |
+
numpy
|
3 |
+
gradio
|
4 |
+
openai
|
5 |
+
requests
|
6 |
+
tqdm
|
similarity.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict, List, Tuple, Any
|
3 |
+
import json
|
4 |
+
import voyageai
|
5 |
+
from openai import OpenAI
|
6 |
+
from api_utils import get_openai_client
|
7 |
+
|
8 |
+
def compute_similarities(ingredients_dict, products_dict):
|
9 |
+
"""
|
10 |
+
Compute cosine similarities between ingredient embeddings and product embeddings
|
11 |
+
|
12 |
+
Args:
|
13 |
+
ingredients_dict: Dictionary of ingredient names to embeddings
|
14 |
+
products_dict: Dictionary of product names to embedding dictionaries
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Dictionary of products with their similar ingredients and scores
|
18 |
+
"""
|
19 |
+
# Validate inputs
|
20 |
+
if not ingredients_dict:
|
21 |
+
print("Warning: ingredients_dict is empty")
|
22 |
+
return {}
|
23 |
+
|
24 |
+
if not products_dict:
|
25 |
+
print("Warning: products_dict is empty")
|
26 |
+
return {}
|
27 |
+
|
28 |
+
# Process ingredients - ensure we have proper embeddings
|
29 |
+
ingredient_names = []
|
30 |
+
ingredient_embeddings_list = []
|
31 |
+
|
32 |
+
for name, emb in ingredients_dict.items():
|
33 |
+
# Make sure we have valid embeddings (not None, not empty)
|
34 |
+
if emb is None or (isinstance(emb, (list, np.ndarray)) and len(emb) == 0):
|
35 |
+
continue
|
36 |
+
|
37 |
+
# Handle both direct embedding vectors and dictionary formats
|
38 |
+
if isinstance(emb, dict) and "embedding" in emb:
|
39 |
+
embedding_vector = emb["embedding"]
|
40 |
+
if embedding_vector is not None:
|
41 |
+
ingredient_names.append(name)
|
42 |
+
ingredient_embeddings_list.append(embedding_vector)
|
43 |
+
elif isinstance(emb, (list, np.ndarray)):
|
44 |
+
ingredient_names.append(name)
|
45 |
+
ingredient_embeddings_list.append(emb)
|
46 |
+
|
47 |
+
if not ingredient_names:
|
48 |
+
print("Warning: No valid ingredient embeddings found")
|
49 |
+
return {}
|
50 |
+
|
51 |
+
# Convert to numpy array ensuring we have a 2D array
|
52 |
+
ingredient_embeddings = np.array(ingredient_embeddings_list, dtype=np.float32)
|
53 |
+
if ingredient_embeddings.ndim == 1:
|
54 |
+
# If we got a 1D array, reshape it to 2D (1 x dimension)
|
55 |
+
print(f"Warning: Ingredient embeddings have only 1 dimension, reshaping. Shape: {ingredient_embeddings.shape}")
|
56 |
+
if len(ingredient_embeddings) > 0:
|
57 |
+
ingredient_embeddings = ingredient_embeddings.reshape(1, -1)
|
58 |
+
else:
|
59 |
+
print("Error: Empty ingredient embeddings array")
|
60 |
+
return {}
|
61 |
+
|
62 |
+
# Normalize ingredient embeddings for cosine similarity
|
63 |
+
# Add safety checks for zero norms
|
64 |
+
ingredient_norms = np.linalg.norm(ingredient_embeddings, axis=1, keepdims=True)
|
65 |
+
# Avoid division by zero
|
66 |
+
ingredient_norms = np.where(ingredient_norms == 0, 1e-10, ingredient_norms)
|
67 |
+
normalized_ingredients = ingredient_embeddings / ingredient_norms
|
68 |
+
|
69 |
+
# Process products
|
70 |
+
product_names = []
|
71 |
+
valid_embeddings = []
|
72 |
+
|
73 |
+
# Extract the actual embedding vectors from product dictionaries
|
74 |
+
for product_name, product_data in products_dict.items():
|
75 |
+
# Skip None values
|
76 |
+
if product_data is None:
|
77 |
+
continue
|
78 |
+
|
79 |
+
# Check if the product has an embedding dictionary with the expected structure
|
80 |
+
if isinstance(product_data, dict) and "embedding" in product_data:
|
81 |
+
embedding_vector = product_data["embedding"]
|
82 |
+
if embedding_vector is not None:
|
83 |
+
product_names.append(product_name)
|
84 |
+
valid_embeddings.append(embedding_vector)
|
85 |
+
# If the product data is already a vector, use it directly
|
86 |
+
elif isinstance(product_data, (list, np.ndarray)):
|
87 |
+
product_names.append(product_name)
|
88 |
+
valid_embeddings.append(product_data)
|
89 |
+
|
90 |
+
if not product_names:
|
91 |
+
print("Warning: No valid product embeddings found")
|
92 |
+
return {}
|
93 |
+
|
94 |
+
# Convert to numpy array for calculations
|
95 |
+
product_embeddings = np.array(valid_embeddings, dtype=np.float32)
|
96 |
+
|
97 |
+
# Handle case where we got a 1D array
|
98 |
+
if product_embeddings.ndim == 1:
|
99 |
+
print(f"Warning: Product embeddings have only 1 dimension, reshaping. Shape: {product_embeddings.shape}")
|
100 |
+
if len(product_embeddings) > 0:
|
101 |
+
product_embeddings = product_embeddings.reshape(1, -1)
|
102 |
+
else:
|
103 |
+
print("Error: Empty product embeddings array")
|
104 |
+
return {}
|
105 |
+
|
106 |
+
# Check and handle embedding dimension mismatch
|
107 |
+
product_dim = product_embeddings.shape[1] if product_embeddings.ndim > 1 else len(product_embeddings)
|
108 |
+
ingredient_dim = normalized_ingredients.shape[1] if normalized_ingredients.ndim > 1 else len(normalized_ingredients)
|
109 |
+
|
110 |
+
if product_dim != ingredient_dim:
|
111 |
+
print(f"Warning: Dimension mismatch between product embeddings ({product_dim}) and ingredient embeddings ({ingredient_dim})")
|
112 |
+
# Return empty results if dimensions don't match
|
113 |
+
return {}
|
114 |
+
|
115 |
+
# Normalize product embeddings for cosine similarity
|
116 |
+
product_norms = np.linalg.norm(product_embeddings, axis=1, keepdims=True)
|
117 |
+
# Avoid division by zero
|
118 |
+
product_norms = np.where(product_norms == 0, 1e-10, product_norms)
|
119 |
+
normalized_products = product_embeddings / product_norms
|
120 |
+
|
121 |
+
# Compute cosine similarity
|
122 |
+
similarity_matrix = np.dot(normalized_products, normalized_ingredients.T)
|
123 |
+
|
124 |
+
# Create result dictionary
|
125 |
+
results = {}
|
126 |
+
for i, product_name in enumerate(product_names):
|
127 |
+
similarities = similarity_matrix[i]
|
128 |
+
product_similarities = [(ingredient_names[j], float(similarities[j]))
|
129 |
+
for j in range(len(ingredient_names))]
|
130 |
+
|
131 |
+
# Sort by similarity score (descending)
|
132 |
+
product_similarities.sort(key=lambda x: x[1], reverse=True)
|
133 |
+
|
134 |
+
results[product_name] = product_similarities
|
135 |
+
|
136 |
+
return results
|
137 |
+
|
138 |
+
def hybrid_ingredient_matching(products: List[str], ingredients_dict: Dict[str, Any],
|
139 |
+
embedding_top_n: int = 20, final_top_n: int = 5,
|
140 |
+
confidence_threshold: float = 0.5,
|
141 |
+
progress=None) -> Dict[str, List[Tuple]]:
|
142 |
+
"""
|
143 |
+
Two-stage matching: first use embeddings to find candidate ingredients, then apply re-ranking
|
144 |
+
|
145 |
+
Args:
|
146 |
+
products: List of product names to categorize
|
147 |
+
ingredients_dict: Dictionary of ingredient names to embeddings
|
148 |
+
embedding_top_n: Number of top ingredients to retrieve using embeddings
|
149 |
+
final_top_n: Number of final ingredients to return after re-ranking
|
150 |
+
confidence_threshold: Minimum score threshold for final results
|
151 |
+
progress: Optional progress tracking object
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
Dictionary mapping products to their matched ingredients with scores
|
155 |
+
"""
|
156 |
+
from utils import SafeProgress
|
157 |
+
from embeddings import create_product_embeddings
|
158 |
+
|
159 |
+
progress_tracker = SafeProgress(progress, desc="Hybrid ingredient matching")
|
160 |
+
progress_tracker(0.1, desc="Stage 1: Finding candidates with embeddings")
|
161 |
+
|
162 |
+
# Stage 1: Use embeddings to find candidate ingredients
|
163 |
+
# Generate product embeddings
|
164 |
+
product_embeddings = create_product_embeddings(products, progress=progress_tracker)
|
165 |
+
|
166 |
+
# Compute similarities to get candidate ingredients
|
167 |
+
similarities = compute_similarities(ingredients_dict, product_embeddings)
|
168 |
+
|
169 |
+
# Filter to top N candidates per product
|
170 |
+
embedding_results = {}
|
171 |
+
for product, product_similarities in similarities.items():
|
172 |
+
embedding_results[product] = product_similarities[:embedding_top_n]
|
173 |
+
|
174 |
+
progress_tracker(0.4, desc="Stage 2: Re-ranking candidates")
|
175 |
+
|
176 |
+
# Initialize OpenAI client using the centralized function
|
177 |
+
openai_client = get_openai_client()
|
178 |
+
|
179 |
+
# Stage 2: Re-rank the candidates for each product
|
180 |
+
final_results = {}
|
181 |
+
|
182 |
+
for i, product in enumerate(products):
|
183 |
+
progress_tracker((0.4 + 0.5 * i / len(products)), desc=f"Re-ranking: {product}")
|
184 |
+
|
185 |
+
# Get the embedding candidates for this product
|
186 |
+
if product not in embedding_results:
|
187 |
+
final_results[product] = []
|
188 |
+
continue
|
189 |
+
|
190 |
+
candidates = embedding_results[product]
|
191 |
+
if not candidates:
|
192 |
+
final_results[product] = []
|
193 |
+
continue
|
194 |
+
|
195 |
+
# Extract just the ingredient names for re-ranking
|
196 |
+
candidate_ingredients = [c[0] for c in candidates]
|
197 |
+
|
198 |
+
try:
|
199 |
+
# Apply re-ranking using OpenAI's structured output
|
200 |
+
response = openai_client.responses.create(
|
201 |
+
model="o3-mini",
|
202 |
+
# reasoning={"effort": "low"},
|
203 |
+
input=[
|
204 |
+
{"role": "system", "content": "You are a food ingredient matching expert. Select the single best ingredient that matches the given product."},
|
205 |
+
{"role": "user", "content": f"Product: {product}\n\nPotential ingredients: {', '.join(candidate_ingredients)}"}
|
206 |
+
],
|
207 |
+
text={
|
208 |
+
"format": {
|
209 |
+
"type": "json_schema",
|
210 |
+
"name": "ingredient_selection",
|
211 |
+
"schema": {
|
212 |
+
"type": "object",
|
213 |
+
"properties": {
|
214 |
+
"best_match": {
|
215 |
+
"type": "object",
|
216 |
+
"properties": {
|
217 |
+
"ingredient": {
|
218 |
+
"type": "string",
|
219 |
+
"description": "The name of the best matching ingredient"
|
220 |
+
},
|
221 |
+
"explanation": {
|
222 |
+
"type": "string",
|
223 |
+
"description": "Brief explanation for the matching"
|
224 |
+
},
|
225 |
+
"relevance_score": {
|
226 |
+
"type": "number",
|
227 |
+
"description": "Score between 0 and 1 indicating relevance"
|
228 |
+
}
|
229 |
+
},
|
230 |
+
"required": ["ingredient", "relevance_score", "explanation"],
|
231 |
+
"additionalProperties": False
|
232 |
+
}
|
233 |
+
},
|
234 |
+
"required": ["best_match"],
|
235 |
+
"additionalProperties": False
|
236 |
+
},
|
237 |
+
"strict": True
|
238 |
+
}
|
239 |
+
}
|
240 |
+
)
|
241 |
+
|
242 |
+
# Parse the response
|
243 |
+
best_match = json.loads(response.output_text)["best_match"]
|
244 |
+
|
245 |
+
# Only include the result if it meets the confidence threshold
|
246 |
+
if best_match["relevance_score"] >= confidence_threshold:
|
247 |
+
final_results[product] = [(best_match["ingredient"], best_match["relevance_score"])]
|
248 |
+
else:
|
249 |
+
final_results[product] = []
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
print(f"Error during OpenAI re-ranking for '{product}': {e}")
|
253 |
+
# Fall back to embedding results if re-ranking fails
|
254 |
+
final_results[product] = candidates[:1] # Select the top embedding result as fallback
|
255 |
+
|
256 |
+
except Exception as e:
|
257 |
+
print(f"Error during OpenAI re-ranking for '{product}': {e}")
|
258 |
+
# Fall back to embedding results if re-ranking fails
|
259 |
+
final_results[product] = candidates[:final_top_n]
|
260 |
+
|
261 |
+
progress_tracker(1.0, desc="Hybrid ingredient matching complete")
|
262 |
+
return final_results
|
ui.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from comparison import compare_ingredient_methods_ui
|
3 |
+
|
4 |
+
# Import from our new UI modules
|
5 |
+
from ui_core import embeddings, get_css, load_examples
|
6 |
+
from ui_ingredient_matching import categorize_products
|
7 |
+
from ui_category_matching import categorize_products_by_category
|
8 |
+
from ui_hybrid_matching import categorize_products_hybrid, categorize_products_hybrid_ingredients
|
9 |
+
from ui_expanded_matching import categorize_products_with_expansion
|
10 |
+
from ui_formatters import get_formatted_css
|
11 |
+
|
12 |
+
def create_demo():
|
13 |
+
"""Create the Gradio interface"""
|
14 |
+
with gr.Blocks(css=get_css()) as demo:
|
15 |
+
gr.Markdown("# Product Categorization Tool\nAnalyze products by matching to ingredients or categories using AI embeddings.")
|
16 |
+
|
17 |
+
with gr.Tabs() as tabs:
|
18 |
+
# Original Ingredient Matching Tab
|
19 |
+
with gr.TabItem("Ingredient Matching"):
|
20 |
+
with gr.Row():
|
21 |
+
with gr.Column(scale=1):
|
22 |
+
# Input section
|
23 |
+
text_input = gr.Textbox(
|
24 |
+
lines=10,
|
25 |
+
placeholder="Enter product names, one per line",
|
26 |
+
label="Product Names"
|
27 |
+
)
|
28 |
+
input_controls = gr.Row()
|
29 |
+
with input_controls:
|
30 |
+
top_n = gr.Slider(1, 25, 10, step=1, label="Top N Results")
|
31 |
+
confidence = gr.Slider(0.1, 0.9, 0.5, label="Similarity Threshold")
|
32 |
+
|
33 |
+
with gr.Row():
|
34 |
+
examples_btn = gr.Button("Load Examples", variant="secondary")
|
35 |
+
categorize_btn = gr.Button("Find Similar Ingredients", variant="primary")
|
36 |
+
|
37 |
+
with gr.Column(scale=1):
|
38 |
+
# Results section
|
39 |
+
text_output = gr.HTML(label="Similar Ingredients Results", elem_id="results-container")
|
40 |
+
|
41 |
+
|
42 |
+
# New Hybrid Ingredient Matching Tab
|
43 |
+
with gr.TabItem("Hybrid Ingredient Matching"):
|
44 |
+
with gr.Row():
|
45 |
+
with gr.Column(scale=1):
|
46 |
+
# Input section
|
47 |
+
hybrid_ing_text_input = gr.Textbox(
|
48 |
+
lines=10,
|
49 |
+
placeholder="Enter product names, one per line",
|
50 |
+
label="Product Names"
|
51 |
+
)
|
52 |
+
hybrid_ing_input_controls = gr.Row()
|
53 |
+
with hybrid_ing_input_controls:
|
54 |
+
ing_embedding_top_n = gr.Slider(1, 50, 20, step=1, label="Embedding Top N Results")
|
55 |
+
ing_final_top_n = gr.Slider(1, 10, 5, step=1, label="Final Top N Ingredients")
|
56 |
+
hybrid_ing_confidence = gr.Slider(0.1, 0.9, 0.5, label="Matching Threshold")
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
hybrid_ing_examples_btn = gr.Button("Load Examples", variant="secondary")
|
60 |
+
hybrid_ing_match_btn = gr.Button("Match Ingredients using Hybrid Approach", variant="primary")
|
61 |
+
|
62 |
+
with gr.Column(scale=1):
|
63 |
+
# Results section
|
64 |
+
hybrid_ing_output = gr.HTML(label="Hybrid Ingredient Matching Results", elem_id="results-container")
|
65 |
+
# New Category Matching Tab
|
66 |
+
with gr.TabItem("Category Matching"):
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column(scale=1):
|
69 |
+
# Input section
|
70 |
+
category_text_input = gr.Textbox(
|
71 |
+
lines=10,
|
72 |
+
placeholder="Enter product names, one per line",
|
73 |
+
label="Product Names"
|
74 |
+
)
|
75 |
+
category_input_controls = gr.Row()
|
76 |
+
with category_input_controls:
|
77 |
+
category_top_n = gr.Slider(1, 10, 5, step=1, label="Top N Categories")
|
78 |
+
category_confidence = gr.Slider(0.1, 0.9, 0.5, label="Matching Threshold")
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
category_examples_btn = gr.Button("Load Examples", variant="secondary")
|
82 |
+
match_categories_btn = gr.Button("Match to Categories", variant="primary")
|
83 |
+
|
84 |
+
with gr.Column(scale=1):
|
85 |
+
# Results section
|
86 |
+
category_output = gr.HTML(label="Category Matching Results", elem_id="results-container")
|
87 |
+
|
88 |
+
# New Hybrid Matching Tab
|
89 |
+
with gr.TabItem("Hybrid Category Matching"):
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column(scale=1):
|
92 |
+
# Input section
|
93 |
+
hybrid_text_input = gr.Textbox(
|
94 |
+
lines=10,
|
95 |
+
placeholder="Enter product names, one per line",
|
96 |
+
label="Product Names"
|
97 |
+
)
|
98 |
+
hybrid_input_controls = gr.Row()
|
99 |
+
with hybrid_input_controls:
|
100 |
+
embedding_top_n = gr.Slider(1, 50, 20, step=1, label="Embedding Top N Results")
|
101 |
+
final_top_n = gr.Slider(1, 10, 5, step=1, label="Final Top N Categories")
|
102 |
+
hybrid_confidence = gr.Slider(0.1, 0.9, 0.5, label="Matching Threshold")
|
103 |
+
|
104 |
+
with gr.Row():
|
105 |
+
hybrid_examples_btn = gr.Button("Load Examples", variant="secondary")
|
106 |
+
hybrid_match_btn = gr.Button("Match using Hybrid Approach", variant="primary")
|
107 |
+
|
108 |
+
with gr.Column(scale=1):
|
109 |
+
# Results section
|
110 |
+
hybrid_output = gr.HTML(label="Hybrid Matching Results", elem_id="results-container")
|
111 |
+
|
112 |
+
|
113 |
+
# New Comparison Tab
|
114 |
+
with gr.TabItem("Compare Methods"):
|
115 |
+
with gr.Row():
|
116 |
+
with gr.Column():
|
117 |
+
compare_product_input = gr.Textbox(
|
118 |
+
label="Enter product names (one per line)",
|
119 |
+
placeholder="4 Tbsp sweet pickle relish\nchocolate chips\nfresh parsley",
|
120 |
+
lines=5
|
121 |
+
)
|
122 |
+
|
123 |
+
with gr.Row():
|
124 |
+
compare_embedding_top_n = gr.Slider(
|
125 |
+
minimum=5, maximum=50, value=20, step=5,
|
126 |
+
label="Initial embedding candidates"
|
127 |
+
)
|
128 |
+
compare_final_top_n = gr.Slider(
|
129 |
+
minimum=1, maximum=10, value=3, step=1,
|
130 |
+
label="Final results per method"
|
131 |
+
)
|
132 |
+
compare_confidence_threshold = gr.Slider(
|
133 |
+
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
|
134 |
+
label="Confidence threshold"
|
135 |
+
)
|
136 |
+
|
137 |
+
compare_btn = gr.Button("Compare Methods", variant="primary")
|
138 |
+
compare_examples_btn = gr.Button("Load Examples", variant="secondary")
|
139 |
+
|
140 |
+
with gr.Column():
|
141 |
+
comparison_output = gr.HTML(label="Results", elem_id="results-container")
|
142 |
+
|
143 |
+
# Connect the compare button
|
144 |
+
compare_btn.click(
|
145 |
+
fn=compare_ingredient_methods_ui,
|
146 |
+
inputs=[
|
147 |
+
compare_product_input,
|
148 |
+
gr.State(False), # Always text input mode
|
149 |
+
compare_embedding_top_n,
|
150 |
+
compare_final_top_n,
|
151 |
+
compare_confidence_threshold
|
152 |
+
],
|
153 |
+
outputs=comparison_output
|
154 |
+
)
|
155 |
+
|
156 |
+
# Add examples button functionality
|
157 |
+
compare_examples_btn.click(
|
158 |
+
fn=load_examples,
|
159 |
+
inputs=[],
|
160 |
+
outputs=compare_product_input
|
161 |
+
)
|
162 |
+
|
163 |
+
# New Expanded Description Tab
|
164 |
+
with gr.TabItem("Expanded Description Matching"):
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
# Input section
|
168 |
+
expanded_text_input = gr.Textbox(
|
169 |
+
lines=10,
|
170 |
+
placeholder="Enter product names, one per line",
|
171 |
+
label="Product Names"
|
172 |
+
)
|
173 |
+
expanded_input_controls = gr.Row()
|
174 |
+
with expanded_input_controls:
|
175 |
+
expanded_top_n = gr.Slider(1, 20, 10, step=1, label="Top N Results")
|
176 |
+
expanded_confidence = gr.Slider(0.1, 0.9, 0.5, label="Matching Threshold")
|
177 |
+
|
178 |
+
# Add toggle here for matching type
|
179 |
+
expanded_match_type = gr.Radio(
|
180 |
+
choices=["ingredients", "categories"],
|
181 |
+
value="ingredients",
|
182 |
+
label="Match Type",
|
183 |
+
info="Choose whether to match against ingredients or categories"
|
184 |
+
)
|
185 |
+
|
186 |
+
with gr.Row():
|
187 |
+
expanded_match_btn = gr.Button("Match with Expanded Descriptions", variant="primary")
|
188 |
+
expanded_examples_btn = gr.Button("Load Examples")
|
189 |
+
|
190 |
+
with gr.Column(scale=1):
|
191 |
+
# Results section
|
192 |
+
expanded_output = gr.HTML(label="Results with Expanded Descriptions", elem_id="results-container")
|
193 |
+
|
194 |
+
# Connect buttons for ingredient matching
|
195 |
+
categorize_btn.click(
|
196 |
+
fn=categorize_products,
|
197 |
+
inputs=[text_input, gr.State(False), top_n, confidence],
|
198 |
+
outputs=[text_output],
|
199 |
+
)
|
200 |
+
|
201 |
+
# Connect buttons for category matching
|
202 |
+
match_categories_btn.click(
|
203 |
+
fn=categorize_products_by_category,
|
204 |
+
inputs=[category_text_input, gr.State(False), category_top_n, category_confidence],
|
205 |
+
outputs=[category_output],
|
206 |
+
)
|
207 |
+
|
208 |
+
# Connect buttons for hybrid matching
|
209 |
+
hybrid_match_btn.click(
|
210 |
+
fn=categorize_products_hybrid,
|
211 |
+
inputs=[hybrid_text_input, gr.State(False), embedding_top_n, final_top_n, hybrid_confidence],
|
212 |
+
outputs=[hybrid_output],
|
213 |
+
)
|
214 |
+
|
215 |
+
# Connect buttons for hybrid ingredient matching
|
216 |
+
hybrid_ing_match_btn.click(
|
217 |
+
fn=categorize_products_hybrid_ingredients,
|
218 |
+
inputs=[hybrid_ing_text_input, gr.State(False), ing_embedding_top_n, ing_final_top_n, hybrid_ing_confidence],
|
219 |
+
outputs=[hybrid_ing_output],
|
220 |
+
)
|
221 |
+
|
222 |
+
hybrid_ing_examples_btn.click(
|
223 |
+
fn=load_examples, # Reuse the same examples
|
224 |
+
inputs=[],
|
225 |
+
outputs=hybrid_ing_text_input
|
226 |
+
)
|
227 |
+
|
228 |
+
# Connect buttons for expanded description matching
|
229 |
+
expanded_match_btn.click(
|
230 |
+
fn=categorize_products_with_expansion,
|
231 |
+
inputs=[expanded_text_input, gr.State(False), expanded_top_n, expanded_confidence, expanded_match_type],
|
232 |
+
outputs=[expanded_output],
|
233 |
+
)
|
234 |
+
|
235 |
+
expanded_examples_btn.click(
|
236 |
+
fn=load_examples, # Reuse the same examples
|
237 |
+
inputs=[],
|
238 |
+
outputs=expanded_text_input
|
239 |
+
)
|
240 |
+
|
241 |
+
# Examples buttons
|
242 |
+
examples_btn.click(
|
243 |
+
fn=load_examples,
|
244 |
+
inputs=[],
|
245 |
+
outputs=text_input
|
246 |
+
)
|
247 |
+
|
248 |
+
category_examples_btn.click(
|
249 |
+
fn=load_examples, # Reuse the same examples
|
250 |
+
inputs=[],
|
251 |
+
outputs=category_text_input
|
252 |
+
)
|
253 |
+
|
254 |
+
hybrid_examples_btn.click(
|
255 |
+
fn=load_examples, # Reuse the same examples
|
256 |
+
inputs=[],
|
257 |
+
outputs=hybrid_text_input
|
258 |
+
)
|
259 |
+
|
260 |
+
gr.Markdown("Powered by Voyage AI embeddings • Built with Gradio")
|
261 |
+
|
262 |
+
return demo
|
ui_category_matching.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import SafeProgress
|
3 |
+
from category_matching import load_categories, match_products_to_categories
|
4 |
+
from ui_core import parse_input
|
5 |
+
from ui_formatters import format_categories_html
|
6 |
+
|
7 |
+
def categorize_products_by_category(product_input, is_file=False, top_n=5, confidence_threshold=0.5, progress=gr.Progress()):
|
8 |
+
"""Categorize products by matching them to predefined categories"""
|
9 |
+
progress_tracker = SafeProgress(progress)
|
10 |
+
progress_tracker(0, desc="Starting categorization...")
|
11 |
+
|
12 |
+
# Parse input
|
13 |
+
product_names, error = parse_input(product_input, is_file)
|
14 |
+
if error:
|
15 |
+
return error
|
16 |
+
|
17 |
+
# Load categories
|
18 |
+
progress_tracker(0.2, desc="Loading categories...")
|
19 |
+
categories = load_categories()
|
20 |
+
|
21 |
+
# Match products to categories
|
22 |
+
progress_tracker(0.3, desc="Matching products to categories...")
|
23 |
+
match_results = match_products_to_categories(
|
24 |
+
product_names,
|
25 |
+
categories,
|
26 |
+
top_n=int(top_n),
|
27 |
+
confidence_threshold=confidence_threshold,
|
28 |
+
progress=progress
|
29 |
+
)
|
30 |
+
|
31 |
+
# Format results
|
32 |
+
progress_tracker(0.9, desc="Formatting results...")
|
33 |
+
output_html = "<div style='font-family: Arial, sans-serif; max-width: 100%; overflow-x: auto;'>"
|
34 |
+
output_html += f"<p style='color: #555;'>Matched {len(product_names)} products to categories.</p>"
|
35 |
+
|
36 |
+
for product, categories in match_results.items():
|
37 |
+
output_html += format_categories_html(product, categories)
|
38 |
+
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>"
|
39 |
+
|
40 |
+
output_html += "</div>"
|
41 |
+
|
42 |
+
if not match_results:
|
43 |
+
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
|
44 |
+
|
45 |
+
progress_tracker(1.0, desc="Done!")
|
46 |
+
return output_html
|
ui_core.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
from typing import Tuple, List, Dict, Any, Optional
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logging.basicConfig(
|
11 |
+
level=logging.INFO,
|
12 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
13 |
+
)
|
14 |
+
logger = logging.getLogger('ui_core')
|
15 |
+
|
16 |
+
# Global variables
|
17 |
+
embeddings = {}
|
18 |
+
# Update default path to point to the pickle file
|
19 |
+
EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/ingredient_embeddings_voyageai.pkl')
|
20 |
+
CATEGORY_EMBEDDINGS_PATH = os.environ.get('CATEGORY_EMBEDDINGS_PATH', 'data/category_embeddings.pickle')
|
21 |
+
|
22 |
+
def load_embeddings(filepath: str = EMBEDDINGS_PATH) -> Dict[str, Any]:
|
23 |
+
"""
|
24 |
+
Load ingredient embeddings from a pickle file
|
25 |
+
|
26 |
+
Args:
|
27 |
+
filepath: Path to the embeddings file
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Dictionary of ingredient embeddings
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
logger.info(f"Attempting to load embeddings from: {filepath}")
|
34 |
+
if not os.path.exists(filepath):
|
35 |
+
logger.error(f"Embeddings file not found: {filepath}")
|
36 |
+
# Try alternative file formats
|
37 |
+
alt_paths = [
|
38 |
+
filepath.replace('.pkl', '.pickle'),
|
39 |
+
filepath.replace('.pickle', '.pkl'),
|
40 |
+
'data/ingredient_embeddings_voyageai.pkl',
|
41 |
+
'data/ingredient_embeddings.pickle'
|
42 |
+
]
|
43 |
+
for alt_path in alt_paths:
|
44 |
+
if os.path.exists(alt_path) and alt_path != filepath:
|
45 |
+
logger.info(f"Found alternative embeddings file: {alt_path}")
|
46 |
+
filepath = alt_path
|
47 |
+
break
|
48 |
+
else:
|
49 |
+
return {}
|
50 |
+
|
51 |
+
# Determine file type and load accordingly
|
52 |
+
if filepath.endswith(('.pkl', '.pickle')):
|
53 |
+
logger.info(f"Loading pickle file: {filepath}")
|
54 |
+
with open(filepath, 'rb') as f:
|
55 |
+
loaded_embeddings = pickle.load(f)
|
56 |
+
else:
|
57 |
+
logger.info(f"Loading JSON file: {filepath}")
|
58 |
+
with open(filepath, 'r') as f:
|
59 |
+
loaded_embeddings = json.load(f)
|
60 |
+
|
61 |
+
# Validate the loaded data
|
62 |
+
if not isinstance(loaded_embeddings, dict) or not loaded_embeddings:
|
63 |
+
logger.error(f"Invalid embeddings format in {filepath}")
|
64 |
+
return {}
|
65 |
+
|
66 |
+
# Convert lists to numpy arrays for faster processing
|
67 |
+
processed_embeddings = {}
|
68 |
+
for ingredient, embedding in loaded_embeddings.items():
|
69 |
+
if isinstance(embedding, list):
|
70 |
+
processed_embeddings[ingredient] = np.array(embedding)
|
71 |
+
else:
|
72 |
+
processed_embeddings[ingredient] = embedding
|
73 |
+
|
74 |
+
logger.info(f"Successfully loaded {len(processed_embeddings)} ingredient embeddings")
|
75 |
+
return processed_embeddings
|
76 |
+
|
77 |
+
except json.JSONDecodeError:
|
78 |
+
logger.error(f"Invalid JSON format in embeddings file: {filepath}")
|
79 |
+
return {}
|
80 |
+
except pickle.UnpicklingError:
|
81 |
+
logger.error(f"Invalid pickle format in embeddings file: {filepath}")
|
82 |
+
return {}
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error loading embeddings: {str(e)}")
|
85 |
+
return {}
|
86 |
+
|
87 |
+
# Load embeddings at module import time
|
88 |
+
embeddings = load_embeddings()
|
89 |
+
|
90 |
+
# If embeddings is empty, try loading category embeddings
|
91 |
+
if not embeddings:
|
92 |
+
logger.info("No ingredient embeddings found, trying category embeddings...")
|
93 |
+
embeddings = load_embeddings(CATEGORY_EMBEDDINGS_PATH)
|
94 |
+
|
95 |
+
# Sample product names for the example button
|
96 |
+
EXAMPLE_PRODUCTS = """Nature's Promise Spring Water Multipack
|
97 |
+
Red's Burritos
|
98 |
+
Nature's Promise Spring Water Multipack
|
99 |
+
Schweppes Seltzer 12 Pack
|
100 |
+
Hunt's Pasta Sauce
|
101 |
+
Buitoni Filled Pasta
|
102 |
+
Buitoni Filled Pasta
|
103 |
+
Samuel Adams or Blue Moon 12 Pack
|
104 |
+
Mrs. T's Pierogies
|
105 |
+
Buitoni Filled Pasta
|
106 |
+
Pillsbury Dough
|
107 |
+
Nature's Promise Organic Celery Hearts
|
108 |
+
MorningStar Farms Meatless Nuggets, Patties or Crumbles
|
109 |
+
Nature's Promise Organic Celery Hearts
|
110 |
+
Boar's Head Mild Provolone Cheese
|
111 |
+
Athenos Feta Crumbles"""
|
112 |
+
|
113 |
+
def load_examples():
|
114 |
+
"""Load example product names into the text input"""
|
115 |
+
return EXAMPLE_PRODUCTS
|
116 |
+
|
117 |
+
from ui_formatters import get_formatted_css, THEME, set_theme
|
118 |
+
|
119 |
+
def get_css():
|
120 |
+
"""Return the CSS for the Gradio interface"""
|
121 |
+
return get_formatted_css()
|
122 |
+
|
123 |
+
def parse_input(input_text, is_file=False) -> Tuple[List[str], Optional[str]]:
|
124 |
+
"""Parse user input into a list of product names"""
|
125 |
+
try:
|
126 |
+
if is_file:
|
127 |
+
# Handle file input (assuming newline-separated product names)
|
128 |
+
product_names = [line.strip() for line in input_text.split('\n') if line.strip()]
|
129 |
+
else:
|
130 |
+
# Handle text input (assuming newline-separated product names)
|
131 |
+
product_names = [line.strip() for line in input_text.split('\n') if line.strip()]
|
132 |
+
|
133 |
+
if not product_names:
|
134 |
+
return [], "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No valid product names found. Please check your input.</div>"
|
135 |
+
|
136 |
+
return product_names, None
|
137 |
+
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Error parsing input: {str(e)}")
|
140 |
+
return [], f"<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error parsing input: {str(e)}</div>"
|
ui_expanded_matching.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import SafeProgress
|
3 |
+
from embeddings import create_product_embeddings
|
4 |
+
from similarity import compute_similarities
|
5 |
+
from openai_expansion import expand_product_descriptions
|
6 |
+
from ui_core import embeddings, parse_input, CATEGORY_EMBEDDINGS_PATH
|
7 |
+
from ui_formatters import format_expanded_results_html, create_results_container
|
8 |
+
from api_utils import get_openai_client, process_in_parallel, rank_ingredients_openai, rank_categories_openai
|
9 |
+
from category_matching import load_categories, load_category_embeddings
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
|
13 |
+
|
14 |
+
def categorize_products_with_expansion(product_input, is_file=False, top_n=10, confidence_threshold=0.5, match_type="ingredients", progress=gr.Progress()):
|
15 |
+
"""
|
16 |
+
Categorize products using expanded descriptions from OpenAI
|
17 |
+
|
18 |
+
Args:
|
19 |
+
product_input: Text input with product names
|
20 |
+
is_file: Whether the input is a file
|
21 |
+
top_n: Number of top results to show
|
22 |
+
confidence_threshold: Confidence threshold for matches
|
23 |
+
match_type: Either "ingredients" or "categories"
|
24 |
+
progress: Progress tracking object
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
HTML formatted results
|
28 |
+
"""
|
29 |
+
progress_tracker = SafeProgress(progress)
|
30 |
+
progress_tracker(0, desc="Starting...")
|
31 |
+
|
32 |
+
# Parse input
|
33 |
+
product_names, error = parse_input(product_input, is_file)
|
34 |
+
if error:
|
35 |
+
return error
|
36 |
+
|
37 |
+
# Validate embeddings are loaded if doing ingredient matching
|
38 |
+
if match_type == "ingredients" and not embeddings:
|
39 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.</div>"
|
40 |
+
|
41 |
+
# Expand product descriptions
|
42 |
+
progress_tracker(0.2, desc="Expanding product descriptions...")
|
43 |
+
expanded_descriptions = expand_product_descriptions(product_names, progress=progress)
|
44 |
+
|
45 |
+
if not expanded_descriptions:
|
46 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to expand product descriptions. Please try again or check your OpenAI API key.</div>"
|
47 |
+
|
48 |
+
# Get shared OpenAI client
|
49 |
+
openai_client = get_openai_client()
|
50 |
+
|
51 |
+
if match_type == "ingredients":
|
52 |
+
# Generate product embeddings
|
53 |
+
progress_tracker(0.4, desc="Generating product embeddings...")
|
54 |
+
product_embeddings = create_product_embeddings(product_names, progress=progress)
|
55 |
+
|
56 |
+
# Compute embedding similarities for ingredients
|
57 |
+
progress_tracker(0.6, desc="Computing ingredient similarities...")
|
58 |
+
all_similarities = compute_similarities(embeddings, product_embeddings)
|
59 |
+
|
60 |
+
if not all_similarities:
|
61 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No similarities found. Please try different product names.</div>"
|
62 |
+
|
63 |
+
# Setup for OpenAI reranking
|
64 |
+
embedding_top_n = 20 # Number of candidates to consider from embeddings
|
65 |
+
|
66 |
+
progress_tracker(0.7, desc="Re-ranking with expanded descriptions...")
|
67 |
+
|
68 |
+
# Function for processing each product
|
69 |
+
def process_reranking(product):
|
70 |
+
if product not in all_similarities:
|
71 |
+
return product, []
|
72 |
+
|
73 |
+
candidates = all_similarities[product][:embedding_top_n]
|
74 |
+
if not candidates:
|
75 |
+
return product, []
|
76 |
+
|
77 |
+
candidate_ingredients = [c[0] for c in candidates]
|
78 |
+
expanded_text = expanded_descriptions.get(product, "")
|
79 |
+
|
80 |
+
try:
|
81 |
+
# Use the shared utility function
|
82 |
+
reranked_ingredients = rank_ingredients_openai(
|
83 |
+
product=product,
|
84 |
+
candidates=candidate_ingredients,
|
85 |
+
expanded_description=expanded_text,
|
86 |
+
client=openai_client,
|
87 |
+
model="o3-mini",
|
88 |
+
max_results=top_n,
|
89 |
+
confidence_threshold=confidence_threshold,
|
90 |
+
debug=True
|
91 |
+
)
|
92 |
+
|
93 |
+
return product, reranked_ingredients
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Error reranking {product}: {e}")
|
97 |
+
# Fall back to top embedding match
|
98 |
+
return product, candidates[:1] if candidates[0][1] >= confidence_threshold else []
|
99 |
+
|
100 |
+
# Process all products in parallel
|
101 |
+
final_results = process_in_parallel(
|
102 |
+
items=product_names,
|
103 |
+
processor_func=process_reranking,
|
104 |
+
max_workers=min(10, len(product_names)),
|
105 |
+
progress_tracker=progress_tracker,
|
106 |
+
progress_start=0.7,
|
107 |
+
progress_end=0.9,
|
108 |
+
progress_desc="Re-ranking"
|
109 |
+
)
|
110 |
+
|
111 |
+
else: # categories
|
112 |
+
# Load category embeddings instead of JSON categories
|
113 |
+
progress_tracker(0.5, desc="Loading category embeddings...")
|
114 |
+
category_embeddings = load_category_embeddings()
|
115 |
+
|
116 |
+
if not category_embeddings:
|
117 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category embeddings found. Please check that the embeddings file exists at data/category_embeddings.pickle.</div>"
|
118 |
+
|
119 |
+
# Generate product embeddings
|
120 |
+
progress_tracker(0.6, desc="Generating product embeddings...")
|
121 |
+
product_embeddings = create_product_embeddings(product_names, progress=progress)
|
122 |
+
|
123 |
+
# Compute embedding similarities for categories
|
124 |
+
progress_tracker(0.7, desc="Computing category similarities...")
|
125 |
+
all_similarities = compute_similarities(category_embeddings, product_embeddings)
|
126 |
+
|
127 |
+
if not all_similarities:
|
128 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No category similarities found. Please try different product names.</div>"
|
129 |
+
|
130 |
+
embedding_top_n = min(20, top_n * 2) # Number of candidates to consider from embeddings
|
131 |
+
|
132 |
+
# Collect all needed category IDs first
|
133 |
+
needed_category_ids = set()
|
134 |
+
for product, similarities in all_similarities.items():
|
135 |
+
for category_id, score in similarities[:embedding_top_n]:
|
136 |
+
if score >= confidence_threshold:
|
137 |
+
needed_category_ids.add(category_id)
|
138 |
+
|
139 |
+
# Load only the needed categories from JSON
|
140 |
+
progress_tracker(0.75, desc="Loading category descriptions...")
|
141 |
+
category_descriptions = {}
|
142 |
+
if needed_category_ids:
|
143 |
+
try:
|
144 |
+
with open("categories.json", 'r') as f:
|
145 |
+
categories_list = json.load(f)
|
146 |
+
for item in categories_list:
|
147 |
+
if item["id"] in needed_category_ids:
|
148 |
+
category_descriptions[item["id"]] = item["text"]
|
149 |
+
except Exception as e:
|
150 |
+
print(f"Error loading category descriptions: {e}")
|
151 |
+
|
152 |
+
# Function to process each product
|
153 |
+
def process_category_matching(product):
|
154 |
+
if product not in all_similarities:
|
155 |
+
return product, []
|
156 |
+
|
157 |
+
candidates = all_similarities[product][:embedding_top_n]
|
158 |
+
if not candidates:
|
159 |
+
return product, []
|
160 |
+
|
161 |
+
# Get the expanded description
|
162 |
+
expanded_text = expanded_descriptions.get(product, "")
|
163 |
+
|
164 |
+
try:
|
165 |
+
# Use rank_categories_openai instead of match_products_to_categories_with_description
|
166 |
+
category_matches = rank_categories_openai(
|
167 |
+
product=product,
|
168 |
+
categories=category_descriptions,
|
169 |
+
expanded_description=expanded_text,
|
170 |
+
client=openai_client,
|
171 |
+
# model="o3-mini",
|
172 |
+
model="gpt-4o-mini",
|
173 |
+
# model="gpt-4o",
|
174 |
+
max_results=top_n,
|
175 |
+
confidence_threshold=confidence_threshold,
|
176 |
+
debug=True
|
177 |
+
)
|
178 |
+
|
179 |
+
# Format results with category descriptions if needed
|
180 |
+
formatted_matches = []
|
181 |
+
for category_id, score in category_matches:
|
182 |
+
category_text = category_descriptions.get(category_id, "Unknown category")
|
183 |
+
formatted_matches.append((category_id, category_text, score))
|
184 |
+
|
185 |
+
return product, formatted_matches
|
186 |
+
except Exception as e:
|
187 |
+
print(f"Error matching {product} to categories: {e}")
|
188 |
+
return product, []
|
189 |
+
|
190 |
+
# Process all products in parallel
|
191 |
+
final_results = process_in_parallel(
|
192 |
+
items=product_names,
|
193 |
+
processor_func=process_category_matching,
|
194 |
+
max_workers=min(10, len(product_names)),
|
195 |
+
progress_tracker=progress_tracker,
|
196 |
+
progress_start=0.7,
|
197 |
+
progress_end=0.9,
|
198 |
+
progress_desc="Category matching"
|
199 |
+
)
|
200 |
+
|
201 |
+
# Format results
|
202 |
+
progress_tracker(0.9, desc="Formatting results...")
|
203 |
+
|
204 |
+
result_elements = []
|
205 |
+
for product, matches in final_results.items():
|
206 |
+
result_elements.append(
|
207 |
+
format_expanded_results_html(
|
208 |
+
product=product,
|
209 |
+
results=matches,
|
210 |
+
expanded_description=expanded_descriptions.get(product, ""),
|
211 |
+
match_type=match_type
|
212 |
+
)
|
213 |
+
)
|
214 |
+
|
215 |
+
output_html = create_results_container(
|
216 |
+
result_elements,
|
217 |
+
header_text=f"Matched {len(product_names)} products to {match_type} using expanded descriptions."
|
218 |
+
)
|
219 |
+
|
220 |
+
if not final_results:
|
221 |
+
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
|
222 |
+
|
223 |
+
progress_tracker(1.0, desc="Done!")
|
224 |
+
return output_html
|
ui_formatters.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Tuple, Any
|
2 |
+
from utils import get_confidence_color, get_confidence_bg_color
|
3 |
+
|
4 |
+
# Theme configuration (can be easily switched between light/dark)
|
5 |
+
THEME = "light" # Options: "light", "dark"
|
6 |
+
|
7 |
+
# Theme-specific colors
|
8 |
+
THEMES = {
|
9 |
+
"light": {
|
10 |
+
"background": "#ffffff",
|
11 |
+
"card_bg": "#ffffff",
|
12 |
+
"card_border": "#ddd",
|
13 |
+
"header_bg": "#2c3e50",
|
14 |
+
"header_text": "#ffffff",
|
15 |
+
"text_primary": "#333333",
|
16 |
+
"text_secondary": "#555555",
|
17 |
+
"section_bg": "#f8f9fa",
|
18 |
+
},
|
19 |
+
"dark": {
|
20 |
+
"background": "#121212",
|
21 |
+
"card_bg": "#1e1e1e",
|
22 |
+
"card_border": "#333",
|
23 |
+
"header_bg": "#37474f",
|
24 |
+
"header_text": "#ffffff",
|
25 |
+
"text_primary": "#e0e0e0",
|
26 |
+
"text_secondary": "#b0bec5",
|
27 |
+
"section_bg": "#263238",
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
# Get current theme colors
|
32 |
+
COLORS = THEMES[THEME]
|
33 |
+
|
34 |
+
# Base styling constants (adjusted based on theme)
|
35 |
+
STYLES = {
|
36 |
+
"card": f"margin-bottom: 20px; border: 1px solid {COLORS['card_border']}; border-radius: 8px; overflow: hidden; background-color: {COLORS['card_bg']};",
|
37 |
+
"header": f"background-color: {COLORS['header_bg']}; padding: 12px 15px; border-bottom: 1px solid {COLORS['card_border']};",
|
38 |
+
"header_text": f"margin: 0; font-size: 18px; color: {COLORS['header_text']};",
|
39 |
+
"flex_container": "display: flex; flex-wrap: wrap;",
|
40 |
+
"method_container": f"flex: 1; min-width: 200px; padding: 15px; border-right: 1px solid {COLORS['card_border']};",
|
41 |
+
"method_title": f"margin-top: 0; color: {COLORS['text_primary']}; padding-bottom: 8px;",
|
42 |
+
"item_list": "list-style-type: none; padding-left: 0;",
|
43 |
+
"item": "margin-bottom: 8px; padding: 8px; border-radius: 4px;",
|
44 |
+
"empty_message": "color: #7f8c8d; font-style: italic;",
|
45 |
+
"info_panel": f"padding: 10px; background-color: {COLORS['section_bg']}; margin-bottom: 10px; border-radius: 4px;"
|
46 |
+
}
|
47 |
+
|
48 |
+
# Method colors (consistent across themes)
|
49 |
+
METHOD_COLORS = {
|
50 |
+
"base": "#f39c12", # Orange
|
51 |
+
"voyage": "#3498db", # Blue
|
52 |
+
"chicory": "#9b59b6", # Purple
|
53 |
+
"openai": "#2ecc71", # Green
|
54 |
+
"expanded": "#e74c3c", # Red
|
55 |
+
"hybrid": "#1abc9c", # Turquoise
|
56 |
+
"categories": "#1abc9c" # Same as hybrid
|
57 |
+
}
|
58 |
+
|
59 |
+
# Method display names
|
60 |
+
METHOD_NAMES = {
|
61 |
+
"base": "Base Embeddings",
|
62 |
+
"voyage": "Voyage AI Reranker",
|
63 |
+
"chicory": "Chicory Parser",
|
64 |
+
"openai": "OpenAI o3-mini",
|
65 |
+
"expanded": "Expanded Description",
|
66 |
+
"hybrid": "Hybrid Matching",
|
67 |
+
"categories": "Category Matches"
|
68 |
+
}
|
69 |
+
|
70 |
+
def format_method_results(method_key, results, color_hex=None):
|
71 |
+
"""
|
72 |
+
Format results for a single method section
|
73 |
+
|
74 |
+
Args:
|
75 |
+
method_key: Key identifying the method (base, voyage, etc.)
|
76 |
+
results: List of (name, score) tuples or format-specific data structure
|
77 |
+
color_hex: Optional color override (otherwise uses METHOD_COLORS)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
HTML string for the method section
|
81 |
+
"""
|
82 |
+
# Get color from METHOD_COLORS if not provided
|
83 |
+
if color_hex is None:
|
84 |
+
color_hex = METHOD_COLORS.get(method_key, "#777777")
|
85 |
+
|
86 |
+
# Get method name from METHOD_NAMES or use the key with capitalization
|
87 |
+
method_name = METHOD_NAMES.get(method_key, method_key.replace('_', ' ').title())
|
88 |
+
|
89 |
+
html = f"<div class='method-results' style='{STYLES['method_container']}'>"
|
90 |
+
html += f"<h4 style='{STYLES['method_title']}; border-bottom: 2px solid {color_hex};'>{method_name}</h4>"
|
91 |
+
|
92 |
+
if results:
|
93 |
+
html += f"<ul style='{STYLES['item_list']}'>"
|
94 |
+
|
95 |
+
# Handle different result formats
|
96 |
+
for item in results:
|
97 |
+
# Handle tuple with 2 elements (name, score)
|
98 |
+
if isinstance(item, tuple) and len(item) == 2:
|
99 |
+
name, score = item
|
100 |
+
# Handle tuple with 3 elements (common in category results)
|
101 |
+
elif isinstance(item, tuple) and len(item) == 3:
|
102 |
+
id_val, text, score = item
|
103 |
+
name = f"<strong>{id_val}</strong>: {text}" if text else id_val
|
104 |
+
# Handle dictionary format
|
105 |
+
elif isinstance(item, dict) and "name" in item and "score" in item:
|
106 |
+
name = item["name"]
|
107 |
+
score = item["score"]
|
108 |
+
# Handle dictionary format with different keys
|
109 |
+
elif isinstance(item, dict) and "category" in item and "confidence" in item:
|
110 |
+
name = item["category"]
|
111 |
+
score = item["confidence"]
|
112 |
+
# Handle dictionary format for ingredients
|
113 |
+
elif isinstance(item, dict) and "ingredient" in item and "relevance_score" in item:
|
114 |
+
name = item["ingredient"]
|
115 |
+
score = item["relevance_score"]
|
116 |
+
# Default case - just convert to string
|
117 |
+
else:
|
118 |
+
name = str(item)
|
119 |
+
score = 0.0
|
120 |
+
|
121 |
+
# Ensure score is a float
|
122 |
+
try:
|
123 |
+
score = float(score)
|
124 |
+
except (ValueError, TypeError):
|
125 |
+
score = 0.0
|
126 |
+
|
127 |
+
confidence_percent = int(score * 100)
|
128 |
+
confidence_color = get_confidence_color(score)
|
129 |
+
bg_color = get_confidence_bg_color(score)
|
130 |
+
|
131 |
+
# Improved layout with better contrast and labeled confidence
|
132 |
+
html += f"<li style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 6px; padding: 6px; border-radius: 4px; background-color: rgba(240, 240, 240, 0.4);'>"
|
133 |
+
html += f"<span style='font-weight: 500; flex: 1;'>{name}</span>"
|
134 |
+
html += f"<span style='background-color: {bg_color}; border: 1px solid {confidence_color}; color: #000; font-weight: 600; padding: 2px 6px; border-radius: 4px; min-width: 70px; text-align: center; margin-left: 8px;'>Confidence: {confidence_percent}%</span>"
|
135 |
+
html += "</li>"
|
136 |
+
|
137 |
+
html += "</ul>"
|
138 |
+
else:
|
139 |
+
html += f"<p style='{STYLES['empty_message']}'>No results found</p>"
|
140 |
+
|
141 |
+
html += "</div>"
|
142 |
+
return html
|
143 |
+
|
144 |
+
def format_result_card(title, content, header_bg_color=None):
|
145 |
+
"""
|
146 |
+
Create a styled card with a header and content
|
147 |
+
|
148 |
+
Args:
|
149 |
+
title: Card title
|
150 |
+
content: HTML content for the card body
|
151 |
+
header_bg_color: Optional header background color
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
HTML string for the card
|
155 |
+
"""
|
156 |
+
if header_bg_color is None:
|
157 |
+
header_bg_color = COLORS['header_bg'] # Default header background color
|
158 |
+
|
159 |
+
html = f"<div class='result-card' style='{STYLES['card']}'>"
|
160 |
+
html += f"<div class='card-header' style='{STYLES['header']}; background-color: {header_bg_color};'>"
|
161 |
+
html += f"<h3 style='{STYLES['header_text']}'>{title}</h3>"
|
162 |
+
html += "</div>"
|
163 |
+
html += f"<div class='card-content'>{content}</div>"
|
164 |
+
html += "</div>"
|
165 |
+
return html
|
166 |
+
|
167 |
+
def format_comparison_html(product, method_results):
|
168 |
+
"""
|
169 |
+
Format the comparison results as HTML
|
170 |
+
|
171 |
+
Args:
|
172 |
+
product: Product name
|
173 |
+
method_results: Dictionary with results from different methods
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
HTML string
|
177 |
+
"""
|
178 |
+
# Create the methods comparison content
|
179 |
+
methods_html = f"<div class='methods-comparison' style='{STYLES['flex_container']}'>"
|
180 |
+
|
181 |
+
# Add results for each method
|
182 |
+
for method_key in ["base", "voyage", "chicory", "openai"]:
|
183 |
+
methods_html += format_method_results(
|
184 |
+
method_key=method_key,
|
185 |
+
results=method_results.get(method_key, [])
|
186 |
+
)
|
187 |
+
|
188 |
+
methods_html += "</div>"
|
189 |
+
|
190 |
+
# Create the full card with the methods content
|
191 |
+
return format_result_card(title=product, content=methods_html)
|
192 |
+
|
193 |
+
def format_expanded_results_html(product, results, expanded_description, match_type="ingredients"):
|
194 |
+
"""
|
195 |
+
Format results using expanded descriptions
|
196 |
+
|
197 |
+
Args:
|
198 |
+
product: Product name
|
199 |
+
results: List of tuples - either (match, score) for ingredients or (id, text, score) for categories
|
200 |
+
expanded_description: Expanded product description
|
201 |
+
match_type: Either "ingredients" or "categories"
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
HTML for the result card
|
205 |
+
"""
|
206 |
+
content = ""
|
207 |
+
|
208 |
+
# Add expanded description section
|
209 |
+
content += f"<div style='{STYLES['info_panel']}'>"
|
210 |
+
content += "<h4 style='margin-top: 0; border-bottom: 1px solid rgba(0,0,0,0.1); padding-bottom: 8px;'>Expanded Description</h4>"
|
211 |
+
content += f"<p style='margin-bottom: 8px;'>{expanded_description}</p>"
|
212 |
+
content += "</div>"
|
213 |
+
|
214 |
+
# Format the results section - create custom section
|
215 |
+
color_hex = METHOD_COLORS.get(match_type, "#1abc9c")
|
216 |
+
|
217 |
+
# Add results section with custom title
|
218 |
+
content += f"<div class='method-results' style='margin-top: 15px; border-left: 3px solid {color_hex}; padding-left: 15px;'>"
|
219 |
+
|
220 |
+
title_text = "Ingredients" if match_type == "ingredients" else "Categories"
|
221 |
+
content += f"<h4 style='margin-top: 0; color: {color_hex};'>{title_text}</h4>"
|
222 |
+
|
223 |
+
if results:
|
224 |
+
content += "<ul style='margin-top: 5px; padding-left: 20px;'>"
|
225 |
+
for item in results:
|
226 |
+
# Handle both 2-value (match, score) and 3-value (id, text, score) tuples
|
227 |
+
if len(item) == 2:
|
228 |
+
match, score = item
|
229 |
+
display_text = match
|
230 |
+
elif len(item) == 3:
|
231 |
+
category_id, category_text, score = item # For categories, use both id and text
|
232 |
+
display_text = f"<strong>{category_id}</strong>: {category_text}"
|
233 |
+
else:
|
234 |
+
continue # Skip any invalid formats
|
235 |
+
|
236 |
+
confidence_percent = int(score * 100)
|
237 |
+
# Improved styling for confidence percentage - using black text for better contrast
|
238 |
+
confidence_color = get_confidence_color(score)
|
239 |
+
bg_color = get_confidence_bg_color(score)
|
240 |
+
content += f"<li style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;'>"
|
241 |
+
content += f"<span style='font-weight: 500; flex: 1;'>{display_text}</span>"
|
242 |
+
content += f"<span style='background-color: {bg_color}; border: 1px solid {confidence_color}; color: #000; font-weight: 600; padding: 2px 6px; border-radius: 4px; min-width: 70px; text-align: center; margin-left: 8px;'>Confidence: {confidence_percent}%</span>"
|
243 |
+
content += "</li>"
|
244 |
+
content += "</ul>"
|
245 |
+
else:
|
246 |
+
content += "<p style='color: #777; font-style: italic; margin: 5px 0;'>No matches found above confidence threshold.</p>"
|
247 |
+
|
248 |
+
content += "</div>"
|
249 |
+
|
250 |
+
return format_result_card(title=product, content=content)
|
251 |
+
|
252 |
+
def format_hybrid_results_html(product, results, summary=None):
|
253 |
+
"""
|
254 |
+
Format hybrid matching results as HTML
|
255 |
+
|
256 |
+
Args:
|
257 |
+
product: Product name
|
258 |
+
results: List of (ingredient, score) tuples
|
259 |
+
summary: Optional matching summary or explanation
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
HTML string
|
263 |
+
"""
|
264 |
+
content = ""
|
265 |
+
|
266 |
+
# Add summary if available
|
267 |
+
if summary:
|
268 |
+
content += f"<div class='matching-summary' style='{STYLES['info_panel']}'>"
|
269 |
+
content += f"<p style='margin: 0; font-style: italic; color: {COLORS['text_secondary']};'>{summary}</p>"
|
270 |
+
content += "</div>"
|
271 |
+
|
272 |
+
# Add the ingredient results
|
273 |
+
content += format_method_results(
|
274 |
+
method_key="hybrid",
|
275 |
+
results=results
|
276 |
+
)
|
277 |
+
|
278 |
+
return format_result_card(title=product, content=content)
|
279 |
+
|
280 |
+
def create_results_container(html_elements, header_text=None):
|
281 |
+
"""
|
282 |
+
Create a container for multiple results
|
283 |
+
|
284 |
+
Args:
|
285 |
+
html_elements: List of HTML strings to include
|
286 |
+
header_text: Optional header text
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
HTML string for the container
|
290 |
+
"""
|
291 |
+
container = "<div class='results-container' style='font-family: Arial, sans-serif;'>"
|
292 |
+
|
293 |
+
if header_text:
|
294 |
+
container += f"<p style='color: {COLORS['text_secondary']};'>{header_text}</p>"
|
295 |
+
|
296 |
+
container += ''.join(html_elements)
|
297 |
+
container += "</div>"
|
298 |
+
|
299 |
+
return container
|
300 |
+
|
301 |
+
def format_categories_html(product, categories, chicory_result=None, header_color=None):
|
302 |
+
"""
|
303 |
+
Format category matching results as HTML
|
304 |
+
|
305 |
+
Args:
|
306 |
+
product: Product name
|
307 |
+
categories: List of (category, score) tuples
|
308 |
+
chicory_result: Optional chicory parser result for the product
|
309 |
+
header_color: Optional header background color
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
HTML string
|
313 |
+
"""
|
314 |
+
content = ""
|
315 |
+
|
316 |
+
# Add Chicory results if available
|
317 |
+
if chicory_result:
|
318 |
+
content += f"<div style='{STYLES['info_panel']}'>"
|
319 |
+
content += "<h4 style='margin-top: 0; border-bottom: 1px solid rgba(0,0,0,0.1); padding-bottom: 8px;'>Chicory Parser Results</h4>"
|
320 |
+
|
321 |
+
if isinstance(chicory_result, dict):
|
322 |
+
ingredient = chicory_result.get("ingredient", "Not found")
|
323 |
+
confidence = chicory_result.get("confidence", 0)
|
324 |
+
confidence_percent = int(confidence * 100)
|
325 |
+
|
326 |
+
content += f"<div style='display: flex; justify-content: space-between; align-items: center; padding: 8px; border-radius: 4px;'>"
|
327 |
+
content += f"<span style='font-weight: bold;'>{ingredient}</span>"
|
328 |
+
content += f"<span style='background-color: {get_confidence_bg_color(confidence)}; border: 1px solid {get_confidence_color(confidence)}; color: #000; font-weight: 600; padding: 2px 6px; border-radius: 4px; min-width: 70px; text-align: center;'>Confidence: {confidence_percent}%</span>"
|
329 |
+
content += "</div>"
|
330 |
+
else:
|
331 |
+
content += f"<p style='{STYLES['empty_message']}'>No Chicory results available</p>"
|
332 |
+
|
333 |
+
content += "</div>"
|
334 |
+
|
335 |
+
# Add the category results
|
336 |
+
content += format_method_results(
|
337 |
+
method_key="categories",
|
338 |
+
results=categories,
|
339 |
+
color_hex=header_color or METHOD_COLORS.get("categories", "#1abc9c")
|
340 |
+
)
|
341 |
+
|
342 |
+
return format_result_card(title=product, content=content)
|
343 |
+
|
344 |
+
def get_formatted_css():
|
345 |
+
"""
|
346 |
+
Generate CSS for the UI based on current theme
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
CSS string ready to use in Gradio
|
350 |
+
"""
|
351 |
+
return f"""
|
352 |
+
.gradio-container .prose {{
|
353 |
+
max-width: 100%;
|
354 |
+
}}
|
355 |
+
#results-container {{
|
356 |
+
height: 600px !important;
|
357 |
+
overflow-y: auto !important;
|
358 |
+
overflow-x: hidden !important;
|
359 |
+
padding: 15px !important;
|
360 |
+
border: 1px solid {COLORS['card_border']} !important;
|
361 |
+
background-color: {COLORS['background']} !important;
|
362 |
+
color: {COLORS['text_primary']} !important;
|
363 |
+
}}
|
364 |
+
/* Style for method columns */
|
365 |
+
.methods-comparison {{
|
366 |
+
display: flex;
|
367 |
+
flex-wrap: wrap;
|
368 |
+
}}
|
369 |
+
.method-results {{
|
370 |
+
flex: 1;
|
371 |
+
min-width: 200px;
|
372 |
+
padding: 15px;
|
373 |
+
border-right: 1px solid {COLORS['card_border']};
|
374 |
+
}}
|
375 |
+
/* Make the product header more visible */
|
376 |
+
.product-header {{
|
377 |
+
background-color: {COLORS['header_bg']} !important;
|
378 |
+
padding: 12px 15px !important;
|
379 |
+
border-bottom: 1px solid {COLORS['card_border']} !important;
|
380 |
+
}}
|
381 |
+
.product-header h3 {{
|
382 |
+
margin: 0 !important;
|
383 |
+
font-size: 18px !important;
|
384 |
+
color: {COLORS['header_text']} !important;
|
385 |
+
background-color: transparent !important;
|
386 |
+
}}
|
387 |
+
/* Remove all nested scrollbars */
|
388 |
+
#results-container * {{
|
389 |
+
overflow: visible !important;
|
390 |
+
height: auto !important;
|
391 |
+
max-height: none !important;
|
392 |
+
}}
|
393 |
+
"""
|
394 |
+
|
395 |
+
def set_theme(theme_name):
|
396 |
+
"""
|
397 |
+
Set the UI theme (light or dark)
|
398 |
+
|
399 |
+
Args:
|
400 |
+
theme_name: 'light' or 'dark'
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
None - updates global variables
|
404 |
+
"""
|
405 |
+
global THEME, COLORS, STYLES
|
406 |
+
|
407 |
+
if theme_name in THEMES:
|
408 |
+
THEME = theme_name
|
409 |
+
COLORS = THEMES[THEME]
|
410 |
+
|
411 |
+
# Update styles with new theme colors
|
412 |
+
STYLES.update({
|
413 |
+
"card": f"margin-bottom: 20px; border: 1px solid {COLORS['card_border']}; border-radius: 8px; overflow: hidden; background-color: {COLORS['card_bg']};",
|
414 |
+
"header": f"background-color: {COLORS['header_bg']}; padding: 12px 15px; border-bottom: 1px solid {COLORS['card_border']};",
|
415 |
+
"header_text": f"margin: 0; font-size: 18px; color: {COLORS['header_text']};",
|
416 |
+
"method_container": f"flex: 1; min-width: 200px; padding: 15px; border-right: 1px solid {COLORS['card_border']};",
|
417 |
+
"method_title": f"margin-top: 0; color: {COLORS['text_primary']}; padding-bottom: 8px;",
|
418 |
+
"info_panel": f"padding: 10px; background-color: {COLORS['section_bg']}; margin-bottom: 10px; border-radius: 4px;"
|
419 |
+
})
|
ui_hybrid_matching.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import SafeProgress
|
3 |
+
from category_matching import load_categories, hybrid_category_matching
|
4 |
+
from similarity import hybrid_ingredient_matching
|
5 |
+
from ui_core import embeddings, parse_input
|
6 |
+
from ui_formatters import format_hybrid_results_html, create_results_container
|
7 |
+
|
8 |
+
def categorize_products_hybrid_generic(product_input, is_file=False, embedding_top_n=20,
|
9 |
+
final_top_n=5, confidence_threshold=0.5,
|
10 |
+
match_type="categories",
|
11 |
+
progress=gr.Progress()):
|
12 |
+
"""Generic hybrid matching for either categories or ingredients"""
|
13 |
+
progress_tracker = SafeProgress(progress)
|
14 |
+
progress_tracker(0, desc=f"Starting hybrid {match_type} matching...")
|
15 |
+
|
16 |
+
# Parse input
|
17 |
+
product_names, error = parse_input(product_input, is_file)
|
18 |
+
if error:
|
19 |
+
return error
|
20 |
+
|
21 |
+
# Determine which matching function to use
|
22 |
+
if match_type == "categories":
|
23 |
+
# Load categories
|
24 |
+
progress_tracker(0.2, desc="Loading categories...")
|
25 |
+
categories = load_categories()
|
26 |
+
|
27 |
+
# Use hybrid approach for categories
|
28 |
+
progress_tracker(0.3, desc="Finding and re-ranking categories...")
|
29 |
+
match_results = hybrid_category_matching(
|
30 |
+
product_names, categories,
|
31 |
+
embedding_top_n=int(embedding_top_n),
|
32 |
+
final_top_n=int(final_top_n),
|
33 |
+
confidence_threshold=confidence_threshold,
|
34 |
+
progress=progress
|
35 |
+
)
|
36 |
+
else: # ingredients
|
37 |
+
# Use hybrid approach for ingredients
|
38 |
+
progress_tracker(0.3, desc="Finding and re-ranking ingredients...")
|
39 |
+
match_results = hybrid_ingredient_matching(
|
40 |
+
product_names, embeddings,
|
41 |
+
embedding_top_n=int(embedding_top_n),
|
42 |
+
final_top_n=int(final_top_n),
|
43 |
+
confidence_threshold=confidence_threshold,
|
44 |
+
progress=progress
|
45 |
+
)
|
46 |
+
|
47 |
+
# Format results using centralized formatters
|
48 |
+
progress_tracker(0.9, desc="Formatting results...")
|
49 |
+
|
50 |
+
result_elements = []
|
51 |
+
for product, matches in match_results.items():
|
52 |
+
result_elements.append(
|
53 |
+
format_hybrid_results_html(
|
54 |
+
product=product,
|
55 |
+
results=matches,
|
56 |
+
summary=f"{match_type.capitalize()} matches using hybrid approach."
|
57 |
+
)
|
58 |
+
)
|
59 |
+
|
60 |
+
output_html = create_results_container(
|
61 |
+
result_elements,
|
62 |
+
header_text=f"Matched {len(product_names)} products to {match_type} using hybrid approach."
|
63 |
+
)
|
64 |
+
|
65 |
+
if not match_results:
|
66 |
+
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
|
67 |
+
|
68 |
+
progress_tracker(1.0, desc="Done!")
|
69 |
+
return output_html
|
70 |
+
|
71 |
+
# Then use it like this:
|
72 |
+
def categorize_products_hybrid(product_input, is_file=False, embedding_top_n=20,
|
73 |
+
final_top_n=5, confidence_threshold=0.5,
|
74 |
+
progress=gr.Progress()):
|
75 |
+
return categorize_products_hybrid_generic(
|
76 |
+
product_input, is_file, embedding_top_n, final_top_n,
|
77 |
+
confidence_threshold, match_type="categories", progress=progress
|
78 |
+
)
|
79 |
+
|
80 |
+
def categorize_products_hybrid_ingredients(product_input, is_file=False, embedding_top_n=20,
|
81 |
+
final_top_n=5, confidence_threshold=0.5,
|
82 |
+
progress=gr.Progress()):
|
83 |
+
return categorize_products_hybrid_generic(
|
84 |
+
product_input, is_file, embedding_top_n, final_top_n,
|
85 |
+
confidence_threshold, match_type="ingredients", progress=progress
|
86 |
+
)
|
ui_ingredient_matching.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import SafeProgress
|
3 |
+
from embeddings import create_product_embeddings
|
4 |
+
from similarity import compute_similarities
|
5 |
+
from chicory_api import call_chicory_parser
|
6 |
+
from ui_core import embeddings, parse_input
|
7 |
+
from ui_formatters import format_categories_html, create_results_container
|
8 |
+
|
9 |
+
def categorize_products(product_input, is_file=False, top_n=10, confidence_threshold=0.5, progress=gr.Progress()):
|
10 |
+
"""Categorize products from text input or file"""
|
11 |
+
progress_tracker = SafeProgress(progress)
|
12 |
+
progress_tracker(0, desc="Starting...")
|
13 |
+
|
14 |
+
# Parse input
|
15 |
+
product_names, error = parse_input(product_input, is_file)
|
16 |
+
if error:
|
17 |
+
return error
|
18 |
+
|
19 |
+
# Validate embeddings are loaded
|
20 |
+
if not embeddings:
|
21 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: No ingredient embeddings loaded. Please check that the embeddings file exists and is properly formatted.</div>"
|
22 |
+
|
23 |
+
# Create embeddings
|
24 |
+
progress_tracker(0.2, desc="Generating product embeddings...")
|
25 |
+
products_embeddings = create_product_embeddings(product_names, progress=progress)
|
26 |
+
|
27 |
+
if not products_embeddings:
|
28 |
+
return "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>Error: Failed to generate product embeddings. Please try again with different product names.</div>"
|
29 |
+
|
30 |
+
# Call Chicory Parser API
|
31 |
+
progress_tracker(0.5, desc="Calling Chicory Parser API...")
|
32 |
+
chicory_results = call_chicory_parser(product_names, progress=progress)
|
33 |
+
|
34 |
+
# Compute similarities
|
35 |
+
progress_tracker(0.7, desc="Computing similarities...")
|
36 |
+
all_similarities = compute_similarities(embeddings, products_embeddings)
|
37 |
+
|
38 |
+
# Format results
|
39 |
+
progress_tracker(0.9, desc="Formatting results...")
|
40 |
+
output_html = "<div style='font-family: Arial, sans-serif; max-width: 100%; overflow-x: auto;'>"
|
41 |
+
output_html += f"<p style='color: #555;'>Processing {len(product_names)} products.</p>"
|
42 |
+
|
43 |
+
for product, similarities in all_similarities.items():
|
44 |
+
filtered_similarities = [(ingredient, score) for ingredient, score in similarities if score >= confidence_threshold]
|
45 |
+
top_similarities = filtered_similarities[:int(top_n)]
|
46 |
+
|
47 |
+
# Debug info for Chicory results
|
48 |
+
chicory_data = chicory_results.get(product, [])
|
49 |
+
|
50 |
+
output_html += format_categories_html(product, top_similarities, chicory_result=chicory_data)
|
51 |
+
output_html += "<hr style='margin: 15px 0; border: 0; border-top: 1px solid #eee;'>"
|
52 |
+
|
53 |
+
output_html += "</div>"
|
54 |
+
|
55 |
+
if not all_similarities:
|
56 |
+
output_html = "<div style='color: #d32f2f; font-weight: bold; padding: 20px;'>No results found. Please check your input or try different products.</div>"
|
57 |
+
|
58 |
+
progress_tracker(1.0, desc="Done!")
|
59 |
+
return output_html
|
utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Iterator, Any
|
2 |
+
from tqdm import tqdm as tqdm_original
|
3 |
+
import sys
|
4 |
+
import pickle
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
|
8 |
+
class SafeProgress:
|
9 |
+
"""Wrapper for progress tracking that handles both tqdm (console) and Gradio progress"""
|
10 |
+
|
11 |
+
def __init__(self, progress_obj=None, desc="Processing", track_tqdm=True):
|
12 |
+
self.progress = progress_obj
|
13 |
+
self.desc = desc
|
14 |
+
self.track_tqdm = track_tqdm
|
15 |
+
self.console_progress = None
|
16 |
+
|
17 |
+
def __call__(self, value, desc=None):
|
18 |
+
"""Update progress indicator directly with a value"""
|
19 |
+
if desc is None:
|
20 |
+
desc = self.desc
|
21 |
+
|
22 |
+
# Update Gradio progress if available
|
23 |
+
if self.progress is not None:
|
24 |
+
try:
|
25 |
+
self.progress(value, desc=desc)
|
26 |
+
except Exception as e:
|
27 |
+
print(f"Progress update error: {e}")
|
28 |
+
|
29 |
+
# Always show console progress
|
30 |
+
if value < 1.0 and self.console_progress is None:
|
31 |
+
# Initialize console progress bar
|
32 |
+
self.console_progress = tqdm_original(total=100, desc=desc, file=sys.stdout)
|
33 |
+
self.console_progress.update(int(value * 100))
|
34 |
+
elif value < 1.0:
|
35 |
+
# Update existing console progress bar
|
36 |
+
current = int(value * 100)
|
37 |
+
previous = self.console_progress.n
|
38 |
+
if current > previous:
|
39 |
+
self.console_progress.update(current - previous)
|
40 |
+
self.console_progress.set_description(desc)
|
41 |
+
elif self.console_progress is not None:
|
42 |
+
# Complete and close the progress bar
|
43 |
+
self.console_progress.update(100 - self.console_progress.n)
|
44 |
+
self.console_progress.close()
|
45 |
+
self.console_progress = None
|
46 |
+
|
47 |
+
def tqdm(self, iterable, desc=None, total=None):
|
48 |
+
"""Wrap an iterable with a progress bar for iteration"""
|
49 |
+
if desc is None:
|
50 |
+
desc = self.desc
|
51 |
+
|
52 |
+
# Track with Gradio if available
|
53 |
+
if self.progress is not None:
|
54 |
+
if hasattr(self.progress, 'tqdm'):
|
55 |
+
# Use Gradio's tqdm if available
|
56 |
+
for item in self.progress.tqdm(iterable, desc=desc, total=total):
|
57 |
+
yield item
|
58 |
+
return
|
59 |
+
|
60 |
+
# Always provide console progress bar
|
61 |
+
length = total if total is not None else len(iterable) if hasattr(iterable, "__len__") else None
|
62 |
+
|
63 |
+
with tqdm_original(iterable, desc=desc, total=length, file=sys.stdout) as pbar:
|
64 |
+
# Track progress in Gradio manually if needed
|
65 |
+
i = 0
|
66 |
+
for item in pbar:
|
67 |
+
if self.progress is not None and length:
|
68 |
+
self.progress((i + 1) / length, desc=desc)
|
69 |
+
yield item
|
70 |
+
i += 1
|
71 |
+
|
72 |
+
def load_embeddings(embeddings_path):
|
73 |
+
"""Load ingredient embeddings from pickle file"""
|
74 |
+
print(f"Loading ingredient embeddings from {embeddings_path}")
|
75 |
+
with open(embeddings_path, "rb") as f:
|
76 |
+
ingredients_embeddings = pickle.load(f)
|
77 |
+
print(f"Loaded {len(ingredients_embeddings)} ingredient embeddings")
|
78 |
+
return ingredients_embeddings
|
79 |
+
|
80 |
+
|
81 |
+
def preprocess_product_for_matching(product, progress=None, description=None):
|
82 |
+
"""
|
83 |
+
Preprocess a product for ingredient matching.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
product (dict): Product dictionary containing at minimum 'name' and 'ingredients'
|
87 |
+
progress (SafeProgress, optional): Progress bar to update
|
88 |
+
description (str, optional): Description for progress update
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
dict: Processed product with normalized fields ready for matching
|
92 |
+
"""
|
93 |
+
try:
|
94 |
+
# Extract essential product info
|
95 |
+
processed_product = {
|
96 |
+
'id': product.get('id', ''),
|
97 |
+
'name': product.get('name', '').strip(),
|
98 |
+
'ingredients': product.get('ingredients', '').strip(),
|
99 |
+
'image_url': product.get('image_url', ''),
|
100 |
+
'url': product.get('url', ''),
|
101 |
+
}
|
102 |
+
|
103 |
+
# Skip products without ingredients
|
104 |
+
if not processed_product['ingredients']:
|
105 |
+
if progress:
|
106 |
+
progress.update(1, description=f"{description}: Skipping product without ingredients")
|
107 |
+
return None
|
108 |
+
|
109 |
+
# Normalize ingredients text
|
110 |
+
processed_product['ingredients'] = processed_product['ingredients'].replace('\n', ' ').strip()
|
111 |
+
|
112 |
+
# Additional preprocessing could be added here
|
113 |
+
|
114 |
+
if progress:
|
115 |
+
progress.update(1, description=f"{description}: Processed {processed_product['name']}")
|
116 |
+
|
117 |
+
return processed_product
|
118 |
+
except Exception as e:
|
119 |
+
if progress:
|
120 |
+
progress.update(1, description=f"{description}: Error processing product: {str(e)}")
|
121 |
+
return None
|
122 |
+
|
123 |
+
# Keep these color utility functions in utils.py as they're generic helpers:
|
124 |
+
def get_confidence_color(score):
|
125 |
+
"""Get color based on confidence score"""
|
126 |
+
if score >= 0.8:
|
127 |
+
return "#1a8a38" # Strong green
|
128 |
+
elif score >= 0.65:
|
129 |
+
return "#4caf50" # Medium green
|
130 |
+
elif score >= 0.5:
|
131 |
+
return "#8bc34a" # Light green
|
132 |
+
else:
|
133 |
+
return "#9e9e9e" # Gray
|
134 |
+
|
135 |
+
def get_confidence_bg_color(score):
|
136 |
+
"""Get background color for confidence badge based on score"""
|
137 |
+
if score >= 0.8:
|
138 |
+
return "#2e7d32" # Dark green
|
139 |
+
elif score >= 0.65:
|
140 |
+
return "#558b2f" # Medium green
|
141 |
+
elif score >= 0.5:
|
142 |
+
return "#9e9d24" # Light green/yellow
|
143 |
+
else:
|
144 |
+
return "#757575" # Gray
|
145 |
+
|
146 |
+
def get_confidence_text_color(score):
|
147 |
+
"""Get text color that's readable on the confidence background"""
|
148 |
+
if score >= 0.5:
|
149 |
+
return "#ffffff" # White text on dark backgrounds
|
150 |
+
else:
|
151 |
+
return "#f5f5f5" # Light gray on gray background
|
152 |
+
|
153 |
+
# Remove any UI formatting-specific functions that now exist in ui_formatters.py:
|
154 |
+
# - format_categories_html
|
155 |
+
# - create_results_container
|
156 |
+
# - Any other UI formatting functions
|