manisharma494's picture
Update app.py
c7b2f0b verified
#!/usr/bin/env python3
"""
Visual Search System - Fixed Threading & Session State Issues
-----------------------------------------------------------
- Fixed: No SessionInfo/ScriptRunContext errors
- Fixed: Embedding generation with proper error handling
- Thread-safe progress tracking without Streamlit APIs
- All session state access moved to main thread only
"""
import os
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
import pandas as pd
import numpy as np
import streamlit as st
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime
from typing import Optional, Tuple, List
# -----------------------
# Configuration
# -----------------------
IMAGES_DIR = Path("images")
EMBED_DIR = Path("embeddings")
CSV_FILE = Path("photos_url.csv")
PROGRESS_FILE = Path("progress.json")
SETUP_COMPLETE_FILE = Path("setup_complete.flag")
MAX_IMAGES = 250 # Set to 250 as requested
JPEG_QUALITY = 85
TARGET_MAX_SIZE = (800, 800)
MAX_WORKERS = 6 # Reduced for stability
RETRY_COUNT = 3
BATCH_SIZE = 20
EMB_NPY = EMBED_DIR / "image_embeddings.npy"
EMB_INDEX_JSON = EMBED_DIR / "index.json"
HIST_BINS_PER_CHANNEL = 32
HIST_RANGE = (0, 256)
HF_TOKEN = os.getenv("HF_TOKEN")
# Use a valid public embedding model for text queries via HF Inference API
CLIP_MODEL = "sentence-transformers/clip-ViT-B-32"
API_URL = f"https://api-inference.huggingface.co/models/{CLIP_MODEL}"
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
# Phase Constants
PHASE_IDLE = "idle"
PHASE_1_DOWNLOAD = "download"
PHASE_2_EMBEDDING = "embedding"
PHASE_3_COMPLETE = "complete"
PHASE_ERROR = "error"
# -----------------------
# Thread-Safe Progress Tracker (NO Streamlit APIs)
# -----------------------
class SafeProgressTracker:
"""Progress tracker that doesn't use Streamlit APIs in threads"""
def __init__(self):
self._lock = threading.Lock()
self._data = {}
def update(self, phase: str, current: int, total: int, fails: int = 0,
message: str = "", details: str = ""):
"""Update progress - safe to call from any thread"""
with self._lock:
percentage = (current / total * 100) if total > 0 else 0
phase_names = {
PHASE_IDLE: "πŸš€ Initializing",
PHASE_1_DOWNLOAD: "πŸ“₯ Downloading Images",
PHASE_2_EMBEDDING: "🧠 Creating Embeddings",
PHASE_3_COMPLETE: "βœ… System Ready",
PHASE_ERROR: "❌ Error Occurred"
}
self._data = {
"phase": phase,
"phase_name": phase_names.get(phase, f"Phase: {phase}"),
"current": current,
"total": total,
"fails": fails,
"percentage": percentage,
"message": message,
"details": details,
"timestamp": time.time(),
"formatted_time": datetime.datetime.now().strftime("%H:%M:%S")
}
# Save to file (no Streamlit APIs)
try:
with open(PROGRESS_FILE, 'w') as f:
json.dump(self._data, f, indent=2)
except Exception as e:
print(f"Progress save error: {e}")
def read(self) -> Optional[dict]:
"""Read current progress"""
with self._lock:
try:
if PROGRESS_FILE.exists():
with open(PROGRESS_FILE, 'r') as f:
return json.load(f)
return self._data.copy() if self._data else None
except:
return self._data.copy() if self._data else None
# Global progress tracker (thread-safe)
progress_tracker = SafeProgressTracker()
# -----------------------
# Utility Functions
# -----------------------
def ensure_dirs():
"""Create directories if they don't exist"""
try:
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
EMBED_DIR.mkdir(parents=True, exist_ok=True)
except Exception as e:
print(f"Directory creation error: {e}")
def seq_filename(i: int) -> str:
return f"{i:04d}.jpg"
def load_csv_urls() -> List[str]:
"""Load URLs from CSV with robust error handling"""
try:
if not CSV_FILE.exists():
return []
# Try different encodings
for encoding in ['utf-8', 'utf-8-sig', 'latin1']:
try:
df = pd.read_csv(CSV_FILE, encoding=encoding)
break
except UnicodeDecodeError:
continue
else:
print("Failed to read CSV with any encoding")
return []
# Find URL column
url_cols = [c for c in df.columns if "url" in c.lower()]
if not url_cols:
url_cols = [df.columns[0]]
urls = df[url_cols[0]].astype(str).tolist()[:MAX_IMAGES]
# Filter valid URLs
valid_urls = []
for url in urls:
url = str(url).strip()
if (url and url.lower() != 'nan' and
(url.startswith('http://') or url.startswith('https://'))):
valid_urls.append(url)
return valid_urls
except Exception as e:
print(f"CSV loading error: {e}")
return []
def is_setup_complete() -> bool:
"""Check if setup is complete"""
try:
return (SETUP_COMPLETE_FILE.exists() and
EMB_NPY.exists() and
EMB_INDEX_JSON.exists() and
EMB_NPY.stat().st_size > 0 and
EMB_INDEX_JSON.stat().st_size > 0)
except:
return False
def mark_setup_complete():
"""Mark setup as complete"""
try:
completion_data = {
"completed_at": time.time(),
"formatted_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"images_processed": len(list(IMAGES_DIR.glob("*.jpg"))),
"embeddings_created": True
}
with open(SETUP_COMPLETE_FILE, 'w') as f:
json.dump(completion_data, f, indent=2)
progress_tracker.update(PHASE_3_COMPLETE, 1, 1, 0,
"βœ… Setup Complete!", "Ready for search")
except Exception as e:
print(f"Setup completion error: {e}")
# -----------------------
# Image Download Functions (Thread-Safe)
# -----------------------
def download_single_image(i: int, url: str) -> bool:
"""Download single image - NO Streamlit APIs"""
fname = IMAGES_DIR / seq_filename(i)
if fname.exists() and fname.stat().st_size > 0:
return True
for attempt in range(RETRY_COUNT):
try:
# Use longer timeouts
response = requests.get(url, stream=True, timeout=(30, 90))
if response.status_code != 200:
if attempt == RETRY_COUNT - 1:
return False
time.sleep(2 ** attempt) # Exponential backoff
continue
img = Image.open(response.raw).convert("RGB")
img.thumbnail(TARGET_MAX_SIZE, Image.Resampling.LANCZOS)
# Atomic save
temp_fname = fname.with_suffix('.tmp')
img.save(temp_fname, "JPEG", quality=JPEG_QUALITY, optimize=True)
temp_fname.replace(fname)
return True
except Exception as e:
if attempt == RETRY_COUNT - 1:
print(f"Download failed {url}: {e}")
return False
time.sleep(2 ** attempt)
return False
def process_downloads_thread_safe(urls: List[str]) -> bool:
"""Download images in background thread - NO Streamlit APIs"""
if not urls:
progress_tracker.update(PHASE_1_DOWNLOAD, 1, 1, 0,
"βœ… No URLs provided", "Skipping downloads")
return True
# Find what needs downloading
tasks = []
for i, url in enumerate(urls, 1):
img_path = IMAGES_DIR / seq_filename(i)
if not (img_path.exists() and img_path.stat().st_size > 0):
tasks.append((i, url))
if not tasks:
progress_tracker.update(PHASE_1_DOWNLOAD, len(urls), len(urls), 0,
"βœ… All images already downloaded", "")
return True
total = len(tasks)
completed = 0
failed = 0
progress_tracker.update(PHASE_1_DOWNLOAD, 0, total, 0,
f"πŸ”„ Downloading {total} images...", "Starting download")
try:
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(download_single_image, i, url): (i, url)
for i, url in tasks}
for future in as_completed(futures):
i, url = futures[future]
try:
success = future.result()
completed += 1
if not success:
failed += 1
# Update progress every 5 completions
if completed % 5 == 0 or completed == total:
success_rate = ((completed - failed) / completed * 100) if completed > 0 else 0
details = f"βœ… {completed - failed} successful β€’ ❌ {failed} failed"
message = f"πŸ“₯ Downloaded {completed}/{total} ({success_rate:.1f}% success)"
progress_tracker.update(PHASE_1_DOWNLOAD, completed, total, failed,
message, details)
except Exception as e:
print(f"Download future error: {e}")
failed += 1
completed += 1
except Exception as e:
# Do not hard fail; continue to embeddings with whatever is available
progress_tracker.update(PHASE_1_DOWNLOAD, completed, total, failed,
f"⚠️ Download issues: {e}", "Continuing setup")
return True
success_rate = ((total - failed) / total * 100) if total > 0 else 100
# Proceed regardless; embedding step can handle zero/partial images
if total > 0:
progress_tracker.update(PHASE_1_DOWNLOAD, total, total, failed,
f"πŸ“₯ Downloaded {total - failed}/{total} ({success_rate:.1f}% success)",
"Proceeding to embeddings")
return True
# -----------------------
# Fixed Embedding Generation (Thread-Safe)
# -----------------------
def create_safe_embedding(img_path: Path) -> np.ndarray:
"""Create embedding with comprehensive error handling"""
try:
# Check if file exists and has content
if not img_path.exists() or img_path.stat().st_size == 0:
print(f"Invalid image file: {img_path}")
return np.zeros(HIST_BINS_PER_CHANNEL * 3, dtype=np.float32)
# Try to open and process image
img = Image.open(img_path).convert("RGB")
img = img.resize((224, 224), Image.Resampling.LANCZOS)
arr = np.array(img, dtype=np.float32)
# Normalize to 0-255 range if needed
if arr.max() <= 1.0:
arr = arr * 255.0
# Create histogram features for each channel
channels = []
for ch in range(3): # RGB channels
hist, _ = np.histogram(arr[:, :, ch], bins=HIST_BINS_PER_CHANNEL,
range=HIST_RANGE)
channels.append(hist.astype(np.float32))
# Concatenate and normalize
vec = np.concatenate(channels)
norm = np.linalg.norm(vec)
if norm > 1e-12:
return vec / norm
else:
return vec # Return unnormalized if norm is too small
except Exception as e:
print(f"Embedding creation error for {img_path}: {e}")
return np.zeros(HIST_BINS_PER_CHANNEL * 3, dtype=np.float32)
def process_embeddings_thread_safe() -> bool:
"""Create embeddings in background thread - NO Streamlit APIs"""
image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg")
if f.stat().st_size > 0])
if not image_files:
try:
# Create empty artifacts so app can run without images
embeddings_array = np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32)
np.save(EMB_NPY, embeddings_array)
with open(EMB_INDEX_JSON, 'w') as f:
json.dump([], f, indent=2)
progress_tracker.update(PHASE_2_EMBEDDING, 1, 1, 0,
"βœ… No images to process", "Empty index created")
return True
except Exception as e:
progress_tracker.update(PHASE_ERROR, 0, 1, 1, f"❌ No images and failed to init embeddings: {e}", "")
return False
# Check if embeddings already exist and are current
try:
if EMB_NPY.exists() and EMB_INDEX_JSON.exists():
existing_embeddings = np.load(EMB_NPY)
with open(EMB_INDEX_JSON, 'r') as f:
existing_index = json.load(f)
if len(existing_embeddings) == len(image_files):
progress_tracker.update(PHASE_2_EMBEDDING, len(image_files), len(image_files), 0,
"βœ… Embeddings up to date", "")
return True
except Exception as e:
print(f"Error checking existing embeddings: {e}")
total = len(image_files)
embeddings = []
index = []
processed = 0
failed = 0
progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0,
f"🧠 Creating embeddings for {total} images...",
"Processing visual features")
try:
for img_file in image_files:
embedding = create_safe_embedding(img_file)
# Always append; track if zero to surface minimal failures only
if not np.any(embedding):
failed += 1
embeddings.append(embedding)
index.append(img_file.name)
processed += 1
# Save in batches for resilience
if processed % BATCH_SIZE == 0 or processed == total:
try:
# Always write files, even if all embeddings are zero vectors
embeddings_array = np.vstack(embeddings).astype(np.float32) if embeddings else np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32)
# Atomic save
temp_npy = EMB_NPY.with_suffix('.tmp')
temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
np.save(temp_npy, embeddings_array)
with open(temp_json, 'w') as f:
json.dump(index, f, indent=2)
# Atomic move
temp_npy.replace(EMB_NPY)
temp_json.replace(EMB_INDEX_JSON)
details = f"πŸ’Ύ Batch saved β€’ πŸ“Š {len(embeddings)} embeddings"
if failed > 0:
details += f" β€’ ⚠️ {failed} errors"
message = f"🧠 Processed {processed}/{total}"
if processed == total:
message = "βœ… All embeddings created!"
progress_tracker.update(PHASE_2_EMBEDDING, processed, total, failed,
message, details)
except Exception as e:
progress_tracker.update(PHASE_ERROR, processed, total, failed,
f"❌ Save failed: {e}", "")
return False
return True
except Exception as e:
progress_tracker.update(PHASE_ERROR, processed, total, failed,
f"❌ Processing failed: {e}", "")
return False
# -----------------------
# Background Setup Thread (NO Streamlit APIs)
# -----------------------
def run_setup_background():
"""Complete setup process in background - NO Streamlit calls"""
try:
ensure_dirs()
# Load URLs
urls = load_csv_urls()
if not urls:
progress_tracker.update(PHASE_1_DOWNLOAD, 1, 1, 0,
"βœ… No valid URLs found", "Skipping downloads")
print(f"Starting setup with {len(urls)} URLs")
# Phase 1: Downloads
process_downloads_thread_safe(urls)
print("Download phase completed")
# Phase 2: Embeddings
if not process_embeddings_thread_safe():
# As a fallback, attempt to initialize empty embeddings and continue
try:
embeddings_array = np.zeros((0, HIST_BINS_PER_CHANNEL * 3), dtype=np.float32)
np.save(EMB_NPY, embeddings_array)
with open(EMB_INDEX_JSON, 'w') as f:
json.dump([], f, indent=2)
progress_tracker.update(PHASE_2_EMBEDDING, 1, 1, 0,
"βœ… Initialized empty embeddings", "No images processed")
except Exception as e:
progress_tracker.update(PHASE_ERROR, 0, 1, 1,
f"❌ Embedding fallback failed: {e}", "")
return
print("Embedding phase completed")
# Mark complete
mark_setup_complete()
print("Setup completed successfully")
except Exception as e:
print(f"Setup error: {e}")
progress_tracker.update(PHASE_ERROR, 0, 1, 1,
f"❌ Setup error: {e}", "Unexpected error")
# -----------------------
# Search Functions (Main Thread Only)
# -----------------------
def cosine_sim(embeddings, q_vec):
"""Calculate cosine similarity"""
if embeddings.size == 0 or embeddings.shape[0] == 0:
return np.array([], dtype=np.float32)
emb_norm = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12)
qn = q_vec / (np.linalg.norm(q_vec) + 1e-12)
return emb_norm @ qn
def search_text_safe(query: str, embeddings: np.ndarray, index: List[str], top_k: int = 5):
"""Text search with timeout handling - Main thread only"""
if not HF_TOKEN:
st.warning("⚠️ Text search requires HF_TOKEN environment variable")
return []
try:
with st.spinner(f"πŸ” Searching for '{query}'..."):
response = requests.post(API_URL, headers=HEADERS,
json={"inputs": query}, timeout=(30, 120))
if response.status_code == 503:
st.warning("⏳ AI model is loading. Please wait 10-20 seconds and try again.")
return []
if response.status_code != 200:
st.error(f"❌ API Error ({response.status_code}). Please try again.")
return []
try:
q_vec = np.array(response.json(), dtype=np.float32)
if q_vec.ndim > 1:
q_vec = q_vec.mean(axis=0)
# Ensure size compatibility
if embeddings.size == 0:
st.warning("⚠️ No embeddings available yet. Add images to enable search.")
return []
if len(q_vec) != embeddings.shape[1]:
if len(q_vec) < embeddings.shape[1]:
padding = np.zeros(embeddings.shape[1] - len(q_vec))
q_vec = np.concatenate([q_vec, padding])
else:
q_vec = q_vec[:embeddings.shape[1]]
similarities = cosine_sim(embeddings, q_vec)
best = np.argsort(-similarities)[:top_k]
return [(index[i], float(similarities[i])) for i in best]
except (ValueError, KeyError) as e:
st.error(f"❌ Failed to process search response: {e}")
return []
except requests.exceptions.Timeout:
st.error("❌ Search timed out. Please try again.")
return []
except Exception as e:
st.error(f"❌ Search failed: {e}")
return []
def search_uploaded_safe(uploaded_file, embeddings: np.ndarray, index: List[str], top_k: int = 5):
"""Image search - Main thread only"""
try:
if embeddings.size == 0:
st.warning("⚠️ No embeddings available yet. Add images to enable search.")
return []
img = Image.open(uploaded_file).convert("RGB")
# Create temporary file for processing
temp_path = IMAGES_DIR / "temp_query.jpg"
img.save(temp_path, "JPEG", quality=90)
q_vec = create_safe_embedding(temp_path)
# Cleanup
if temp_path.exists():
temp_path.unlink()
similarities = cosine_sim(embeddings, q_vec)
best = np.argsort(-similarities)[:top_k]
return [(index[i], float(similarities[i])) for i in best]
except Exception as e:
st.error(f"❌ Image analysis failed: {e}")
return []
# -----------------------
# Main Application UI (Main Thread Only)
# -----------------------
def apply_styling():
"""Apply CSS styling"""
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700&display=swap');
.stApp {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
font-family: 'Poppins', sans-serif;
}
.main-header {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border-radius: 20px;
padding: 2rem;
margin: 1rem 0;
text-align: center;
color: white;
border: 1px solid rgba(255, 255, 255, 0.2);
}
.main-header h1 {
font-size: 2.5rem;
font-weight: 600;
margin: 0 0 0.5rem 0;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.glass-card {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border-radius: 15px;
padding: 2rem;
margin: 1rem 0;
border: 1px solid rgba(255, 255, 255, 0.2);
}
.stButton > button {
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 20px;
padding: 0.8rem 2rem;
font-weight: 500;
transition: all 0.3s ease;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
}
.metric-card {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border-radius: 12px;
padding: 1.5rem;
text-align: center;
border: 1px solid rgba(255, 255, 255, 0.2);
}
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
def init_session_state():
"""Initialize session state variables - Main thread only"""
if 'setup_thread' not in st.session_state:
st.session_state.setup_thread = None
if 'setup_started' not in st.session_state:
st.session_state.setup_started = False
def main():
"""Main application - All session state access here"""
st.set_page_config(
page_title="Visual Search System",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="collapsed"
)
apply_styling()
init_session_state() # Safe - main thread only
if is_setup_complete():
# Search interface
st.markdown(f"""
<div class="main-header">
<h1>πŸ” Visual Search System</h1>
<p>Intelligent search across {MAX_IMAGES} images</p>
</div>
""", unsafe_allow_html=True)
try:
# Load search data
embeddings = np.load(EMB_NPY)
with open(EMB_INDEX_JSON, 'r') as f:
index = json.load(f)
# System stats
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">πŸ“Έ Images</h3>
<p style="font-size: 2rem; color: white; margin: 0.5rem 0;">{len(index):,}</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">Available</p>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">🧠 Features</h3>
<p style="font-size: 2rem; color: white; margin: 0.5rem 0;">{embeddings.shape[1]}</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">Dimensions</p>
</div>
""", unsafe_allow_html=True)
with col3:
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">⚑ Status</h3>
<p style="font-size: 2rem; color: #4CAF50; margin: 0.5rem 0;">Ready</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">System</p>
</div>
""", unsafe_allow_html=True)
# Search interface
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
st.markdown("### πŸ” Search Options")
search_type = st.radio("Choose search method:",
["πŸ”€ Text Search", "πŸ“ Image Upload"], horizontal=True)
top_k = st.slider("Number of results:", 1, 15, 6)
if search_type == "πŸ”€ Text Search":
st.markdown("#### πŸ”€ Describe What You're Looking For")
query = st.text_input("", placeholder="e.g., sunset, cat, building, nature...")
if query:
if st.button("πŸ” Search", type="primary"):
results = search_text_safe(query, embeddings, index, top_k)
if results:
st.markdown("---")
st.write(f"**🎯 Found {len(results)} similar images:**")
cols = st.columns(min(3, len(results)))
for i, (filename, similarity) in enumerate(results):
col = cols[i % len(cols)]
img_path = IMAGES_DIR / filename
with col:
if img_path.exists():
st.image(str(img_path), use_column_width=True)
similarity_percent = similarity * 100
color = "#4CAF50" if similarity_percent > 70 else "#FF9800" if similarity_percent > 50 else "#2196F3"
st.markdown(f"""
<div style="text-align: center; color: {color}; font-weight: 600;">
{similarity_percent:.1f}% Match
</div>
""", unsafe_allow_html=True)
else:
st.error(f"❌ {filename}")
else: # Image upload
st.markdown("#### πŸ“ Upload Image for Similarity Search")
uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg', 'webp'])
if uploaded_file:
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.image(uploaded_file, caption="Query Image")
if st.button("πŸ” Find Similar", type="primary"):
results = search_uploaded_safe(uploaded_file, embeddings, index, top_k)
if results:
st.markdown("---")
st.write(f"**🎯 Found {len(results)} similar images:**")
cols = st.columns(min(3, len(results)))
for i, (filename, similarity) in enumerate(results):
col = cols[i % len(cols)]
img_path = IMAGES_DIR / filename
with col:
if img_path.exists():
st.image(str(img_path), use_column_width=True)
similarity_percent = similarity * 100
color = "#4CAF50" if similarity_percent > 70 else "#FF9800" if similarity_percent > 50 else "#2196F3"
st.markdown(f"""
<div style="text-align: center; color: {color}; font-weight: 600;">
{similarity_percent:.1f}% Match
</div>
""", unsafe_allow_html=True)
else:
st.error(f"❌ {filename}")
st.markdown('</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"❌ Failed to load search system: {e}")
if st.button("πŸ”„ Retry"):
st.rerun()
else:
# Setup interface
st.markdown(f"""
<div class="main-header">
<h1>πŸš€ Visual Search Setup</h1>
<p>Preparing to process {MAX_IMAGES} images</p>
</div>
""", unsafe_allow_html=True)
progress_data = progress_tracker.read()
if progress_data or st.session_state.setup_started:
# Show progress
if progress_data:
phase = progress_data.get("phase", PHASE_IDLE)
current = progress_data.get("current", 0)
total = progress_data.get("total", 0)
percentage = progress_data.get("percentage", 0)
message = progress_data.get("message", "Processing...")
details = progress_data.get("details", "")
fails = progress_data.get("fails", 0)
st.markdown(f"""
<div class="glass-card">
<h2 style="color: white; text-align: center; margin-bottom: 1rem;">{message}</h2>
</div>
""", unsafe_allow_html=True)
if total > 0:
st.progress(percentage / 100)
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">πŸ“Š Progress</h3>
<p style="font-size: 1.5rem; color: white; margin: 0.5rem 0;">{current}/{total}</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">{percentage:.1f}%</p>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">⏱️ Status</h3>
<p style="font-size: 1.5rem; color: white; margin: 0.5rem 0;">Processing</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">Active</p>
</div>
""", unsafe_allow_html=True)
with col3:
quality = "Good" if fails < 5 else "Issues" if fails < 20 else "Poor"
color = "#4CAF50" if fails < 5 else "#FF9800" if fails < 20 else "#f44336"
st.markdown(f"""
<div class="metric-card">
<h3 style="color: #667eea; margin: 0;">✨ Quality</h3>
<p style="font-size: 1.5rem; color: {color}; margin: 0.5rem 0;">{quality}</p>
<p style="color: rgba(255,255,255,0.8); margin: 0;">{fails} issues</p>
</div>
""", unsafe_allow_html=True)
if details:
st.markdown(f"""
<div style="text-align: center; color: rgba(255,255,255,0.8);
background: rgba(255,255,255,0.1); padding: 1rem;
border-radius: 10px; margin: 1rem 0;">
{details}
</div>
""", unsafe_allow_html=True)
if phase == PHASE_ERROR:
st.error("❌ Setup encountered errors. You can restart the process.")
if st.button("πŸ”„ Restart Setup"):
if PROGRESS_FILE.exists():
PROGRESS_FILE.unlink()
st.session_state.setup_started = False
st.session_state.setup_thread = None
st.rerun()
elif percentage >= 100:
st.success("πŸŽ‰ Setup completed successfully! Redirecting...")
time.sleep(2)
st.rerun()
else:
# Auto-refresh
time.sleep(3)
st.rerun()
else:
st.info("⏳ Setup is starting...")
time.sleep(2)
st.rerun()
else:
# Initial setup screen
if not CSV_FILE.exists():
st.error("❌ Required file missing: photos_url.csv")
else:
urls = load_csv_urls()
if urls:
st.success(f"βœ… Found {len(urls):,} valid image URLs")
st.markdown("""
<div class="glass-card">
<h3 style="color: white; text-align: center;">Setup Process</h3>
<p style="color: rgba(255,255,255,0.8); text-align: center;">
πŸ“₯ <strong>Phase 1:</strong> Download images in parallel<br>
🧠 <strong>Phase 2:</strong> Create visual embeddings<br>
βœ… <strong>Phase 3:</strong> Enable search functionality
</p>
</div>
""", unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
if st.button("πŸš€ Start Setup Process", type="primary", use_container_width=True):
st.session_state.setup_started = True
# Start background thread
thread = threading.Thread(target=run_setup_background, daemon=True)
thread.start()
st.session_state.setup_thread = thread
st.rerun()
else:
st.error("❌ No valid URLs found in CSV file")
if __name__ == "__main__":
main()