dceshubh's picture
Create app.py
0cae3f6 verified
"""
Fashion RAG Pipeline - Assignment
Week 9: Multimodal RAG Pipeline with H&M Fashion Dataset
OBJECTIVE: Build a complete multimodal RAG (Retrieval-Augmented Generation) pipeline
that can search through fashion items using both text and image queries, then generate
helpful responses using an LLM.
LEARNING GOALS:
- Understand the three phases of RAG: Retrieval, Augmentation, Generation
- Work with multimodal data (images + text)
- Use vector databases for similarity search
- Integrate LLM for response generation
- Build an end-to-end AI pipeline
DATASET: H&M Fashion Caption Dataset
- 20K+ fashion items with images and text descriptions
- URL: https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption
PIPELINE OVERVIEW:
1. RETRIEVAL: Find similar fashion items using vector search
2. AUGMENTATION: Create enhanced prompts with retrieved context
3. GENERATION: Generate helpful responses using LLM
Commands to run:
python assignment_fashion_rag.py --query "black dress for evening"
python assignment_fashion_rag.py --app
"""
import argparse
import os
from random import sample
import re
# Suppress warnings
import warnings
from typing import Any, Dict, List, Optional, Tuple
# Gradio for web interface
import gradio as gr
# Core dependencies
import lancedb
import pandas as pd
import torch
from datasets import load_dataset
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from PIL import Image
# LLM dependencies
from transformers import AutoModelForCausalLM, AutoTokenizer
warnings.filterwarnings("ignore")
def is_huggingface_space():
"""
Checks if the code is running within a Hugging Face Spaces environment.
Returns:
bool: True if running in HF Spaces, False otherwise.
"""
if os.environ.get("SYSTEM") == "spaces":
return True
else:
return False
# =============================================================================
# SECTION 1: DATABASE SETUP AND SCHEMA
# =============================================================================
def register_embedding_model(model_name: str = "open-clip") -> Any:
"""
Register embedding model for vector search
TODO: Complete this function
HINT: Use EmbeddingFunctionRegistry to get and create the model
Args:
model_name: Name of the embedding model
Returns:
Embedding model instance
"""
# Get the registry instance
registry = EmbeddingFunctionRegistry.get_instance()
print(f"πŸ” Registering embedding model: {model_name}")
# Get and create the model
model = registry.get(model_name).create()
# Return the model
return model
# Global embedding model
clip_model = register_embedding_model()
class FashionItem(LanceModel):
"""
Schema for fashion items in vector database
TODO: Complete the schema definition
HINT: This defines the structure of data stored in the vector database
REQUIRED FIELDS:
1. vector: Vector field for CLIP embeddings (use clip_model.ndims())
2. image_uri: String field for image file paths
3. description: Optional string field for text descriptions
"""
# Add vector field for embeddings
vector: Vector(clip_model.ndims()) = clip_model.VectorField()
# Add image field
image_uri: str = clip_model.SourceField()
# Add text description field
description: Optional[str] = None
@property
def image(self):
if isinstance(self.image_uri, str) and os.path.exists(self.image_uri):
return Image.open(self.image_uri)
elif hasattr(self.image_uri, "save"): # PIL Image object
return self.image_uri
else:
# Return a placeholder or handle the case appropriately
return None
# =============================================================================
# SECTION 2: RETRIEVAL - Vector Database Operations
# =============================================================================
def setup_fashion_database(
database_path: str = "fashion_db",
table_name: str = "fashion_items",
dataset_name: str = "tomytjandra/h-and-m-fashion-caption",
sample_size: int = 1000,
images_dir: str = "fashion_images",
) -> None:
"""
Set up vector database with H&M fashion dataset
Complete this function to:
1. Connect to LanceDB database
2. Check if table already exists (skip if it does)
3. Load H&M dataset from HuggingFace
4. Process and save images locally
5. Create vector database table
"""
print("πŸ”§ Setting up fashion database...")
print(f"Database path: {database_path}")
print(f"Dataset: {dataset_name}")
print(f"Sample size: {sample_size}")
# Connect to LanceDB
db = lancedb.connect(database_path)
# Check if table already exists
if table_name in db.table_names():
existing_table = db.open_table(table_name) # open table
print(f"βœ… Table '{table_name}' already exists with {len(existing_table)} items")
return
# Drop table
#print(f"⚠️ Table '{table_name}' already exists, deleting it...")
#db.drop_table(table_name)
else:
print(f"πŸ—οΈ Table '{table_name}' does not exist, creating new fashion database...")
# Load dataset from HuggingFace
print("πŸ“₯ Loading H&M fashion dataset...")
dataset = load_dataset(dataset_name)
train_data = dataset["train"]
# Sample data to specified size in the sample_size parameter
if len(train_data) > sample_size:
indices = sample(range(len(train_data)), sample_size)
train_data = train_data.select(indices)
print(f"Processing {len(train_data)} fashion items...")
# Create images directory
os.makedirs(images_dir, exist_ok=True)
# Process each item
table_data = []
for i, item in enumerate(train_data):
# Get image and text
image = item["image"]
text = item["text"]
# Save image
image_path = os.path.join(images_dir, f"fashion_{i:04d}.jpg")
image.save(image_path)
# Create record
record = {
"image_uri": image_path,
"description": text
}
table_data.append(record)
if (i + 1) % 100 == 0:
print(f" Processed {i + 1}/{len(train_data)} items...")
# Create vector database table
if table_data:
if table_name in db.table_names():
print(f"⚠️ Table '{table_name}' already exists, deleting it...")
db.drop_table(table_name)
print("πŸ—„οΈ Creating vector database table...")
table = db.create_table(
table_name,
schema=FashionItem,
data=table_data,
#embedding_function=clip_model,
)
print(f"βœ… Created table '{table_name}' with {len(table_data)} items")
else:
print("❌ No data to create table, please check dataset loading")
print("πŸŽ‰ Fashion database setup complete!")
def search_fashion_items(
database_path: str,
table_name: str,
query: str,
search_type: str = "auto",
limit: int = 3,
) -> Tuple[List[Dict], str]:
"""
Search for fashion items using text or image query
Complete this function to:
1. Determine if query is text or image (auto-detection)
2. Connect to the vector database
3. Perform similarity search using CLIP embeddings
4. Return search results and detected search type
STEPS TO IMPLEMENT:
1. Auto-detect search type: check if query is a file path
2. Connect to database
3. Open table
4. Search based on type:
- Image: load with PIL and search
- Text: search directly with string
5. Return results and search type
Args:
database_path: Path to LanceDB database
table_name: Name of the table to search
query: Search query (text or image path)
search_type: "auto", "text", or "image"
limit: Number of results to return
Returns:
Tuple of (results_list, actual_search_type)
"""
print(f"πŸ” Searching for: {query}")
# Determine search type automatically
# HINT: Use os.path.exists(query) to check if query is a file path
# HINT: If file exists, it's an image search; otherwise, it's text search
if os.path.exists(query):
actual_search_type = "image"
else:
actual_search_type = "text"
print(f" Detected search type: {actual_search_type}")
# Connect to database
db = lancedb.connect(database_path)
print(f"πŸ“‚ Connected to database: {database_path}")
# Open the table
table = db.open_table(table_name)
print(f"πŸ“– Opened table: {table_name}")
# Perform search based on detected type
if actual_search_type == "image":
# Load image and search
image = Image.open(query)
print(f" Searching with image: {query}")
# # Get embeddings for the image
# image_embedding = clip_model.embed_image(image)
# # Perform similarity search
# results = table.search(
# vector=image_embedding,
# limit=limit,
# filter=None, # No additional filters
# ).to_dicts()
# print(f" Found {len(results)} results using image search")
results = table.search(image).limit(limit).to_pydantic(FashionItem)
else:
# Text search
print(f" Searching with text: {query}")
results = table.search(query).limit(limit).to_pydantic(FashionItem)
# Print results found
print(f" Found {len(results)} results using {actual_search_type} search")
# Return results and search type
return results, actual_search_type
# =============================================================================
# SECTION 3: AUGMENTATION - Prompt Engineering
# =============================================================================
def create_fashion_prompt(
query: str, retrieved_items: List[Dict], search_type: str
) -> str:
"""
Create enhanced prompt for LLM using retrieved fashion items
Complete this function to create a well-structured prompt that:
1. Creates a system prompt defining the AI assistant's role
2. Formats retrieved items as context for the LLM
3. Includes the user's query appropriately
4. Combines everything into a coherent prompt
PROMPT STRUCTURE:
1. System prompt: Define the AI as a fashion assistant
2. Context section: List retrieved fashion items with descriptions
3. Query section: Include user's original query
4. Instruction: Ask for fashion recommendations
Args:
query: Original user query
retrieved_items: List of retrieved fashion items
search_type: Type of search performed
Returns:
Enhanced prompt string for LLM
"""
# Create system prompt
# HINT: Define the AI as a fashion assistant with expertise
system_prompt = "You are a fashion assistant with expertise in clothing and accessories. " \
"Your task is to provide helpful fashion recommendations based on user queries and retrieved items." \
"For each of the retrieved item - Please provide helpful fashion recommendations. " \
"Be funny, creative, and engaging in your response." \
"Talk about only retrieved items and do not make up any information. " \
"If you do not have enough information, please say so. " \
"Do not talk about anything else"
print("πŸ“ Creating enhanced prompt...")
# Format retrieved items context
context = "Here are some relevant fashion items from our catalog:\n\n"
for i, item in enumerate(retrieved_items, 1):
print (f" Adding item {i}: {item}...")
# Ensure item has description and image URI
context += f"{i}. {item.description}\n\n"
# Create user query section
# HINT: Handle different search types (image vs text)
if search_type == "image":
query_section = (
f"User searched for an image: {query}\n"
"Please provide fashion recommendations based on the retrieved items and the image."
)
else:
query_section = (
f"User query: {query}\n"
"Please provide fashion recommendations based on the retrieved items and the query."
)
print(f" Query section created: {query_section[:60]}...")
# Combine into final prompt
# HINT: Combine system prompt, context, query section, and response instruction
prompt = f"{system_prompt}\n\n{context}\n{query_section}\n\n "
return prompt
# =============================================================================
# SECTION 4: GENERATION - LLM Response Generation
# =============================================================================
def setup_llm_model(model_name: str = "Qwen/Qwen2.5-0.5B-Instruct") -> Tuple[Any, Any]:
"""
Set up LLM model and tokenizer
Complete this function to load the LLM model and tokenizer
STEPS TO IMPLEMENT:
1. Load tokenizer
2. Load model
3. Configure model settings for GPU/CPU
5. Return tokenizer and model
Args:
model_name: Name of the model to load
Returns:
Tuple of (tokenizer, model)
"""
print(f"πŸ€– Loading LLM model: {model_name}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(" Tokenizer loaded successfully")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32, device_map="cpu"
)
# Set pad token if not exists
# TODO: Why are we doing this ?
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Print success message and return
print("βœ… LLM model loaded successfully")
return tokenizer, model
def generate_fashion_response(
prompt: str, tokenizer: Any, model: Any, max_tokens: int = 200
) -> str:
"""
Generate response using LLM
Complete this function to generate text using the LLM
STEPS TO IMPLEMENT:
1. Check if tokenizer and model are loaded
2. Encode the prompt with attention mask
3. Generate response using model.generate()
4. Decode the response and clean it up
5. Return the generated text
Args:
prompt: Input prompt for the model
tokenizer: Model tokenizer
model: LLM model
max_tokens: Maximum tokens to generate
Returns:
Generated response text
"""
if not tokenizer or not model:
return "⚠️ LLM not loaded - showing search results only"
# Encode prompt with attention mask
# HINT: Use tokenizer() with return_tensors="pt", truncation=True, max_length=1024, padding=True
inputs = tokenizer(
prompt, return_tensors="pt", truncation=True, max_length=2048, padding=True
)
# Added byself
# Ensure everything runs on CPU
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = model.generate(
#inputs.input_ids,
**inputs,
#attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode response and clean it up
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response.replace(prompt, "").strip()
return response
# =============================================================================
# SECTION 5: IMAGE STORAGE
# =============================================================================
def save_retrieved_images(
results: Dict[str, Any], output_dir: str = "retrieved_fashion_images"
) -> List[str]:
"""Save retrieved fashion images to output directory"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
query_safe = re.sub(r"[^\w\s-]", "", str(results["query"]))[:30]
query_safe = re.sub(r"[-\s]+", "_", query_safe)
saved_paths = []
print(f"πŸ’Ύ Saving {len(results['results'])} retrieved images...")
for i, item in enumerate(results["results"], 1):
original_path = item.image_uri
image = Image.open(original_path)
# Generate new filename
filename = f"{query_safe}_result_{i:02d}.jpg"
save_path = os.path.join(output_dir, filename)
# Save image
image.save(save_path, "JPEG", quality=95)
saved_paths.append(save_path)
print(f" βœ… Saved image {i}: {filename}")
print(f" Description: {item.description[:60]}...")
print(f"πŸ’Ύ Saved {len(saved_paths)} images to: {output_dir}")
return saved_paths
# =============================================================================
# SECTION 6: COMPLETE RAG PIPELINE
# =============================================================================
def run_fashion_rag_pipeline(
query: str,
database_path: str = "fashion_db",
table_name: str = "fashion_items",
search_type: str = "auto",
limit: int = 3,
save_images: bool = True,
) -> Dict[str, Any]:
"""
Run complete fashion RAG pipeline
Complete this function to orchestrate the entire pipeline:
1. RETRIEVAL: Search for relevant fashion items using vector database
2. AUGMENTATION: Create enhanced prompt with retrieved context
3. GENERATION: Generate LLM response using the enhanced prompt
4. IMAGE STORAGE: Save retrieved images if requested
This is the main function that ties everything together!
PIPELINE PHASES:
Phase 1 - RETRIEVAL: Find similar fashion items
Phase 2 - AUGMENTATION: Create context-rich prompt
Phase 3 - GENERATION: Generate helpful response
Phase 4 - STORAGE: Save retrieved images
"""
print("πŸš€ Starting Fashion RAG Pipeline")
print("=" * 50)
# PHASE 1: RETRIEVAL
print("πŸ” PHASE 1: RETRIEVAL")
# Search for fashion items using the search function
# HINT: Call search_fashion_items() with the provided parameters
results, actual_search_type = search_fashion_items(
database_path=database_path,
table_name=table_name,
query=query,
search_type=search_type,
limit=limit,
)
print(f" Found {len(results)} relevant items")
print(f" Search type used: {actual_search_type}")
# PHASE 2: AUGMENTATION
print("πŸ“ PHASE 2: AUGMENTATION")
# Create enhanced prompt using retrieved items
# HINT: Call create_fashion_prompt() with parameters
enhanced_prompt = create_fashion_prompt(
query=query,
retrieved_items=results,
search_type=actual_search_type,
)
print(f" Created enhanced prompt ({len(enhanced_prompt)} chars)")
# PHASE 3: GENERATION
print("πŸ€– PHASE 3: GENERATION")
# Set up LLM and generate response
tokenizer, model = setup_llm_model()
if not tokenizer or not model:
print("⚠️ LLM not loaded - skipping response generation")
response = "⚠️ LLM not available"
else:
# Generate response using the enhanced prompt
response = generate_fashion_response(
prompt=enhanced_prompt,
tokenizer=tokenizer,
model=model,
max_tokens=200,
)
print(f" Generated response ({len(response)} chars)")
# Prepare final results dictionary
final_results = {
"query": query,
"results": results,
"response": response,
"search_type": actual_search_type
}
# Save retrieved images if requested
if save_images:
saved_image_paths = save_retrieved_images(final_results)
final_results["saved_image_paths"] = saved_image_paths
# Return final results
return final_results
# =============================================================================
# GRADIO WEB APP
# =============================================================================
def fashion_search_app(query):
"""
Process fashion query and return response with images for Gradio
Complete this function to handle web app queries
STEPS TO IMPLEMENT:
1. Check if query is provided
2. Setup database if needed
3. Run RAG pipeline
4. Extract LLM response and images
5. Return formatted results for Gradio
"""
if not query.strip():
return "Please enter a search query", []
# Setup database if needed (will skip if exists)
print("πŸ”§ Checking/setting up fashion database...")
setup_fashion_database()
# Run the RAG pipeline
result = run_fashion_rag_pipeline(
query=query,
database_path="fashion_db",
table_name="fashion_items",
search_type="auto",
limit=3,
save_images=True,
)
print("🎯 RAG pipeline completed")
# Get LLM response
llm_response = result['response']
print(f"πŸ€– LLM Response: {llm_response[:60]}...")
# Get retrieved images for display
retrieved_images = []
for item in result['results']:
if os.path.exists(item.image_uri):
img = Image.open(item.image_uri)
retrieved_images.append(img)
# Return response and images
return llm_response, retrieved_images
def launch_gradio_app():
"""Launch the Gradio web interface"""
# Create Gradio interface
with gr.Blocks(title="Fashion RAG Assistant") as app:
gr.Markdown("# πŸ‘— Fashion RAG Assistant")
gr.Markdown("Search for fashion items and get AI-powered recommendations!")
with gr.Row():
with gr.Column(scale=1):
# Input
query_input = gr.Textbox(
label="Search Query",
placeholder="Enter your fashion query (e.g., 'black dress for evening')",
lines=2,
)
search_btn = gr.Button("Search", variant="primary")
# Examples
gr.Examples(
examples=[
"black dress for evening",
"casual summer outfit",
"blue jeans",
"white shirt",
"winter jacket",
],
inputs=query_input,
)
with gr.Column(scale=2):
# Output
response_output = gr.Textbox(
label="Fashion Recommendation", lines=10, interactive=True, autoscroll=True
)
# Retrieved Images
images_output = gr.Gallery(
label="Retrieved Fashion Items", columns=3, height=400
)
# Connect the search function
search_btn.click(
fn=fashion_search_app,
inputs=query_input,
outputs=[response_output, images_output],
)
# Also trigger on Enter key
query_input.submit(
fn=fashion_search_app,
inputs=query_input,
outputs=[response_output, images_output],
)
print("πŸš€ Starting Fashion RAG Gradio App...")
print("πŸ“ Note: First run will download dataset and setup database")
app.launch(share=True)
# =============================================================================
# MAIN EXECUTION
# =============================================================================
def main():
"""Main function to handle command line arguments and run the pipeline"""
# If running in Hugging Face Spaces, automatically launch the app
if is_huggingface_space():
print("πŸ€— Running in Hugging Face Spaces - launching web app automatically")
launch_gradio_app()
return
parser = argparse.ArgumentParser(
description="Fashion RAG Pipeline Assignment - SOLUTION"
)
parser.add_argument("--query", type=str, help="Search query (text or image path)")
parser.add_argument("--app", action="store_true", help="Launch Gradio web app")
args = parser.parse_args()
# Launch web app if requested
if args.app:
launch_gradio_app()
return
if not args.query:
print("❌ Please provide a query with --query or use --app for web interface")
print("Examples:")
print(" python solution_fashion_rag.py --query 'black dress for evening'")
print(" python solution_fashion_rag.py --query 'fashion_images/dress.jpg'")
print(" python solution_fashion_rag.py --app")
return
# Setup database first (will skip if already exists)
print("πŸ”§ Checking/setting up fashion database...")
setup_fashion_database()
# Run the complete RAG pipeline with default settings
result = run_fashion_rag_pipeline(
query=args.query,
database_path="fashion_db",
table_name="fashion_items",
search_type="auto",
limit=3,
save_images=True,
)
# Display results
print("\n" + "=" * 50)
print("🎯 PIPELINE RESULTS")
print("=" * 50)
print(f"Query: {result['query']}")
print(f"Search Type: {result['search_type']}")
print(f"Results Found: {len(result['results'])}")
print("\nπŸ“‹ Retrieved Items:")
for i, item in enumerate(result["results"], 1):
print(f"{i}. {item.description}")
print(f"\nπŸ€– LLM Response:")
print(result["response"])
# Show saved images info if any
if result.get("saved_image_paths"):
print(f"\nπŸ“Έ Saved Images:")
for i, path in enumerate(result["saved_image_paths"], 1):
print(f"{i}. {path}")
if __name__ == "__main__":
main()