Spaces:
Running
Running
""" | |
Fashion RAG Pipeline - Assignment | |
Week 9: Multimodal RAG Pipeline with H&M Fashion Dataset | |
OBJECTIVE: Build a complete multimodal RAG (Retrieval-Augmented Generation) pipeline | |
that can search through fashion items using both text and image queries, then generate | |
helpful responses using an LLM. | |
LEARNING GOALS: | |
- Understand the three phases of RAG: Retrieval, Augmentation, Generation | |
- Work with multimodal data (images + text) | |
- Use vector databases for similarity search | |
- Integrate LLM for response generation | |
- Build an end-to-end AI pipeline | |
DATASET: H&M Fashion Caption Dataset | |
- 20K+ fashion items with images and text descriptions | |
- URL: https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption | |
PIPELINE OVERVIEW: | |
1. RETRIEVAL: Find similar fashion items using vector search | |
2. AUGMENTATION: Create enhanced prompts with retrieved context | |
3. GENERATION: Generate helpful responses using LLM | |
Commands to run: | |
python assignment_fashion_rag.py --query "black dress for evening" | |
python assignment_fashion_rag.py --app | |
""" | |
import argparse | |
import os | |
from random import sample | |
import re | |
# Suppress warnings | |
import warnings | |
from typing import Any, Dict, List, Optional, Tuple | |
# Gradio for web interface | |
import gradio as gr | |
# Core dependencies | |
import lancedb | |
import pandas as pd | |
import torch | |
from datasets import load_dataset | |
from lancedb.embeddings import EmbeddingFunctionRegistry | |
from lancedb.pydantic import LanceModel, Vector | |
from PIL import Image | |
# LLM dependencies | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
warnings.filterwarnings("ignore") | |
def is_huggingface_space(): | |
""" | |
Checks if the code is running within a Hugging Face Spaces environment. | |
Returns: | |
bool: True if running in HF Spaces, False otherwise. | |
""" | |
if os.environ.get("SYSTEM") == "spaces": | |
return True | |
else: | |
return False | |
# ============================================================================= | |
# SECTION 1: DATABASE SETUP AND SCHEMA | |
# ============================================================================= | |
def register_embedding_model(model_name: str = "open-clip") -> Any: | |
""" | |
Register embedding model for vector search | |
TODO: Complete this function | |
HINT: Use EmbeddingFunctionRegistry to get and create the model | |
Args: | |
model_name: Name of the embedding model | |
Returns: | |
Embedding model instance | |
""" | |
# Get the registry instance | |
registry = EmbeddingFunctionRegistry.get_instance() | |
print(f"π Registering embedding model: {model_name}") | |
# Get and create the model | |
model = registry.get(model_name).create() | |
# Return the model | |
return model | |
# Global embedding model | |
clip_model = register_embedding_model() | |
class FashionItem(LanceModel): | |
""" | |
Schema for fashion items in vector database | |
TODO: Complete the schema definition | |
HINT: This defines the structure of data stored in the vector database | |
REQUIRED FIELDS: | |
1. vector: Vector field for CLIP embeddings (use clip_model.ndims()) | |
2. image_uri: String field for image file paths | |
3. description: Optional string field for text descriptions | |
""" | |
# Add vector field for embeddings | |
vector: Vector(clip_model.ndims()) = clip_model.VectorField() | |
# Add image field | |
image_uri: str = clip_model.SourceField() | |
# Add text description field | |
description: Optional[str] = None | |
def image(self): | |
if isinstance(self.image_uri, str) and os.path.exists(self.image_uri): | |
return Image.open(self.image_uri) | |
elif hasattr(self.image_uri, "save"): # PIL Image object | |
return self.image_uri | |
else: | |
# Return a placeholder or handle the case appropriately | |
return None | |
# ============================================================================= | |
# SECTION 2: RETRIEVAL - Vector Database Operations | |
# ============================================================================= | |
def setup_fashion_database( | |
database_path: str = "fashion_db", | |
table_name: str = "fashion_items", | |
dataset_name: str = "tomytjandra/h-and-m-fashion-caption", | |
sample_size: int = 1000, | |
images_dir: str = "fashion_images", | |
) -> None: | |
""" | |
Set up vector database with H&M fashion dataset | |
Complete this function to: | |
1. Connect to LanceDB database | |
2. Check if table already exists (skip if it does) | |
3. Load H&M dataset from HuggingFace | |
4. Process and save images locally | |
5. Create vector database table | |
""" | |
print("π§ Setting up fashion database...") | |
print(f"Database path: {database_path}") | |
print(f"Dataset: {dataset_name}") | |
print(f"Sample size: {sample_size}") | |
# Connect to LanceDB | |
db = lancedb.connect(database_path) | |
# Check if table already exists | |
if table_name in db.table_names(): | |
existing_table = db.open_table(table_name) # open table | |
print(f"β Table '{table_name}' already exists with {len(existing_table)} items") | |
return | |
# Drop table | |
#print(f"β οΈ Table '{table_name}' already exists, deleting it...") | |
#db.drop_table(table_name) | |
else: | |
print(f"ποΈ Table '{table_name}' does not exist, creating new fashion database...") | |
# Load dataset from HuggingFace | |
print("π₯ Loading H&M fashion dataset...") | |
dataset = load_dataset(dataset_name) | |
train_data = dataset["train"] | |
# Sample data to specified size in the sample_size parameter | |
if len(train_data) > sample_size: | |
indices = sample(range(len(train_data)), sample_size) | |
train_data = train_data.select(indices) | |
print(f"Processing {len(train_data)} fashion items...") | |
# Create images directory | |
os.makedirs(images_dir, exist_ok=True) | |
# Process each item | |
table_data = [] | |
for i, item in enumerate(train_data): | |
# Get image and text | |
image = item["image"] | |
text = item["text"] | |
# Save image | |
image_path = os.path.join(images_dir, f"fashion_{i:04d}.jpg") | |
image.save(image_path) | |
# Create record | |
record = { | |
"image_uri": image_path, | |
"description": text | |
} | |
table_data.append(record) | |
if (i + 1) % 100 == 0: | |
print(f" Processed {i + 1}/{len(train_data)} items...") | |
# Create vector database table | |
if table_data: | |
if table_name in db.table_names(): | |
print(f"β οΈ Table '{table_name}' already exists, deleting it...") | |
db.drop_table(table_name) | |
print("ποΈ Creating vector database table...") | |
table = db.create_table( | |
table_name, | |
schema=FashionItem, | |
data=table_data, | |
#embedding_function=clip_model, | |
) | |
print(f"β Created table '{table_name}' with {len(table_data)} items") | |
else: | |
print("β No data to create table, please check dataset loading") | |
print("π Fashion database setup complete!") | |
def search_fashion_items( | |
database_path: str, | |
table_name: str, | |
query: str, | |
search_type: str = "auto", | |
limit: int = 3, | |
) -> Tuple[List[Dict], str]: | |
""" | |
Search for fashion items using text or image query | |
Complete this function to: | |
1. Determine if query is text or image (auto-detection) | |
2. Connect to the vector database | |
3. Perform similarity search using CLIP embeddings | |
4. Return search results and detected search type | |
STEPS TO IMPLEMENT: | |
1. Auto-detect search type: check if query is a file path | |
2. Connect to database | |
3. Open table | |
4. Search based on type: | |
- Image: load with PIL and search | |
- Text: search directly with string | |
5. Return results and search type | |
Args: | |
database_path: Path to LanceDB database | |
table_name: Name of the table to search | |
query: Search query (text or image path) | |
search_type: "auto", "text", or "image" | |
limit: Number of results to return | |
Returns: | |
Tuple of (results_list, actual_search_type) | |
""" | |
print(f"π Searching for: {query}") | |
# Determine search type automatically | |
# HINT: Use os.path.exists(query) to check if query is a file path | |
# HINT: If file exists, it's an image search; otherwise, it's text search | |
if os.path.exists(query): | |
actual_search_type = "image" | |
else: | |
actual_search_type = "text" | |
print(f" Detected search type: {actual_search_type}") | |
# Connect to database | |
db = lancedb.connect(database_path) | |
print(f"π Connected to database: {database_path}") | |
# Open the table | |
table = db.open_table(table_name) | |
print(f"π Opened table: {table_name}") | |
# Perform search based on detected type | |
if actual_search_type == "image": | |
# Load image and search | |
image = Image.open(query) | |
print(f" Searching with image: {query}") | |
# # Get embeddings for the image | |
# image_embedding = clip_model.embed_image(image) | |
# # Perform similarity search | |
# results = table.search( | |
# vector=image_embedding, | |
# limit=limit, | |
# filter=None, # No additional filters | |
# ).to_dicts() | |
# print(f" Found {len(results)} results using image search") | |
results = table.search(image).limit(limit).to_pydantic(FashionItem) | |
else: | |
# Text search | |
print(f" Searching with text: {query}") | |
results = table.search(query).limit(limit).to_pydantic(FashionItem) | |
# Print results found | |
print(f" Found {len(results)} results using {actual_search_type} search") | |
# Return results and search type | |
return results, actual_search_type | |
# ============================================================================= | |
# SECTION 3: AUGMENTATION - Prompt Engineering | |
# ============================================================================= | |
def create_fashion_prompt( | |
query: str, retrieved_items: List[Dict], search_type: str | |
) -> str: | |
""" | |
Create enhanced prompt for LLM using retrieved fashion items | |
Complete this function to create a well-structured prompt that: | |
1. Creates a system prompt defining the AI assistant's role | |
2. Formats retrieved items as context for the LLM | |
3. Includes the user's query appropriately | |
4. Combines everything into a coherent prompt | |
PROMPT STRUCTURE: | |
1. System prompt: Define the AI as a fashion assistant | |
2. Context section: List retrieved fashion items with descriptions | |
3. Query section: Include user's original query | |
4. Instruction: Ask for fashion recommendations | |
Args: | |
query: Original user query | |
retrieved_items: List of retrieved fashion items | |
search_type: Type of search performed | |
Returns: | |
Enhanced prompt string for LLM | |
""" | |
# Create system prompt | |
# HINT: Define the AI as a fashion assistant with expertise | |
system_prompt = "You are a fashion assistant with expertise in clothing and accessories. " \ | |
"Your task is to provide helpful fashion recommendations based on user queries and retrieved items." \ | |
"For each of the retrieved item - Please provide helpful fashion recommendations. " \ | |
"Be funny, creative, and engaging in your response." \ | |
"Talk about only retrieved items and do not make up any information. " \ | |
"If you do not have enough information, please say so. " \ | |
"Do not talk about anything else" | |
print("π Creating enhanced prompt...") | |
# Format retrieved items context | |
context = "Here are some relevant fashion items from our catalog:\n\n" | |
for i, item in enumerate(retrieved_items, 1): | |
print (f" Adding item {i}: {item}...") | |
# Ensure item has description and image URI | |
context += f"{i}. {item.description}\n\n" | |
# Create user query section | |
# HINT: Handle different search types (image vs text) | |
if search_type == "image": | |
query_section = ( | |
f"User searched for an image: {query}\n" | |
"Please provide fashion recommendations based on the retrieved items and the image." | |
) | |
else: | |
query_section = ( | |
f"User query: {query}\n" | |
"Please provide fashion recommendations based on the retrieved items and the query." | |
) | |
print(f" Query section created: {query_section[:60]}...") | |
# Combine into final prompt | |
# HINT: Combine system prompt, context, query section, and response instruction | |
prompt = f"{system_prompt}\n\n{context}\n{query_section}\n\n " | |
return prompt | |
# ============================================================================= | |
# SECTION 4: GENERATION - LLM Response Generation | |
# ============================================================================= | |
def setup_llm_model(model_name: str = "Qwen/Qwen2.5-0.5B-Instruct") -> Tuple[Any, Any]: | |
""" | |
Set up LLM model and tokenizer | |
Complete this function to load the LLM model and tokenizer | |
STEPS TO IMPLEMENT: | |
1. Load tokenizer | |
2. Load model | |
3. Configure model settings for GPU/CPU | |
5. Return tokenizer and model | |
Args: | |
model_name: Name of the model to load | |
Returns: | |
Tuple of (tokenizer, model) | |
""" | |
print(f"π€ Loading LLM model: {model_name}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print(" Tokenizer loaded successfully") | |
# Load model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, torch_dtype=torch.float32, device_map="cpu" | |
) | |
# Set pad token if not exists | |
# TODO: Why are we doing this ? | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Print success message and return | |
print("β LLM model loaded successfully") | |
return tokenizer, model | |
def generate_fashion_response( | |
prompt: str, tokenizer: Any, model: Any, max_tokens: int = 200 | |
) -> str: | |
""" | |
Generate response using LLM | |
Complete this function to generate text using the LLM | |
STEPS TO IMPLEMENT: | |
1. Check if tokenizer and model are loaded | |
2. Encode the prompt with attention mask | |
3. Generate response using model.generate() | |
4. Decode the response and clean it up | |
5. Return the generated text | |
Args: | |
prompt: Input prompt for the model | |
tokenizer: Model tokenizer | |
model: LLM model | |
max_tokens: Maximum tokens to generate | |
Returns: | |
Generated response text | |
""" | |
if not tokenizer or not model: | |
return "β οΈ LLM not loaded - showing search results only" | |
# Encode prompt with attention mask | |
# HINT: Use tokenizer() with return_tensors="pt", truncation=True, max_length=1024, padding=True | |
inputs = tokenizer( | |
prompt, return_tensors="pt", truncation=True, max_length=2048, padding=True | |
) | |
# Added byself | |
# Ensure everything runs on CPU | |
inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
#inputs.input_ids, | |
**inputs, | |
#attention_mask=inputs.attention_mask, | |
max_new_tokens=max_tokens, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode response and clean it up | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = full_response.replace(prompt, "").strip() | |
return response | |
# ============================================================================= | |
# SECTION 5: IMAGE STORAGE | |
# ============================================================================= | |
def save_retrieved_images( | |
results: Dict[str, Any], output_dir: str = "retrieved_fashion_images" | |
) -> List[str]: | |
"""Save retrieved fashion images to output directory""" | |
# Create output directory | |
os.makedirs(output_dir, exist_ok=True) | |
query_safe = re.sub(r"[^\w\s-]", "", str(results["query"]))[:30] | |
query_safe = re.sub(r"[-\s]+", "_", query_safe) | |
saved_paths = [] | |
print(f"πΎ Saving {len(results['results'])} retrieved images...") | |
for i, item in enumerate(results["results"], 1): | |
original_path = item.image_uri | |
image = Image.open(original_path) | |
# Generate new filename | |
filename = f"{query_safe}_result_{i:02d}.jpg" | |
save_path = os.path.join(output_dir, filename) | |
# Save image | |
image.save(save_path, "JPEG", quality=95) | |
saved_paths.append(save_path) | |
print(f" β Saved image {i}: {filename}") | |
print(f" Description: {item.description[:60]}...") | |
print(f"πΎ Saved {len(saved_paths)} images to: {output_dir}") | |
return saved_paths | |
# ============================================================================= | |
# SECTION 6: COMPLETE RAG PIPELINE | |
# ============================================================================= | |
def run_fashion_rag_pipeline( | |
query: str, | |
database_path: str = "fashion_db", | |
table_name: str = "fashion_items", | |
search_type: str = "auto", | |
limit: int = 3, | |
save_images: bool = True, | |
) -> Dict[str, Any]: | |
""" | |
Run complete fashion RAG pipeline | |
Complete this function to orchestrate the entire pipeline: | |
1. RETRIEVAL: Search for relevant fashion items using vector database | |
2. AUGMENTATION: Create enhanced prompt with retrieved context | |
3. GENERATION: Generate LLM response using the enhanced prompt | |
4. IMAGE STORAGE: Save retrieved images if requested | |
This is the main function that ties everything together! | |
PIPELINE PHASES: | |
Phase 1 - RETRIEVAL: Find similar fashion items | |
Phase 2 - AUGMENTATION: Create context-rich prompt | |
Phase 3 - GENERATION: Generate helpful response | |
Phase 4 - STORAGE: Save retrieved images | |
""" | |
print("π Starting Fashion RAG Pipeline") | |
print("=" * 50) | |
# PHASE 1: RETRIEVAL | |
print("π PHASE 1: RETRIEVAL") | |
# Search for fashion items using the search function | |
# HINT: Call search_fashion_items() with the provided parameters | |
results, actual_search_type = search_fashion_items( | |
database_path=database_path, | |
table_name=table_name, | |
query=query, | |
search_type=search_type, | |
limit=limit, | |
) | |
print(f" Found {len(results)} relevant items") | |
print(f" Search type used: {actual_search_type}") | |
# PHASE 2: AUGMENTATION | |
print("π PHASE 2: AUGMENTATION") | |
# Create enhanced prompt using retrieved items | |
# HINT: Call create_fashion_prompt() with parameters | |
enhanced_prompt = create_fashion_prompt( | |
query=query, | |
retrieved_items=results, | |
search_type=actual_search_type, | |
) | |
print(f" Created enhanced prompt ({len(enhanced_prompt)} chars)") | |
# PHASE 3: GENERATION | |
print("π€ PHASE 3: GENERATION") | |
# Set up LLM and generate response | |
tokenizer, model = setup_llm_model() | |
if not tokenizer or not model: | |
print("β οΈ LLM not loaded - skipping response generation") | |
response = "β οΈ LLM not available" | |
else: | |
# Generate response using the enhanced prompt | |
response = generate_fashion_response( | |
prompt=enhanced_prompt, | |
tokenizer=tokenizer, | |
model=model, | |
max_tokens=200, | |
) | |
print(f" Generated response ({len(response)} chars)") | |
# Prepare final results dictionary | |
final_results = { | |
"query": query, | |
"results": results, | |
"response": response, | |
"search_type": actual_search_type | |
} | |
# Save retrieved images if requested | |
if save_images: | |
saved_image_paths = save_retrieved_images(final_results) | |
final_results["saved_image_paths"] = saved_image_paths | |
# Return final results | |
return final_results | |
# ============================================================================= | |
# GRADIO WEB APP | |
# ============================================================================= | |
def fashion_search_app(query): | |
""" | |
Process fashion query and return response with images for Gradio | |
Complete this function to handle web app queries | |
STEPS TO IMPLEMENT: | |
1. Check if query is provided | |
2. Setup database if needed | |
3. Run RAG pipeline | |
4. Extract LLM response and images | |
5. Return formatted results for Gradio | |
""" | |
if not query.strip(): | |
return "Please enter a search query", [] | |
# Setup database if needed (will skip if exists) | |
print("π§ Checking/setting up fashion database...") | |
setup_fashion_database() | |
# Run the RAG pipeline | |
result = run_fashion_rag_pipeline( | |
query=query, | |
database_path="fashion_db", | |
table_name="fashion_items", | |
search_type="auto", | |
limit=3, | |
save_images=True, | |
) | |
print("π― RAG pipeline completed") | |
# Get LLM response | |
llm_response = result['response'] | |
print(f"π€ LLM Response: {llm_response[:60]}...") | |
# Get retrieved images for display | |
retrieved_images = [] | |
for item in result['results']: | |
if os.path.exists(item.image_uri): | |
img = Image.open(item.image_uri) | |
retrieved_images.append(img) | |
# Return response and images | |
return llm_response, retrieved_images | |
def launch_gradio_app(): | |
"""Launch the Gradio web interface""" | |
# Create Gradio interface | |
with gr.Blocks(title="Fashion RAG Assistant") as app: | |
gr.Markdown("# π Fashion RAG Assistant") | |
gr.Markdown("Search for fashion items and get AI-powered recommendations!") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input | |
query_input = gr.Textbox( | |
label="Search Query", | |
placeholder="Enter your fashion query (e.g., 'black dress for evening')", | |
lines=2, | |
) | |
search_btn = gr.Button("Search", variant="primary") | |
# Examples | |
gr.Examples( | |
examples=[ | |
"black dress for evening", | |
"casual summer outfit", | |
"blue jeans", | |
"white shirt", | |
"winter jacket", | |
], | |
inputs=query_input, | |
) | |
with gr.Column(scale=2): | |
# Output | |
response_output = gr.Textbox( | |
label="Fashion Recommendation", lines=10, interactive=True, autoscroll=True | |
) | |
# Retrieved Images | |
images_output = gr.Gallery( | |
label="Retrieved Fashion Items", columns=3, height=400 | |
) | |
# Connect the search function | |
search_btn.click( | |
fn=fashion_search_app, | |
inputs=query_input, | |
outputs=[response_output, images_output], | |
) | |
# Also trigger on Enter key | |
query_input.submit( | |
fn=fashion_search_app, | |
inputs=query_input, | |
outputs=[response_output, images_output], | |
) | |
print("π Starting Fashion RAG Gradio App...") | |
print("π Note: First run will download dataset and setup database") | |
app.launch(share=True) | |
# ============================================================================= | |
# MAIN EXECUTION | |
# ============================================================================= | |
def main(): | |
"""Main function to handle command line arguments and run the pipeline""" | |
# If running in Hugging Face Spaces, automatically launch the app | |
if is_huggingface_space(): | |
print("π€ Running in Hugging Face Spaces - launching web app automatically") | |
launch_gradio_app() | |
return | |
parser = argparse.ArgumentParser( | |
description="Fashion RAG Pipeline Assignment - SOLUTION" | |
) | |
parser.add_argument("--query", type=str, help="Search query (text or image path)") | |
parser.add_argument("--app", action="store_true", help="Launch Gradio web app") | |
args = parser.parse_args() | |
# Launch web app if requested | |
if args.app: | |
launch_gradio_app() | |
return | |
if not args.query: | |
print("β Please provide a query with --query or use --app for web interface") | |
print("Examples:") | |
print(" python solution_fashion_rag.py --query 'black dress for evening'") | |
print(" python solution_fashion_rag.py --query 'fashion_images/dress.jpg'") | |
print(" python solution_fashion_rag.py --app") | |
return | |
# Setup database first (will skip if already exists) | |
print("π§ Checking/setting up fashion database...") | |
setup_fashion_database() | |
# Run the complete RAG pipeline with default settings | |
result = run_fashion_rag_pipeline( | |
query=args.query, | |
database_path="fashion_db", | |
table_name="fashion_items", | |
search_type="auto", | |
limit=3, | |
save_images=True, | |
) | |
# Display results | |
print("\n" + "=" * 50) | |
print("π― PIPELINE RESULTS") | |
print("=" * 50) | |
print(f"Query: {result['query']}") | |
print(f"Search Type: {result['search_type']}") | |
print(f"Results Found: {len(result['results'])}") | |
print("\nπ Retrieved Items:") | |
for i, item in enumerate(result["results"], 1): | |
print(f"{i}. {item.description}") | |
print(f"\nπ€ LLM Response:") | |
print(result["response"]) | |
# Show saved images info if any | |
if result.get("saved_image_paths"): | |
print(f"\nπΈ Saved Images:") | |
for i, path in enumerate(result["saved_image_paths"], 1): | |
print(f"{i}. {path}") | |
if __name__ == "__main__": | |
main() | |