""" Fashion RAG Pipeline - Assignment Week 9: Multimodal RAG Pipeline with H&M Fashion Dataset OBJECTIVE: Build a complete multimodal RAG (Retrieval-Augmented Generation) pipeline that can search through fashion items using both text and image queries, then generate helpful responses using an LLM. LEARNING GOALS: - Understand the three phases of RAG: Retrieval, Augmentation, Generation - Work with multimodal data (images + text) - Use vector databases for similarity search - Integrate LLM for response generation - Build an end-to-end AI pipeline DATASET: H&M Fashion Caption Dataset - 20K+ fashion items with images and text descriptions - URL: https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption PIPELINE OVERVIEW: 1. RETRIEVAL: Find similar fashion items using vector search 2. AUGMENTATION: Create enhanced prompts with retrieved context 3. GENERATION: Generate helpful responses using LLM Commands to run: python assignment_fashion_rag.py --query "black dress for evening" python assignment_fashion_rag.py --app """ import argparse import os from random import sample import re # Suppress warnings import warnings from typing import Any, Dict, List, Optional, Tuple # Gradio for web interface import gradio as gr # Core dependencies import lancedb import pandas as pd import torch from datasets import load_dataset from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector from PIL import Image # LLM dependencies from transformers import AutoModelForCausalLM, AutoTokenizer warnings.filterwarnings("ignore") def is_huggingface_space(): """ Checks if the code is running within a Hugging Face Spaces environment. Returns: bool: True if running in HF Spaces, False otherwise. """ if os.environ.get("SYSTEM") == "spaces": return True else: return False # ============================================================================= # SECTION 1: DATABASE SETUP AND SCHEMA # ============================================================================= def register_embedding_model(model_name: str = "open-clip") -> Any: """ Register embedding model for vector search TODO: Complete this function HINT: Use EmbeddingFunctionRegistry to get and create the model Args: model_name: Name of the embedding model Returns: Embedding model instance """ # Get the registry instance registry = EmbeddingFunctionRegistry.get_instance() print(f"šŸ” Registering embedding model: {model_name}") # Get and create the model model = registry.get(model_name).create() # Return the model return model # Global embedding model clip_model = register_embedding_model() class FashionItem(LanceModel): """ Schema for fashion items in vector database TODO: Complete the schema definition HINT: This defines the structure of data stored in the vector database REQUIRED FIELDS: 1. vector: Vector field for CLIP embeddings (use clip_model.ndims()) 2. image_uri: String field for image file paths 3. description: Optional string field for text descriptions """ # Add vector field for embeddings vector: Vector(clip_model.ndims()) = clip_model.VectorField() # Add image field image_uri: str = clip_model.SourceField() # Add text description field description: Optional[str] = None @property def image(self): if isinstance(self.image_uri, str) and os.path.exists(self.image_uri): return Image.open(self.image_uri) elif hasattr(self.image_uri, "save"): # PIL Image object return self.image_uri else: # Return a placeholder or handle the case appropriately return None # ============================================================================= # SECTION 2: RETRIEVAL - Vector Database Operations # ============================================================================= def setup_fashion_database( database_path: str = "fashion_db", table_name: str = "fashion_items", dataset_name: str = "tomytjandra/h-and-m-fashion-caption", sample_size: int = 1000, images_dir: str = "fashion_images", ) -> None: """ Set up vector database with H&M fashion dataset Complete this function to: 1. Connect to LanceDB database 2. Check if table already exists (skip if it does) 3. Load H&M dataset from HuggingFace 4. Process and save images locally 5. Create vector database table """ print("šŸ”§ Setting up fashion database...") print(f"Database path: {database_path}") print(f"Dataset: {dataset_name}") print(f"Sample size: {sample_size}") # Connect to LanceDB db = lancedb.connect(database_path) # Check if table already exists if table_name in db.table_names(): existing_table = db.open_table(table_name) # open table print(f"āœ… Table '{table_name}' already exists with {len(existing_table)} items") return # Drop table #print(f"āš ļø Table '{table_name}' already exists, deleting it...") #db.drop_table(table_name) else: print(f"šŸ—ļø Table '{table_name}' does not exist, creating new fashion database...") # Load dataset from HuggingFace print("šŸ“„ Loading H&M fashion dataset...") dataset = load_dataset(dataset_name) train_data = dataset["train"] # Sample data to specified size in the sample_size parameter if len(train_data) > sample_size: indices = sample(range(len(train_data)), sample_size) train_data = train_data.select(indices) print(f"Processing {len(train_data)} fashion items...") # Create images directory os.makedirs(images_dir, exist_ok=True) # Process each item table_data = [] for i, item in enumerate(train_data): # Get image and text image = item["image"] text = item["text"] # Save image image_path = os.path.join(images_dir, f"fashion_{i:04d}.jpg") image.save(image_path) # Create record record = { "image_uri": image_path, "description": text } table_data.append(record) if (i + 1) % 100 == 0: print(f" Processed {i + 1}/{len(train_data)} items...") # Create vector database table if table_data: if table_name in db.table_names(): print(f"āš ļø Table '{table_name}' already exists, deleting it...") db.drop_table(table_name) print("šŸ—„ļø Creating vector database table...") table = db.create_table( table_name, schema=FashionItem, data=table_data, #embedding_function=clip_model, ) print(f"āœ… Created table '{table_name}' with {len(table_data)} items") else: print("āŒ No data to create table, please check dataset loading") print("šŸŽ‰ Fashion database setup complete!") def search_fashion_items( database_path: str, table_name: str, query: str, search_type: str = "auto", limit: int = 3, ) -> Tuple[List[Dict], str]: """ Search for fashion items using text or image query Complete this function to: 1. Determine if query is text or image (auto-detection) 2. Connect to the vector database 3. Perform similarity search using CLIP embeddings 4. Return search results and detected search type STEPS TO IMPLEMENT: 1. Auto-detect search type: check if query is a file path 2. Connect to database 3. Open table 4. Search based on type: - Image: load with PIL and search - Text: search directly with string 5. Return results and search type Args: database_path: Path to LanceDB database table_name: Name of the table to search query: Search query (text or image path) search_type: "auto", "text", or "image" limit: Number of results to return Returns: Tuple of (results_list, actual_search_type) """ print(f"šŸ” Searching for: {query}") # Determine search type automatically # HINT: Use os.path.exists(query) to check if query is a file path # HINT: If file exists, it's an image search; otherwise, it's text search if os.path.exists(query): actual_search_type = "image" else: actual_search_type = "text" print(f" Detected search type: {actual_search_type}") # Connect to database db = lancedb.connect(database_path) print(f"šŸ“‚ Connected to database: {database_path}") # Open the table table = db.open_table(table_name) print(f"šŸ“– Opened table: {table_name}") # Perform search based on detected type if actual_search_type == "image": # Load image and search image = Image.open(query) print(f" Searching with image: {query}") # # Get embeddings for the image # image_embedding = clip_model.embed_image(image) # # Perform similarity search # results = table.search( # vector=image_embedding, # limit=limit, # filter=None, # No additional filters # ).to_dicts() # print(f" Found {len(results)} results using image search") results = table.search(image).limit(limit).to_pydantic(FashionItem) else: # Text search print(f" Searching with text: {query}") results = table.search(query).limit(limit).to_pydantic(FashionItem) # Print results found print(f" Found {len(results)} results using {actual_search_type} search") # Return results and search type return results, actual_search_type # ============================================================================= # SECTION 3: AUGMENTATION - Prompt Engineering # ============================================================================= def create_fashion_prompt( query: str, retrieved_items: List[Dict], search_type: str ) -> str: """ Create enhanced prompt for LLM using retrieved fashion items Complete this function to create a well-structured prompt that: 1. Creates a system prompt defining the AI assistant's role 2. Formats retrieved items as context for the LLM 3. Includes the user's query appropriately 4. Combines everything into a coherent prompt PROMPT STRUCTURE: 1. System prompt: Define the AI as a fashion assistant 2. Context section: List retrieved fashion items with descriptions 3. Query section: Include user's original query 4. Instruction: Ask for fashion recommendations Args: query: Original user query retrieved_items: List of retrieved fashion items search_type: Type of search performed Returns: Enhanced prompt string for LLM """ # Create system prompt # HINT: Define the AI as a fashion assistant with expertise system_prompt = "You are a fashion assistant with expertise in clothing and accessories. " \ "Your task is to provide helpful fashion recommendations based on user queries and retrieved items." \ "For each of the retrieved item - Please provide helpful fashion recommendations. " \ "Be funny, creative, and engaging in your response." \ "Talk about only retrieved items and do not make up any information. " \ "If you do not have enough information, please say so. " \ "Do not talk about anything else" print("šŸ“ Creating enhanced prompt...") # Format retrieved items context context = "Here are some relevant fashion items from our catalog:\n\n" for i, item in enumerate(retrieved_items, 1): print (f" Adding item {i}: {item}...") # Ensure item has description and image URI context += f"{i}. {item.description}\n\n" # Create user query section # HINT: Handle different search types (image vs text) if search_type == "image": query_section = ( f"User searched for an image: {query}\n" "Please provide fashion recommendations based on the retrieved items and the image." ) else: query_section = ( f"User query: {query}\n" "Please provide fashion recommendations based on the retrieved items and the query." ) print(f" Query section created: {query_section[:60]}...") # Combine into final prompt # HINT: Combine system prompt, context, query section, and response instruction prompt = f"{system_prompt}\n\n{context}\n{query_section}\n\n " return prompt # ============================================================================= # SECTION 4: GENERATION - LLM Response Generation # ============================================================================= def setup_llm_model(model_name: str = "Qwen/Qwen2.5-0.5B-Instruct") -> Tuple[Any, Any]: """ Set up LLM model and tokenizer Complete this function to load the LLM model and tokenizer STEPS TO IMPLEMENT: 1. Load tokenizer 2. Load model 3. Configure model settings for GPU/CPU 5. Return tokenizer and model Args: model_name: Name of the model to load Returns: Tuple of (tokenizer, model) """ print(f"šŸ¤– Loading LLM model: {model_name}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) print(" Tokenizer loaded successfully") # Load model model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map="cpu" ) # Set pad token if not exists # TODO: Why are we doing this ? if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Print success message and return print("āœ… LLM model loaded successfully") return tokenizer, model def generate_fashion_response( prompt: str, tokenizer: Any, model: Any, max_tokens: int = 200 ) -> str: """ Generate response using LLM Complete this function to generate text using the LLM STEPS TO IMPLEMENT: 1. Check if tokenizer and model are loaded 2. Encode the prompt with attention mask 3. Generate response using model.generate() 4. Decode the response and clean it up 5. Return the generated text Args: prompt: Input prompt for the model tokenizer: Model tokenizer model: LLM model max_tokens: Maximum tokens to generate Returns: Generated response text """ if not tokenizer or not model: return "āš ļø LLM not loaded - showing search results only" # Encode prompt with attention mask # HINT: Use tokenizer() with return_tensors="pt", truncation=True, max_length=1024, padding=True inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048, padding=True ) # Added byself # Ensure everything runs on CPU inputs = {k: v.to("cpu") for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = model.generate( #inputs.input_ids, **inputs, #attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode response and clean it up full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_response.replace(prompt, "").strip() return response # ============================================================================= # SECTION 5: IMAGE STORAGE # ============================================================================= def save_retrieved_images( results: Dict[str, Any], output_dir: str = "retrieved_fashion_images" ) -> List[str]: """Save retrieved fashion images to output directory""" # Create output directory os.makedirs(output_dir, exist_ok=True) query_safe = re.sub(r"[^\w\s-]", "", str(results["query"]))[:30] query_safe = re.sub(r"[-\s]+", "_", query_safe) saved_paths = [] print(f"šŸ’¾ Saving {len(results['results'])} retrieved images...") for i, item in enumerate(results["results"], 1): original_path = item.image_uri image = Image.open(original_path) # Generate new filename filename = f"{query_safe}_result_{i:02d}.jpg" save_path = os.path.join(output_dir, filename) # Save image image.save(save_path, "JPEG", quality=95) saved_paths.append(save_path) print(f" āœ… Saved image {i}: {filename}") print(f" Description: {item.description[:60]}...") print(f"šŸ’¾ Saved {len(saved_paths)} images to: {output_dir}") return saved_paths # ============================================================================= # SECTION 6: COMPLETE RAG PIPELINE # ============================================================================= def run_fashion_rag_pipeline( query: str, database_path: str = "fashion_db", table_name: str = "fashion_items", search_type: str = "auto", limit: int = 3, save_images: bool = True, ) -> Dict[str, Any]: """ Run complete fashion RAG pipeline Complete this function to orchestrate the entire pipeline: 1. RETRIEVAL: Search for relevant fashion items using vector database 2. AUGMENTATION: Create enhanced prompt with retrieved context 3. GENERATION: Generate LLM response using the enhanced prompt 4. IMAGE STORAGE: Save retrieved images if requested This is the main function that ties everything together! PIPELINE PHASES: Phase 1 - RETRIEVAL: Find similar fashion items Phase 2 - AUGMENTATION: Create context-rich prompt Phase 3 - GENERATION: Generate helpful response Phase 4 - STORAGE: Save retrieved images """ print("šŸš€ Starting Fashion RAG Pipeline") print("=" * 50) # PHASE 1: RETRIEVAL print("šŸ” PHASE 1: RETRIEVAL") # Search for fashion items using the search function # HINT: Call search_fashion_items() with the provided parameters results, actual_search_type = search_fashion_items( database_path=database_path, table_name=table_name, query=query, search_type=search_type, limit=limit, ) print(f" Found {len(results)} relevant items") print(f" Search type used: {actual_search_type}") # PHASE 2: AUGMENTATION print("šŸ“ PHASE 2: AUGMENTATION") # Create enhanced prompt using retrieved items # HINT: Call create_fashion_prompt() with parameters enhanced_prompt = create_fashion_prompt( query=query, retrieved_items=results, search_type=actual_search_type, ) print(f" Created enhanced prompt ({len(enhanced_prompt)} chars)") # PHASE 3: GENERATION print("šŸ¤– PHASE 3: GENERATION") # Set up LLM and generate response tokenizer, model = setup_llm_model() if not tokenizer or not model: print("āš ļø LLM not loaded - skipping response generation") response = "āš ļø LLM not available" else: # Generate response using the enhanced prompt response = generate_fashion_response( prompt=enhanced_prompt, tokenizer=tokenizer, model=model, max_tokens=200, ) print(f" Generated response ({len(response)} chars)") # Prepare final results dictionary final_results = { "query": query, "results": results, "response": response, "search_type": actual_search_type } # Save retrieved images if requested if save_images: saved_image_paths = save_retrieved_images(final_results) final_results["saved_image_paths"] = saved_image_paths # Return final results return final_results # ============================================================================= # GRADIO WEB APP # ============================================================================= def fashion_search_app(query): """ Process fashion query and return response with images for Gradio Complete this function to handle web app queries STEPS TO IMPLEMENT: 1. Check if query is provided 2. Setup database if needed 3. Run RAG pipeline 4. Extract LLM response and images 5. Return formatted results for Gradio """ if not query.strip(): return "Please enter a search query", [] # Setup database if needed (will skip if exists) print("šŸ”§ Checking/setting up fashion database...") setup_fashion_database() # Run the RAG pipeline result = run_fashion_rag_pipeline( query=query, database_path="fashion_db", table_name="fashion_items", search_type="auto", limit=3, save_images=True, ) print("šŸŽÆ RAG pipeline completed") # Get LLM response llm_response = result['response'] print(f"šŸ¤– LLM Response: {llm_response[:60]}...") # Get retrieved images for display retrieved_images = [] for item in result['results']: if os.path.exists(item.image_uri): img = Image.open(item.image_uri) retrieved_images.append(img) # Return response and images return llm_response, retrieved_images def launch_gradio_app(): """Launch the Gradio web interface""" # Create Gradio interface with gr.Blocks(title="Fashion RAG Assistant") as app: gr.Markdown("# šŸ‘— Fashion RAG Assistant") gr.Markdown("Search for fashion items and get AI-powered recommendations!") with gr.Row(): with gr.Column(scale=1): # Input query_input = gr.Textbox( label="Search Query", placeholder="Enter your fashion query (e.g., 'black dress for evening')", lines=2, ) search_btn = gr.Button("Search", variant="primary") # Examples gr.Examples( examples=[ "black dress for evening", "casual summer outfit", "blue jeans", "white shirt", "winter jacket", ], inputs=query_input, ) with gr.Column(scale=2): # Output response_output = gr.Textbox( label="Fashion Recommendation", lines=10, interactive=True, autoscroll=True ) # Retrieved Images images_output = gr.Gallery( label="Retrieved Fashion Items", columns=3, height=400 ) # Connect the search function search_btn.click( fn=fashion_search_app, inputs=query_input, outputs=[response_output, images_output], ) # Also trigger on Enter key query_input.submit( fn=fashion_search_app, inputs=query_input, outputs=[response_output, images_output], ) print("šŸš€ Starting Fashion RAG Gradio App...") print("šŸ“ Note: First run will download dataset and setup database") app.launch(share=True) # ============================================================================= # MAIN EXECUTION # ============================================================================= def main(): """Main function to handle command line arguments and run the pipeline""" # If running in Hugging Face Spaces, automatically launch the app if is_huggingface_space(): print("šŸ¤— Running in Hugging Face Spaces - launching web app automatically") launch_gradio_app() return parser = argparse.ArgumentParser( description="Fashion RAG Pipeline Assignment - SOLUTION" ) parser.add_argument("--query", type=str, help="Search query (text or image path)") parser.add_argument("--app", action="store_true", help="Launch Gradio web app") args = parser.parse_args() # Launch web app if requested if args.app: launch_gradio_app() return if not args.query: print("āŒ Please provide a query with --query or use --app for web interface") print("Examples:") print(" python solution_fashion_rag.py --query 'black dress for evening'") print(" python solution_fashion_rag.py --query 'fashion_images/dress.jpg'") print(" python solution_fashion_rag.py --app") return # Setup database first (will skip if already exists) print("šŸ”§ Checking/setting up fashion database...") setup_fashion_database() # Run the complete RAG pipeline with default settings result = run_fashion_rag_pipeline( query=args.query, database_path="fashion_db", table_name="fashion_items", search_type="auto", limit=3, save_images=True, ) # Display results print("\n" + "=" * 50) print("šŸŽÆ PIPELINE RESULTS") print("=" * 50) print(f"Query: {result['query']}") print(f"Search Type: {result['search_type']}") print(f"Results Found: {len(result['results'])}") print("\nšŸ“‹ Retrieved Items:") for i, item in enumerate(result["results"], 1): print(f"{i}. {item.description}") print(f"\nšŸ¤– LLM Response:") print(result["response"]) # Show saved images info if any if result.get("saved_image_paths"): print(f"\nšŸ“ø Saved Images:") for i, path in enumerate(result["saved_image_paths"], 1): print(f"{i}. {path}") if __name__ == "__main__": main()