Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Visual Search System - Fixed Threading & Session State Issues | |
----------------------------------------------------------- | |
- Fixed: No SessionInfo/ScriptRunContext errors | |
- Fixed: Embedding generation with proper error handling | |
- Thread-safe progress tracking without Streamlit APIs | |
- All session state access moved to main thread only | |
""" | |
import os | |
import json | |
import time | |
import threading | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import requests | |
import pandas as pd | |
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
import datetime | |
from typing import Optional, Tuple, List | |
# ----------------------- | |
# Configuration | |
# ----------------------- | |
IMAGES_DIR = Path("images") | |
EMBED_DIR = Path("embeddings") | |
CSV_FILE = Path("photos_url.csv") | |
PROGRESS_FILE = Path("progress.json") | |
SETUP_COMPLETE_FILE = Path("setup_complete.flag") | |
MAX_IMAGES = 250 # Set to 250 as requested | |
JPEG_QUALITY = 85 | |
TARGET_MAX_SIZE = (800, 800) | |
MAX_WORKERS = 6 # Reduced for stability | |
RETRY_COUNT = 3 | |
BATCH_SIZE = 20 | |
EMB_NPY = EMBED_DIR / "image_embeddings.npy" | |
EMB_INDEX_JSON = EMBED_DIR / "index.json" | |
HIST_BINS_PER_CHANNEL = 32 | |
HIST_RANGE = (0, 256) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Use a valid public embedding model for text queries via HF Inference API | |
CLIP_MODEL = "sentence-transformers/clip-ViT-B-32" | |
API_URL = f"https://api-inference.huggingface.co/models/{CLIP_MODEL}" | |
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
# Phase Constants | |
PHASE_IDLE = "idle" | |
PHASE_1_DOWNLOAD = "download" | |
PHASE_2_EMBEDDING = "embedding" | |
PHASE_3_COMPLETE = "complete" | |
PHASE_ERROR = "error" | |
# ----------------------- | |
# Thread-Safe Progress Tracker (NO Streamlit APIs) | |
# ----------------------- | |
class SafeProgressTracker: | |
"""Progress tracker that doesn't use Streamlit APIs in threads""" | |
def __init__(self): | |
self._lock = threading.Lock() | |
self._data = {} | |
def update(self, phase: str, current: int, total: int, fails: int = 0, | |
message: str = "", details: str = ""): | |
"""Update progress - safe to call from any thread""" | |
with self._lock: | |
percentage = (current / total * 100) if total > 0 else 0 | |
phase_names = { | |
PHASE_IDLE: "π Initializing", | |
PHASE_1_DOWNLOAD: "π₯ Downloading Images", | |
PHASE_2_EMBEDDING: "π§ Creating Embeddings", | |
PHASE_3_COMPLETE: "β System Ready", | |
PHASE_ERROR: "β Error Occurred" | |
} | |
self._data = { | |
"phase": phase, | |
"phase_name": phase_names.get(phase, f"Phase: {phase}"), | |
"current": current, | |
"total": total, | |
"fails": fails, | |
"percentage": percentage, | |
"message": message, | |
"details": details, | |
"timestamp": time.time(), | |
"formatted_time": datetime.datetime.now().strftime("%H:%M:%S") | |
} | |
# Save to file (no Streamlit APIs) | |
try: | |
with open(PROGRESS_FILE, 'w') as f: | |
json.dump(self._data, f, indent=2) | |
except Exception as e: | |
print(f"Progress save error: {e}") | |
def read(self) -> Optional[dict]: | |
"""Read current progress""" | |
with self._lock: | |
try: | |
if PROGRESS_FILE.exists(): | |
with open(PROGRESS_FILE, 'r') as f: | |
return json.load(f) | |
return self._data.copy() if self._data else None | |
except: | |
return self._data.copy() if self._data else None | |
# Global progress tracker (thread-safe) | |
progress_tracker = SafeProgressTracker() | |
# ----------------------- | |
# Utility Functions | |
# ----------------------- | |
def ensure_dirs(): | |
"""Create directories if they don't exist""" | |
try: | |
IMAGES_DIR.mkdir(parents=True, exist_ok=True) | |
EMBED_DIR.mkdir(parents=True, exist_ok=True) | |
except Exception as e: | |
print(f"Directory creation error: {e}") | |
def seq_filename(i: int) -> str: | |
return f"{i:04d}.jpg" | |
def load_csv_urls() -> List[str]: | |
"""Load URLs from CSV with robust error handling""" | |
try: | |
if not CSV_FILE.exists(): | |
return [] | |
# Try different encodings | |
for encoding in ['utf-8', 'utf-8-sig', 'latin1']: | |
try: | |
df = pd.read_csv(CSV_FILE, encoding=encoding) | |
break | |
except UnicodeDecodeError: | |
continue | |
else: | |
print("Failed to read CSV with any encoding") | |
return [] | |
# Find URL column | |
url_cols = [c for c in df.columns if "url" in c.lower()] | |
if not url_cols: | |
url_cols = [df.columns[0]] | |
urls = df[url_cols[0]].astype(str).tolist()[:MAX_IMAGES] | |
# Filter valid URLs | |
valid_urls = [] | |
for url in urls: | |
url = str(url).strip() | |
if (url and url.lower() != 'nan' and | |
(url.startswith('http://') or url.startswith('https://'))): | |
valid_urls.append(url) | |
return valid_urls | |
except Exception as e: | |
print(f"CSV loading error: {e}") | |
return [] | |
def is_setup_complete() -> bool: | |
"""Check if setup is complete""" | |
try: | |
return (SETUP_COMPLETE_FILE.exists() and | |
EMB_NPY.exists() and | |
EMB_INDEX_JSON.exists() and | |
EMB_NPY.stat().st_size > 0 and | |
EMB_INDEX_JSON.stat().st_size > 0) | |
except: | |
return False | |
def mark_setup_complete(): | |
"""Mark setup as complete""" | |
try: | |
completion_data = { | |
"completed_at": time.time(), | |
"formatted_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"images_processed": len(list(IMAGES_DIR.glob("*.jpg"))), | |
"embeddings_created": True | |
} | |
with open(SETUP_COMPLETE_FILE, 'w') as f: | |
json.dump(completion_data, f, indent=2) | |
progress_tracker.update(PHASE_3_COMPLETE, 1, 1, 0, | |
"β Setup Complete!", "Ready for search") | |
except Exception as e: | |
print(f"Setup completion error: {e}") | |
# ----------------------- | |
# Image Download Functions (Thread-Safe) | |
# ----------------------- | |
def download_single_image(i: int, url: str) -> bool: | |
"""Download single image - NO Streamlit APIs""" | |
fname = IMAGES_DIR / seq_filename(i) | |
if fname.exists() and fname.stat().st_size > 0: | |
return True | |
for attempt in range(RETRY_COUNT): | |
try: | |
# Use longer timeouts | |
response = requests.get(url, stream=True, timeout=(30, 90)) | |
if response.status_code != 200: | |
if attempt == RETRY_COUNT - 1: | |
return False | |
time.sleep(2 ** attempt) # Exponential backoff | |
continue | |
img = Image.open(response.raw).convert("RGB") | |
img.thumbnail(TARGET_MAX_SIZE, Image.Resampling.LANCZOS) | |
# Atomic save | |
temp_fname = fname.with_suffix('.tmp') | |
img.save(temp_fname, "JPEG", quality=JPEG_QUALITY, optimize=True) | |
temp_fname.replace(fname) | |
return True | |
except Exception as e: | |
if attempt == RETRY_COUNT - 1: | |
print(f"Download failed {url}: {e}") | |
return False | |
time.sleep(2 ** attempt) | |
return False | |
def process_downloads_thread_safe(urls: List[str]) -> bool: | |
"""Download images in background thread - NO Streamlit APIs""" | |
if not urls: | |
progress_tracker.update(PHASE_1_DOWNLOAD, 1, 1, 0, | |
"β No URLs provided", "Skipping downloads") | |
return True | |
# Find what needs downloading | |
tasks = [] | |
for i, url in enumerate(urls, 1): | |
img_path = IMAGES_DIR / seq_filename(i) | |
if not (img_path.exists() and img_path.stat().st_size > 0): | |
tasks.append((i, url)) | |
if not tasks: | |
progress_tracker.update(PHASE_1_DOWNLOAD, len(urls), len(urls), 0, | |
"β All images already downloaded", "") | |
return True | |
total = len(tasks) | |
completed = 0 | |
failed = 0 | |
progress_tracker.update(PHASE_1_DOWNLOAD, 0, total, 0, | |
f"π Downloading {total} images...", "Starting download") | |
try: | |
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: | |
futures = {executor.submit(download_single_image, i, url): (i, url) | |
for i, url in tasks} | |
for future in as_completed(futures): | |
i, url = futures[future] | |
try: | |
success = future.result() | |
completed += 1 | |
if not success: | |
failed += 1 | |
# Update progress every 5 completions | |
if completed % 5 == 0 or completed == total: | |
success_rate = ((completed - failed) / completed * 100) if completed > 0 else 0 | |
details = f"β {completed - failed} successful β’ β {failed} failed" | |
message = f"π₯ Downloaded {completed}/{total} ({success_rate:.1f}% success)" | |
progress_tracker.update(PHASE_1_DOWNLOAD, completed, total, failed, | |
message, details) | |
except Exception as e: | |
print(f"Download future error: {e}") | |
failed += 1 | |
completed += 1 | |
except Exception as e: | |
# Do not hard fail; continue to embeddings with whatever is available | |
progress_tracker.update(PHASE_1_DOWNLOAD, completed, total, failed, | |
f"β οΈ Download issues: {e}", "Continuing setup") | |
return True | |
success_rate = ((total - failed) / total * 100) if total > 0 else 100 | |
# Proceed regardless; embedding step can handle zero/partial images | |
if total > 0: | |
progress_tracker.update(PHASE_1_DOWNLOAD, total, total, failed, | |
f"π₯ Downloaded {total - failed}/{total} ({success_rate:.1f}% success)", | |
"Proceeding to embeddings") | |
return True | |
# ----------------------- | |
# Fixed Embedding Generation (Thread-Safe) | |
# ----------------------- | |
def create_safe_embedding(img_path: Path) -> np.ndarray: | |
"""Create embedding with comprehensive error handling""" | |
try: | |
# Check if file exists and has content | |
if not img_path.exists() or img_path.stat().st_size == 0: | |
print(f"Invalid image file: {img_path}") | |
return np.zeros(HIST_BINS_PER_CHANNEL * 3, dtype=np.float32) | |
# Try to open and process image | |
img = Image.open(img_path).convert("RGB") | |
img = img.resize((224, 224), Image.Resampling.LANCZOS) | |
arr = np.array(img, dtype=np.float32) | |
# Normalize to 0-255 range if needed | |
if arr.max() <= 1.0: | |
arr = arr * 255.0 | |
# Create histogram features for each channel | |
channels = [] | |
for ch in range(3): # RGB channels | |
hist, _ = np.histogram(arr[:, :, ch], bins=HIST_BINS_PER_CHANNEL, | |
range=HIST_RANGE) | |
channels.append(hist.astype(np.float32)) | |
# Concatenate and normalize | |
vec = np.concatenate(channels) | |
norm = np.linalg.norm(vec) | |
if norm > 1e-12: | |
return vec / norm | |
else: | |
return vec # Return unnormalized if norm is too small | |
except Exception as e: | |
print(f"Embedding creation error for {img_path}: {e}") | |
return np.zeros(HIST_BINS_PER_CHANNEL * 3, dtype=np.float32) | |
def process_embeddings_thread_safe() -> bool: | |
"""Create embeddings in background thread - NO Streamlit APIs""" | |
image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg") | |
if f.stat().st_size > 0]) | |
if not image_files: | |
try: | |
# Create empty artifacts so app can run without images | |
embeddings_array = np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32) | |
np.save(EMB_NPY, embeddings_array) | |
with open(EMB_INDEX_JSON, 'w') as f: | |
json.dump([], f, indent=2) | |
progress_tracker.update(PHASE_2_EMBEDDING, 1, 1, 0, | |
"β No images to process", "Empty index created") | |
return True | |
except Exception as e: | |
progress_tracker.update(PHASE_ERROR, 0, 1, 1, f"β No images and failed to init embeddings: {e}", "") | |
return False | |
# Check if embeddings already exist and are current | |
try: | |
if EMB_NPY.exists() and EMB_INDEX_JSON.exists(): | |
existing_embeddings = np.load(EMB_NPY) | |
with open(EMB_INDEX_JSON, 'r') as f: | |
existing_index = json.load(f) | |
if len(existing_embeddings) == len(image_files): | |
progress_tracker.update(PHASE_2_EMBEDDING, len(image_files), len(image_files), 0, | |
"β Embeddings up to date", "") | |
return True | |
except Exception as e: | |
print(f"Error checking existing embeddings: {e}") | |
total = len(image_files) | |
embeddings = [] | |
index = [] | |
processed = 0 | |
failed = 0 | |
progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0, | |
f"π§ Creating embeddings for {total} images...", | |
"Processing visual features") | |
try: | |
for img_file in image_files: | |
embedding = create_safe_embedding(img_file) | |
# Always append; track if zero to surface minimal failures only | |
if not np.any(embedding): | |
failed += 1 | |
embeddings.append(embedding) | |
index.append(img_file.name) | |
processed += 1 | |
# Save in batches for resilience | |
if processed % BATCH_SIZE == 0 or processed == total: | |
try: | |
# Always write files, even if all embeddings are zero vectors | |
embeddings_array = np.vstack(embeddings).astype(np.float32) if embeddings else np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32) | |
# Atomic save | |
temp_npy = EMB_NPY.with_suffix('.tmp') | |
temp_json = EMB_INDEX_JSON.with_suffix('.tmp') | |
np.save(temp_npy, embeddings_array) | |
with open(temp_json, 'w') as f: | |
json.dump(index, f, indent=2) | |
# Atomic move | |
temp_npy.replace(EMB_NPY) | |
temp_json.replace(EMB_INDEX_JSON) | |
details = f"πΎ Batch saved β’ π {len(embeddings)} embeddings" | |
if failed > 0: | |
details += f" β’ β οΈ {failed} errors" | |
message = f"π§ Processed {processed}/{total}" | |
if processed == total: | |
message = "β All embeddings created!" | |
progress_tracker.update(PHASE_2_EMBEDDING, processed, total, failed, | |
message, details) | |
except Exception as e: | |
progress_tracker.update(PHASE_ERROR, processed, total, failed, | |
f"β Save failed: {e}", "") | |
return False | |
return True | |
except Exception as e: | |
progress_tracker.update(PHASE_ERROR, processed, total, failed, | |
f"β Processing failed: {e}", "") | |
return False | |
# ----------------------- | |
# Background Setup Thread (NO Streamlit APIs) | |
# ----------------------- | |
def run_setup_background(): | |
"""Complete setup process in background - NO Streamlit calls""" | |
try: | |
ensure_dirs() | |
# Load URLs | |
urls = load_csv_urls() | |
if not urls: | |
progress_tracker.update(PHASE_1_DOWNLOAD, 1, 1, 0, | |
"β No valid URLs found", "Skipping downloads") | |
print(f"Starting setup with {len(urls)} URLs") | |
# Phase 1: Downloads | |
process_downloads_thread_safe(urls) | |
print("Download phase completed") | |
# Phase 2: Embeddings | |
if not process_embeddings_thread_safe(): | |
# As a fallback, attempt to initialize empty embeddings and continue | |
try: | |
embeddings_array = np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32) | |
np.save(EMB_NPY, embeddings_array) | |
with open(EMB_INDEX_JSON, 'w') as f: | |
json.dump([], f, indent=2) | |
progress_tracker.update(PHASE_2_EMBEDDING, 1, 1, 0, | |
"β Initialized empty embeddings", "No images processed") | |
except Exception as e: | |
progress_tracker.update(PHASE_ERROR, 0, 1, 1, | |
f"β Embedding fallback failed: {e}", "") | |
return | |
print("Embedding phase completed") | |
# Mark complete | |
mark_setup_complete() | |
print("Setup completed successfully") | |
except Exception as e: | |
print(f"Setup error: {e}") | |
progress_tracker.update(PHASE_ERROR, 0, 1, 1, | |
f"β Setup error: {e}", "Unexpected error") | |
# ----------------------- | |
# Search Functions (Main Thread Only) | |
# ----------------------- | |
def cosine_sim(embeddings, q_vec): | |
"""Calculate cosine similarity""" | |
if embeddings.size == 0 or embeddings.shape[0] == 0: | |
return np.array([], dtype=np.float32) | |
emb_norm = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12) | |
qn = q_vec / (np.linalg.norm(q_vec) + 1e-12) | |
return emb_norm @ qn | |
def search_text_safe(query: str, embeddings: np.ndarray, index: List[str], top_k: int = 5): | |
"""Text search with timeout handling - Main thread only""" | |
if not HF_TOKEN: | |
st.warning("β οΈ Text search requires HF_TOKEN environment variable") | |
return [] | |
try: | |
with st.spinner(f"π Searching for '{query}'..."): | |
response = requests.post(API_URL, headers=HEADERS, | |
json={"inputs": query}, timeout=(30, 120)) | |
if response.status_code == 503: | |
st.warning("β³ AI model is loading. Please wait 10-20 seconds and try again.") | |
return [] | |
if response.status_code != 200: | |
st.error(f"β API Error ({response.status_code}). Please try again.") | |
return [] | |
try: | |
q_vec = np.array(response.json(), dtype=np.float32) | |
if q_vec.ndim > 1: | |
q_vec = q_vec.mean(axis=0) | |
# Ensure size compatibility | |
if embeddings.size == 0: | |
st.warning("β οΈ No embeddings available yet. Add images to enable search.") | |
return [] | |
if len(q_vec) != embeddings.shape[1]: | |
if len(q_vec) < embeddings.shape[1]: | |
padding = np.zeros(embeddings.shape[1] - len(q_vec)) | |
q_vec = np.concatenate([q_vec, padding]) | |
else: | |
q_vec = q_vec[:embeddings.shape[1]] | |
similarities = cosine_sim(embeddings, q_vec) | |
best = np.argsort(-similarities)[:top_k] | |
return [(index[i], float(similarities[i])) for i in best] | |
except (ValueError, KeyError) as e: | |
st.error(f"β Failed to process search response: {e}") | |
return [] | |
except requests.exceptions.Timeout: | |
st.error("β Search timed out. Please try again.") | |
return [] | |
except Exception as e: | |
st.error(f"β Search failed: {e}") | |
return [] | |
def search_uploaded_safe(uploaded_file, embeddings: np.ndarray, index: List[str], top_k: int = 5): | |
"""Image search - Main thread only""" | |
try: | |
if embeddings.size == 0: | |
st.warning("β οΈ No embeddings available yet. Add images to enable search.") | |
return [] | |
img = Image.open(uploaded_file).convert("RGB") | |
# Create temporary file for processing | |
temp_path = IMAGES_DIR / "temp_query.jpg" | |
img.save(temp_path, "JPEG", quality=90) | |
q_vec = create_safe_embedding(temp_path) | |
# Cleanup | |
if temp_path.exists(): | |
temp_path.unlink() | |
similarities = cosine_sim(embeddings, q_vec) | |
best = np.argsort(-similarities)[:top_k] | |
return [(index[i], float(similarities[i])) for i in best] | |
except Exception as e: | |
st.error(f"β Image analysis failed: {e}") | |
return [] | |
# ----------------------- | |
# Main Application UI (Main Thread Only) | |
# ----------------------- | |
def apply_styling(): | |
"""Apply CSS styling""" | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700&display=swap'); | |
.stApp { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
font-family: 'Poppins', sans-serif; | |
} | |
.main-header { | |
background: rgba(255, 255, 255, 0.1); | |
backdrop-filter: blur(10px); | |
border-radius: 20px; | |
padding: 2rem; | |
margin: 1rem 0; | |
text-align: center; | |
color: white; | |
border: 1px solid rgba(255, 255, 255, 0.2); | |
} | |
.main-header h1 { | |
font-size: 2.5rem; | |
font-weight: 600; | |
margin: 0 0 0.5rem 0; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.3); | |
} | |
.glass-card { | |
background: rgba(255, 255, 255, 0.1); | |
backdrop-filter: blur(10px); | |
border-radius: 15px; | |
padding: 2rem; | |
margin: 1rem 0; | |
border: 1px solid rgba(255, 255, 255, 0.2); | |
} | |
.stButton > button { | |
background: linear-gradient(45deg, #667eea, #764ba2); | |
color: white; | |
border: none; | |
border-radius: 20px; | |
padding: 0.8rem 2rem; | |
font-weight: 500; | |
transition: all 0.3s ease; | |
} | |
.stButton > button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4); | |
} | |
.metric-card { | |
background: rgba(255, 255, 255, 0.1); | |
backdrop-filter: blur(10px); | |
border-radius: 12px; | |
padding: 1.5rem; | |
text-align: center; | |
border: 1px solid rgba(255, 255, 255, 0.2); | |
} | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
</style> | |
""", unsafe_allow_html=True) | |
def init_session_state(): | |
"""Initialize session state variables - Main thread only""" | |
if 'setup_thread' not in st.session_state: | |
st.session_state.setup_thread = None | |
if 'setup_started' not in st.session_state: | |
st.session_state.setup_started = False | |
def main(): | |
"""Main application - All session state access here""" | |
st.set_page_config( | |
page_title="Visual Search System", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="collapsed" | |
) | |
apply_styling() | |
init_session_state() # Safe - main thread only | |
if is_setup_complete(): | |
# Search interface | |
st.markdown(f""" | |
<div class="main-header"> | |
<h1>π Visual Search System</h1> | |
<p>Intelligent search across {MAX_IMAGES} images</p> | |
</div> | |
""", unsafe_allow_html=True) | |
try: | |
# Load search data | |
embeddings = np.load(EMB_NPY) | |
with open(EMB_INDEX_JSON, 'r') as f: | |
index = json.load(f) | |
# System stats | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">πΈ Images</h3> | |
<p style="font-size: 2rem; color: white; margin: 0.5rem 0;">{len(index):,}</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">Available</p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">π§ Features</h3> | |
<p style="font-size: 2rem; color: white; margin: 0.5rem 0;">{embeddings.shape[1]}</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">Dimensions</p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col3: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">β‘ Status</h3> | |
<p style="font-size: 2rem; color: #4CAF50; margin: 0.5rem 0;">Ready</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">System</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Search interface | |
st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
st.markdown("### π Search Options") | |
search_type = st.radio("Choose search method:", | |
["π€ Text Search", "π Image Upload"], horizontal=True) | |
top_k = st.slider("Number of results:", 1, 15, 6) | |
if search_type == "π€ Text Search": | |
st.markdown("#### π€ Describe What You're Looking For") | |
query = st.text_input("", placeholder="e.g., sunset, cat, building, nature...") | |
if query: | |
if st.button("π Search", type="primary"): | |
results = search_text_safe(query, embeddings, index, top_k) | |
if results: | |
st.markdown("---") | |
st.write(f"**π― Found {len(results)} similar images:**") | |
cols = st.columns(min(3, len(results))) | |
for i, (filename, similarity) in enumerate(results): | |
col = cols[i % len(cols)] | |
img_path = IMAGES_DIR / filename | |
with col: | |
if img_path.exists(): | |
st.image(str(img_path), use_column_width=True) | |
similarity_percent = similarity * 100 | |
color = "#4CAF50" if similarity_percent > 70 else "#FF9800" if similarity_percent > 50 else "#2196F3" | |
st.markdown(f""" | |
<div style="text-align: center; color: {color}; font-weight: 600;"> | |
{similarity_percent:.1f}% Match | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.error(f"β {filename}") | |
else: # Image upload | |
st.markdown("#### π Upload Image for Similarity Search") | |
uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg', 'webp']) | |
if uploaded_file: | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
st.image(uploaded_file, caption="Query Image") | |
if st.button("π Find Similar", type="primary"): | |
results = search_uploaded_safe(uploaded_file, embeddings, index, top_k) | |
if results: | |
st.markdown("---") | |
st.write(f"**π― Found {len(results)} similar images:**") | |
cols = st.columns(min(3, len(results))) | |
for i, (filename, similarity) in enumerate(results): | |
col = cols[i % len(cols)] | |
img_path = IMAGES_DIR / filename | |
with col: | |
if img_path.exists(): | |
st.image(str(img_path), use_column_width=True) | |
similarity_percent = similarity * 100 | |
color = "#4CAF50" if similarity_percent > 70 else "#FF9800" if similarity_percent > 50 else "#2196F3" | |
st.markdown(f""" | |
<div style="text-align: center; color: {color}; font-weight: 600;"> | |
{similarity_percent:.1f}% Match | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.error(f"β {filename}") | |
st.markdown('</div>', unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"β Failed to load search system: {e}") | |
if st.button("π Retry"): | |
st.rerun() | |
else: | |
# Setup interface | |
st.markdown(f""" | |
<div class="main-header"> | |
<h1>π Visual Search Setup</h1> | |
<p>Preparing to process {MAX_IMAGES} images</p> | |
</div> | |
""", unsafe_allow_html=True) | |
progress_data = progress_tracker.read() | |
if progress_data or st.session_state.setup_started: | |
# Show progress | |
if progress_data: | |
phase = progress_data.get("phase", PHASE_IDLE) | |
current = progress_data.get("current", 0) | |
total = progress_data.get("total", 0) | |
percentage = progress_data.get("percentage", 0) | |
message = progress_data.get("message", "Processing...") | |
details = progress_data.get("details", "") | |
fails = progress_data.get("fails", 0) | |
st.markdown(f""" | |
<div class="glass-card"> | |
<h2 style="color: white; text-align: center; margin-bottom: 1rem;">{message}</h2> | |
</div> | |
""", unsafe_allow_html=True) | |
if total > 0: | |
st.progress(percentage / 100) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">π Progress</h3> | |
<p style="font-size: 1.5rem; color: white; margin: 0.5rem 0;">{current}/{total}</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">{percentage:.1f}%</p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">β±οΈ Status</h3> | |
<p style="font-size: 1.5rem; color: white; margin: 0.5rem 0;">Processing</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">Active</p> | |
</div> | |
""", unsafe_allow_html=True) | |
with col3: | |
quality = "Good" if fails < 5 else "Issues" if fails < 20 else "Poor" | |
color = "#4CAF50" if fails < 5 else "#FF9800" if fails < 20 else "#f44336" | |
st.markdown(f""" | |
<div class="metric-card"> | |
<h3 style="color: #667eea; margin: 0;">β¨ Quality</h3> | |
<p style="font-size: 1.5rem; color: {color}; margin: 0.5rem 0;">{quality}</p> | |
<p style="color: rgba(255,255,255,0.8); margin: 0;">{fails} issues</p> | |
</div> | |
""", unsafe_allow_html=True) | |
if details: | |
st.markdown(f""" | |
<div style="text-align: center; color: rgba(255,255,255,0.8); | |
background: rgba(255,255,255,0.1); padding: 1rem; | |
border-radius: 10px; margin: 1rem 0;"> | |
{details} | |
</div> | |
""", unsafe_allow_html=True) | |
if phase == PHASE_ERROR: | |
st.error("β Setup encountered errors. You can restart the process.") | |
if st.button("π Restart Setup"): | |
if PROGRESS_FILE.exists(): | |
PROGRESS_FILE.unlink() | |
st.session_state.setup_started = False | |
st.session_state.setup_thread = None | |
st.rerun() | |
elif percentage >= 100: | |
st.success("π Setup completed successfully! Redirecting...") | |
time.sleep(2) | |
st.rerun() | |
else: | |
# Auto-refresh | |
time.sleep(3) | |
st.rerun() | |
else: | |
st.info("β³ Setup is starting...") | |
time.sleep(2) | |
st.rerun() | |
else: | |
# Initial setup screen | |
if not CSV_FILE.exists(): | |
st.error("β Required file missing: photos_url.csv") | |
else: | |
urls = load_csv_urls() | |
if urls: | |
st.success(f"β Found {len(urls):,} valid image URLs") | |
st.markdown(""" | |
<div class="glass-card"> | |
<h3 style="color: white; text-align: center;">Setup Process</h3> | |
<p style="color: rgba(255,255,255,0.8); text-align: center;"> | |
π₯ <strong>Phase 1:</strong> Download images in parallel<br> | |
π§ <strong>Phase 2:</strong> Create visual embeddings<br> | |
β <strong>Phase 3:</strong> Enable search functionality | |
</p> | |
</div> | |
""", unsafe_allow_html=True) | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
if st.button("π Start Setup Process", type="primary", use_container_width=True): | |
st.session_state.setup_started = True | |
# Start background thread | |
thread = threading.Thread(target=run_setup_background, daemon=True) | |
thread.start() | |
st.session_state.setup_thread = thread | |
st.rerun() | |
else: | |
st.error("β No valid URLs found in CSV file") | |
if __name__ == "__main__": | |
main() | |