dceshubh commited on
Commit
0cae3f6
Β·
verified Β·
1 Parent(s): 0bcafcc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +798 -0
app.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fashion RAG Pipeline - Assignment
3
+ Week 9: Multimodal RAG Pipeline with H&M Fashion Dataset
4
+
5
+ OBJECTIVE: Build a complete multimodal RAG (Retrieval-Augmented Generation) pipeline
6
+ that can search through fashion items using both text and image queries, then generate
7
+ helpful responses using an LLM.
8
+
9
+ LEARNING GOALS:
10
+ - Understand the three phases of RAG: Retrieval, Augmentation, Generation
11
+ - Work with multimodal data (images + text)
12
+ - Use vector databases for similarity search
13
+ - Integrate LLM for response generation
14
+ - Build an end-to-end AI pipeline
15
+
16
+ DATASET: H&M Fashion Caption Dataset
17
+ - 20K+ fashion items with images and text descriptions
18
+ - URL: https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption
19
+
20
+ PIPELINE OVERVIEW:
21
+ 1. RETRIEVAL: Find similar fashion items using vector search
22
+ 2. AUGMENTATION: Create enhanced prompts with retrieved context
23
+ 3. GENERATION: Generate helpful responses using LLM
24
+
25
+ Commands to run:
26
+ python assignment_fashion_rag.py --query "black dress for evening"
27
+ python assignment_fashion_rag.py --app
28
+ """
29
+
30
+ import argparse
31
+ import os
32
+ from random import sample
33
+ import re
34
+
35
+ # Suppress warnings
36
+ import warnings
37
+ from typing import Any, Dict, List, Optional, Tuple
38
+
39
+ # Gradio for web interface
40
+ import gradio as gr
41
+
42
+ # Core dependencies
43
+ import lancedb
44
+ import pandas as pd
45
+ import torch
46
+ from datasets import load_dataset
47
+ from lancedb.embeddings import EmbeddingFunctionRegistry
48
+ from lancedb.pydantic import LanceModel, Vector
49
+ from PIL import Image
50
+
51
+ # LLM dependencies
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ warnings.filterwarnings("ignore")
55
+
56
+
57
+ def is_huggingface_space():
58
+ """
59
+ Checks if the code is running within a Hugging Face Spaces environment.
60
+
61
+ Returns:
62
+ bool: True if running in HF Spaces, False otherwise.
63
+ """
64
+ if os.environ.get("SYSTEM") == "spaces":
65
+ return True
66
+ else:
67
+ return False
68
+
69
+
70
+ # =============================================================================
71
+ # SECTION 1: DATABASE SETUP AND SCHEMA
72
+ # =============================================================================
73
+
74
+
75
+ def register_embedding_model(model_name: str = "open-clip") -> Any:
76
+ """
77
+ Register embedding model for vector search
78
+
79
+ TODO: Complete this function
80
+ HINT: Use EmbeddingFunctionRegistry to get and create the model
81
+
82
+ Args:
83
+ model_name: Name of the embedding model
84
+ Returns:
85
+ Embedding model instance
86
+ """
87
+ # Get the registry instance
88
+ registry = EmbeddingFunctionRegistry.get_instance()
89
+ print(f"πŸ” Registering embedding model: {model_name}")
90
+
91
+ # Get and create the model
92
+ model = registry.get(model_name).create()
93
+
94
+ # Return the model
95
+ return model
96
+
97
+ # Global embedding model
98
+ clip_model = register_embedding_model()
99
+
100
+ class FashionItem(LanceModel):
101
+ """
102
+ Schema for fashion items in vector database
103
+
104
+ TODO: Complete the schema definition
105
+ HINT: This defines the structure of data stored in the vector database
106
+
107
+ REQUIRED FIELDS:
108
+ 1. vector: Vector field for CLIP embeddings (use clip_model.ndims())
109
+ 2. image_uri: String field for image file paths
110
+ 3. description: Optional string field for text descriptions
111
+ """
112
+
113
+ # Add vector field for embeddings
114
+ vector: Vector(clip_model.ndims()) = clip_model.VectorField()
115
+
116
+ # Add image field
117
+ image_uri: str = clip_model.SourceField()
118
+
119
+ # Add text description field
120
+ description: Optional[str] = None
121
+
122
+ @property
123
+ def image(self):
124
+ if isinstance(self.image_uri, str) and os.path.exists(self.image_uri):
125
+ return Image.open(self.image_uri)
126
+ elif hasattr(self.image_uri, "save"): # PIL Image object
127
+ return self.image_uri
128
+ else:
129
+ # Return a placeholder or handle the case appropriately
130
+ return None
131
+
132
+
133
+ # =============================================================================
134
+ # SECTION 2: RETRIEVAL - Vector Database Operations
135
+ # =============================================================================
136
+
137
+
138
+ def setup_fashion_database(
139
+ database_path: str = "fashion_db",
140
+ table_name: str = "fashion_items",
141
+ dataset_name: str = "tomytjandra/h-and-m-fashion-caption",
142
+ sample_size: int = 1000,
143
+ images_dir: str = "fashion_images",
144
+ ) -> None:
145
+ """
146
+ Set up vector database with H&M fashion dataset
147
+
148
+ Complete this function to:
149
+ 1. Connect to LanceDB database
150
+ 2. Check if table already exists (skip if it does)
151
+ 3. Load H&M dataset from HuggingFace
152
+ 4. Process and save images locally
153
+ 5. Create vector database table
154
+ """
155
+ print("πŸ”§ Setting up fashion database...")
156
+ print(f"Database path: {database_path}")
157
+ print(f"Dataset: {dataset_name}")
158
+ print(f"Sample size: {sample_size}")
159
+
160
+ # Connect to LanceDB
161
+ db = lancedb.connect(database_path)
162
+
163
+ # Check if table already exists
164
+ if table_name in db.table_names():
165
+ existing_table = db.open_table(table_name) # open table
166
+ print(f"βœ… Table '{table_name}' already exists with {len(existing_table)} items")
167
+ return
168
+ # Drop table
169
+ #print(f"⚠️ Table '{table_name}' already exists, deleting it...")
170
+ #db.drop_table(table_name)
171
+ else:
172
+ print(f"πŸ—οΈ Table '{table_name}' does not exist, creating new fashion database...")
173
+
174
+ # Load dataset from HuggingFace
175
+ print("πŸ“₯ Loading H&M fashion dataset...")
176
+ dataset = load_dataset(dataset_name)
177
+ train_data = dataset["train"]
178
+
179
+ # Sample data to specified size in the sample_size parameter
180
+ if len(train_data) > sample_size:
181
+ indices = sample(range(len(train_data)), sample_size)
182
+ train_data = train_data.select(indices)
183
+ print(f"Processing {len(train_data)} fashion items...")
184
+
185
+ # Create images directory
186
+ os.makedirs(images_dir, exist_ok=True)
187
+
188
+ # Process each item
189
+ table_data = []
190
+ for i, item in enumerate(train_data):
191
+ # Get image and text
192
+ image = item["image"]
193
+ text = item["text"]
194
+
195
+ # Save image
196
+ image_path = os.path.join(images_dir, f"fashion_{i:04d}.jpg")
197
+ image.save(image_path)
198
+
199
+ # Create record
200
+ record = {
201
+ "image_uri": image_path,
202
+ "description": text
203
+ }
204
+ table_data.append(record)
205
+
206
+ if (i + 1) % 100 == 0:
207
+ print(f" Processed {i + 1}/{len(train_data)} items...")
208
+
209
+ # Create vector database table
210
+ if table_data:
211
+ if table_name in db.table_names():
212
+ print(f"⚠️ Table '{table_name}' already exists, deleting it...")
213
+ db.drop_table(table_name)
214
+
215
+ print("πŸ—„οΈ Creating vector database table...")
216
+ table = db.create_table(
217
+ table_name,
218
+ schema=FashionItem,
219
+ data=table_data,
220
+ #embedding_function=clip_model,
221
+ )
222
+ print(f"βœ… Created table '{table_name}' with {len(table_data)} items")
223
+ else:
224
+ print("❌ No data to create table, please check dataset loading")
225
+ print("πŸŽ‰ Fashion database setup complete!")
226
+
227
+ def search_fashion_items(
228
+ database_path: str,
229
+ table_name: str,
230
+ query: str,
231
+ search_type: str = "auto",
232
+ limit: int = 3,
233
+ ) -> Tuple[List[Dict], str]:
234
+ """
235
+ Search for fashion items using text or image query
236
+
237
+ Complete this function to:
238
+ 1. Determine if query is text or image (auto-detection)
239
+ 2. Connect to the vector database
240
+ 3. Perform similarity search using CLIP embeddings
241
+ 4. Return search results and detected search type
242
+
243
+ STEPS TO IMPLEMENT:
244
+ 1. Auto-detect search type: check if query is a file path
245
+ 2. Connect to database
246
+ 3. Open table
247
+ 4. Search based on type:
248
+ - Image: load with PIL and search
249
+ - Text: search directly with string
250
+ 5. Return results and search type
251
+
252
+ Args:
253
+ database_path: Path to LanceDB database
254
+ table_name: Name of the table to search
255
+ query: Search query (text or image path)
256
+ search_type: "auto", "text", or "image"
257
+ limit: Number of results to return
258
+
259
+ Returns:
260
+ Tuple of (results_list, actual_search_type)
261
+ """
262
+
263
+ print(f"πŸ” Searching for: {query}")
264
+
265
+ # Determine search type automatically
266
+ # HINT: Use os.path.exists(query) to check if query is a file path
267
+ # HINT: If file exists, it's an image search; otherwise, it's text search
268
+
269
+ if os.path.exists(query):
270
+ actual_search_type = "image"
271
+ else:
272
+ actual_search_type = "text"
273
+ print(f" Detected search type: {actual_search_type}")
274
+
275
+ # Connect to database
276
+ db = lancedb.connect(database_path)
277
+ print(f"πŸ“‚ Connected to database: {database_path}")
278
+
279
+ # Open the table
280
+ table = db.open_table(table_name)
281
+ print(f"πŸ“– Opened table: {table_name}")
282
+
283
+ # Perform search based on detected type
284
+ if actual_search_type == "image":
285
+ # Load image and search
286
+ image = Image.open(query)
287
+ print(f" Searching with image: {query}")
288
+ # # Get embeddings for the image
289
+ # image_embedding = clip_model.embed_image(image)
290
+ # # Perform similarity search
291
+ # results = table.search(
292
+ # vector=image_embedding,
293
+ # limit=limit,
294
+ # filter=None, # No additional filters
295
+ # ).to_dicts()
296
+ # print(f" Found {len(results)} results using image search")
297
+
298
+ results = table.search(image).limit(limit).to_pydantic(FashionItem)
299
+ else:
300
+ # Text search
301
+ print(f" Searching with text: {query}")
302
+ results = table.search(query).limit(limit).to_pydantic(FashionItem)
303
+
304
+ # Print results found
305
+ print(f" Found {len(results)} results using {actual_search_type} search")
306
+
307
+ # Return results and search type
308
+ return results, actual_search_type
309
+
310
+ # =============================================================================
311
+ # SECTION 3: AUGMENTATION - Prompt Engineering
312
+ # =============================================================================
313
+
314
+
315
+ def create_fashion_prompt(
316
+ query: str, retrieved_items: List[Dict], search_type: str
317
+ ) -> str:
318
+ """
319
+ Create enhanced prompt for LLM using retrieved fashion items
320
+
321
+ Complete this function to create a well-structured prompt that:
322
+ 1. Creates a system prompt defining the AI assistant's role
323
+ 2. Formats retrieved items as context for the LLM
324
+ 3. Includes the user's query appropriately
325
+ 4. Combines everything into a coherent prompt
326
+
327
+ PROMPT STRUCTURE:
328
+ 1. System prompt: Define the AI as a fashion assistant
329
+ 2. Context section: List retrieved fashion items with descriptions
330
+ 3. Query section: Include user's original query
331
+ 4. Instruction: Ask for fashion recommendations
332
+
333
+ Args:
334
+ query: Original user query
335
+ retrieved_items: List of retrieved fashion items
336
+ search_type: Type of search performed
337
+
338
+ Returns:
339
+ Enhanced prompt string for LLM
340
+ """
341
+
342
+ # Create system prompt
343
+ # HINT: Define the AI as a fashion assistant with expertise
344
+ system_prompt = "You are a fashion assistant with expertise in clothing and accessories. " \
345
+ "Your task is to provide helpful fashion recommendations based on user queries and retrieved items." \
346
+ "For each of the retrieved item - Please provide helpful fashion recommendations. " \
347
+ "Be funny, creative, and engaging in your response." \
348
+ "Talk about only retrieved items and do not make up any information. " \
349
+ "If you do not have enough information, please say so. " \
350
+ "Do not talk about anything else"
351
+ print("πŸ“ Creating enhanced prompt...")
352
+
353
+ # Format retrieved items context
354
+ context = "Here are some relevant fashion items from our catalog:\n\n"
355
+ for i, item in enumerate(retrieved_items, 1):
356
+ print (f" Adding item {i}: {item}...")
357
+ # Ensure item has description and image URI
358
+ context += f"{i}. {item.description}\n\n"
359
+
360
+ # Create user query section
361
+ # HINT: Handle different search types (image vs text)
362
+ if search_type == "image":
363
+ query_section = (
364
+ f"User searched for an image: {query}\n"
365
+ "Please provide fashion recommendations based on the retrieved items and the image."
366
+ )
367
+ else:
368
+ query_section = (
369
+ f"User query: {query}\n"
370
+ "Please provide fashion recommendations based on the retrieved items and the query."
371
+ )
372
+
373
+ print(f" Query section created: {query_section[:60]}...")
374
+ # Combine into final prompt
375
+ # HINT: Combine system prompt, context, query section, and response instruction
376
+ prompt = f"{system_prompt}\n\n{context}\n{query_section}\n\n "
377
+ return prompt
378
+
379
+ # =============================================================================
380
+ # SECTION 4: GENERATION - LLM Response Generation
381
+ # =============================================================================
382
+
383
+
384
+ def setup_llm_model(model_name: str = "Qwen/Qwen2.5-0.5B-Instruct") -> Tuple[Any, Any]:
385
+ """
386
+ Set up LLM model and tokenizer
387
+
388
+ Complete this function to load the LLM model and tokenizer
389
+
390
+ STEPS TO IMPLEMENT:
391
+ 1. Load tokenizer
392
+ 2. Load model
393
+ 3. Configure model settings for GPU/CPU
394
+ 5. Return tokenizer and model
395
+
396
+ Args:
397
+ model_name: Name of the model to load
398
+
399
+ Returns:
400
+ Tuple of (tokenizer, model)
401
+ """
402
+
403
+ print(f"πŸ€– Loading LLM model: {model_name}")
404
+
405
+ # Load tokenizer
406
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
407
+ print(" Tokenizer loaded successfully")
408
+
409
+ # Load model
410
+ model = AutoModelForCausalLM.from_pretrained(
411
+ model_name, torch_dtype=torch.float32, device_map="cpu"
412
+ )
413
+
414
+ # Set pad token if not exists
415
+ # TODO: Why are we doing this ?
416
+ if tokenizer.pad_token is None:
417
+ tokenizer.pad_token = tokenizer.eos_token
418
+
419
+ # Print success message and return
420
+ print("βœ… LLM model loaded successfully")
421
+ return tokenizer, model
422
+
423
+ def generate_fashion_response(
424
+ prompt: str, tokenizer: Any, model: Any, max_tokens: int = 200
425
+ ) -> str:
426
+ """
427
+ Generate response using LLM
428
+
429
+ Complete this function to generate text using the LLM
430
+
431
+ STEPS TO IMPLEMENT:
432
+ 1. Check if tokenizer and model are loaded
433
+ 2. Encode the prompt with attention mask
434
+ 3. Generate response using model.generate()
435
+ 4. Decode the response and clean it up
436
+ 5. Return the generated text
437
+
438
+ Args:
439
+ prompt: Input prompt for the model
440
+ tokenizer: Model tokenizer
441
+ model: LLM model
442
+ max_tokens: Maximum tokens to generate
443
+
444
+ Returns:
445
+ Generated response text
446
+ """
447
+
448
+ if not tokenizer or not model:
449
+ return "⚠️ LLM not loaded - showing search results only"
450
+
451
+ # Encode prompt with attention mask
452
+ # HINT: Use tokenizer() with return_tensors="pt", truncation=True, max_length=1024, padding=True
453
+ inputs = tokenizer(
454
+ prompt, return_tensors="pt", truncation=True, max_length=2048, padding=True
455
+ )
456
+
457
+ # Added byself
458
+ # Ensure everything runs on CPU
459
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
460
+
461
+ # Generate response
462
+ with torch.no_grad():
463
+ outputs = model.generate(
464
+ #inputs.input_ids,
465
+ **inputs,
466
+ #attention_mask=inputs.attention_mask,
467
+ max_new_tokens=max_tokens,
468
+ temperature=0.7,
469
+ do_sample=True,
470
+ pad_token_id=tokenizer.eos_token_id,
471
+ eos_token_id=tokenizer.eos_token_id
472
+ )
473
+
474
+ # Decode response and clean it up
475
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
476
+ response = full_response.replace(prompt, "").strip()
477
+ return response
478
+
479
+
480
+ # =============================================================================
481
+ # SECTION 5: IMAGE STORAGE
482
+ # =============================================================================
483
+
484
+
485
+ def save_retrieved_images(
486
+ results: Dict[str, Any], output_dir: str = "retrieved_fashion_images"
487
+ ) -> List[str]:
488
+ """Save retrieved fashion images to output directory"""
489
+
490
+ # Create output directory
491
+ os.makedirs(output_dir, exist_ok=True)
492
+
493
+ query_safe = re.sub(r"[^\w\s-]", "", str(results["query"]))[:30]
494
+ query_safe = re.sub(r"[-\s]+", "_", query_safe)
495
+
496
+ saved_paths = []
497
+
498
+ print(f"πŸ’Ύ Saving {len(results['results'])} retrieved images...")
499
+
500
+ for i, item in enumerate(results["results"], 1):
501
+ original_path = item.image_uri
502
+ image = Image.open(original_path)
503
+
504
+ # Generate new filename
505
+ filename = f"{query_safe}_result_{i:02d}.jpg"
506
+ save_path = os.path.join(output_dir, filename)
507
+
508
+ # Save image
509
+ image.save(save_path, "JPEG", quality=95)
510
+ saved_paths.append(save_path)
511
+
512
+ print(f" βœ… Saved image {i}: {filename}")
513
+ print(f" Description: {item.description[:60]}...")
514
+
515
+ print(f"πŸ’Ύ Saved {len(saved_paths)} images to: {output_dir}")
516
+ return saved_paths
517
+
518
+
519
+ # =============================================================================
520
+ # SECTION 6: COMPLETE RAG PIPELINE
521
+ # =============================================================================
522
+
523
+
524
+ def run_fashion_rag_pipeline(
525
+ query: str,
526
+ database_path: str = "fashion_db",
527
+ table_name: str = "fashion_items",
528
+ search_type: str = "auto",
529
+ limit: int = 3,
530
+ save_images: bool = True,
531
+ ) -> Dict[str, Any]:
532
+ """
533
+ Run complete fashion RAG pipeline
534
+
535
+ Complete this function to orchestrate the entire pipeline:
536
+ 1. RETRIEVAL: Search for relevant fashion items using vector database
537
+ 2. AUGMENTATION: Create enhanced prompt with retrieved context
538
+ 3. GENERATION: Generate LLM response using the enhanced prompt
539
+ 4. IMAGE STORAGE: Save retrieved images if requested
540
+
541
+ This is the main function that ties everything together!
542
+
543
+ PIPELINE PHASES:
544
+ Phase 1 - RETRIEVAL: Find similar fashion items
545
+ Phase 2 - AUGMENTATION: Create context-rich prompt
546
+ Phase 3 - GENERATION: Generate helpful response
547
+ Phase 4 - STORAGE: Save retrieved images
548
+ """
549
+
550
+ print("πŸš€ Starting Fashion RAG Pipeline")
551
+ print("=" * 50)
552
+
553
+ # PHASE 1: RETRIEVAL
554
+ print("πŸ” PHASE 1: RETRIEVAL")
555
+ # Search for fashion items using the search function
556
+ # HINT: Call search_fashion_items() with the provided parameters
557
+ results, actual_search_type = search_fashion_items(
558
+ database_path=database_path,
559
+ table_name=table_name,
560
+ query=query,
561
+ search_type=search_type,
562
+ limit=limit,
563
+ )
564
+ print(f" Found {len(results)} relevant items")
565
+ print(f" Search type used: {actual_search_type}")
566
+
567
+ # PHASE 2: AUGMENTATION
568
+ print("πŸ“ PHASE 2: AUGMENTATION")
569
+ # Create enhanced prompt using retrieved items
570
+ # HINT: Call create_fashion_prompt() with parameters
571
+ enhanced_prompt = create_fashion_prompt(
572
+ query=query,
573
+ retrieved_items=results,
574
+ search_type=actual_search_type,
575
+ )
576
+ print(f" Created enhanced prompt ({len(enhanced_prompt)} chars)")
577
+
578
+ # PHASE 3: GENERATION
579
+ print("πŸ€– PHASE 3: GENERATION")
580
+ # Set up LLM and generate response
581
+ tokenizer, model = setup_llm_model()
582
+ if not tokenizer or not model:
583
+ print("⚠️ LLM not loaded - skipping response generation")
584
+ response = "⚠️ LLM not available"
585
+ else:
586
+ # Generate response using the enhanced prompt
587
+ response = generate_fashion_response(
588
+ prompt=enhanced_prompt,
589
+ tokenizer=tokenizer,
590
+ model=model,
591
+ max_tokens=200,
592
+ )
593
+
594
+ print(f" Generated response ({len(response)} chars)")
595
+
596
+ # Prepare final results dictionary
597
+ final_results = {
598
+ "query": query,
599
+ "results": results,
600
+ "response": response,
601
+ "search_type": actual_search_type
602
+ }
603
+
604
+ # Save retrieved images if requested
605
+ if save_images:
606
+ saved_image_paths = save_retrieved_images(final_results)
607
+ final_results["saved_image_paths"] = saved_image_paths
608
+
609
+ # Return final results
610
+ return final_results
611
+
612
+
613
+ # =============================================================================
614
+ # GRADIO WEB APP
615
+ # =============================================================================
616
+
617
+
618
+ def fashion_search_app(query):
619
+ """
620
+ Process fashion query and return response with images for Gradio
621
+
622
+ Complete this function to handle web app queries
623
+
624
+ STEPS TO IMPLEMENT:
625
+ 1. Check if query is provided
626
+ 2. Setup database if needed
627
+ 3. Run RAG pipeline
628
+ 4. Extract LLM response and images
629
+ 5. Return formatted results for Gradio
630
+ """
631
+
632
+ if not query.strip():
633
+ return "Please enter a search query", []
634
+
635
+ # Setup database if needed (will skip if exists)
636
+ print("πŸ”§ Checking/setting up fashion database...")
637
+ setup_fashion_database()
638
+
639
+ # Run the RAG pipeline
640
+ result = run_fashion_rag_pipeline(
641
+ query=query,
642
+ database_path="fashion_db",
643
+ table_name="fashion_items",
644
+ search_type="auto",
645
+ limit=3,
646
+ save_images=True,
647
+ )
648
+ print("🎯 RAG pipeline completed")
649
+
650
+ # Get LLM response
651
+ llm_response = result['response']
652
+ print(f"πŸ€– LLM Response: {llm_response[:60]}...")
653
+
654
+ # Get retrieved images for display
655
+ retrieved_images = []
656
+ for item in result['results']:
657
+ if os.path.exists(item.image_uri):
658
+ img = Image.open(item.image_uri)
659
+ retrieved_images.append(img)
660
+
661
+ # Return response and images
662
+ return llm_response, retrieved_images
663
+
664
+
665
+ def launch_gradio_app():
666
+ """Launch the Gradio web interface"""
667
+
668
+ # Create Gradio interface
669
+ with gr.Blocks(title="Fashion RAG Assistant") as app:
670
+
671
+ gr.Markdown("# πŸ‘— Fashion RAG Assistant")
672
+ gr.Markdown("Search for fashion items and get AI-powered recommendations!")
673
+
674
+ with gr.Row():
675
+ with gr.Column(scale=1):
676
+ # Input
677
+ query_input = gr.Textbox(
678
+ label="Search Query",
679
+ placeholder="Enter your fashion query (e.g., 'black dress for evening')",
680
+ lines=2,
681
+ )
682
+
683
+ search_btn = gr.Button("Search", variant="primary")
684
+
685
+ # Examples
686
+ gr.Examples(
687
+ examples=[
688
+ "black dress for evening",
689
+ "casual summer outfit",
690
+ "blue jeans",
691
+ "white shirt",
692
+ "winter jacket",
693
+ ],
694
+ inputs=query_input,
695
+ )
696
+
697
+ with gr.Column(scale=2):
698
+ # Output
699
+ response_output = gr.Textbox(
700
+ label="Fashion Recommendation", lines=10, interactive=True, autoscroll=True
701
+ )
702
+
703
+ # Retrieved Images
704
+ images_output = gr.Gallery(
705
+ label="Retrieved Fashion Items", columns=3, height=400
706
+ )
707
+
708
+ # Connect the search function
709
+ search_btn.click(
710
+ fn=fashion_search_app,
711
+ inputs=query_input,
712
+ outputs=[response_output, images_output],
713
+ )
714
+
715
+ # Also trigger on Enter key
716
+ query_input.submit(
717
+ fn=fashion_search_app,
718
+ inputs=query_input,
719
+ outputs=[response_output, images_output],
720
+ )
721
+
722
+ print("πŸš€ Starting Fashion RAG Gradio App...")
723
+ print("πŸ“ Note: First run will download dataset and setup database")
724
+ app.launch(share=True)
725
+
726
+
727
+ # =============================================================================
728
+ # MAIN EXECUTION
729
+ # =============================================================================
730
+
731
+
732
+ def main():
733
+ """Main function to handle command line arguments and run the pipeline"""
734
+
735
+ # If running in Hugging Face Spaces, automatically launch the app
736
+ if is_huggingface_space():
737
+ print("πŸ€— Running in Hugging Face Spaces - launching web app automatically")
738
+ launch_gradio_app()
739
+ return
740
+
741
+ parser = argparse.ArgumentParser(
742
+ description="Fashion RAG Pipeline Assignment - SOLUTION"
743
+ )
744
+ parser.add_argument("--query", type=str, help="Search query (text or image path)")
745
+ parser.add_argument("--app", action="store_true", help="Launch Gradio web app")
746
+
747
+ args = parser.parse_args()
748
+
749
+ # Launch web app if requested
750
+ if args.app:
751
+ launch_gradio_app()
752
+ return
753
+
754
+ if not args.query:
755
+ print("❌ Please provide a query with --query or use --app for web interface")
756
+ print("Examples:")
757
+ print(" python solution_fashion_rag.py --query 'black dress for evening'")
758
+ print(" python solution_fashion_rag.py --query 'fashion_images/dress.jpg'")
759
+ print(" python solution_fashion_rag.py --app")
760
+ return
761
+
762
+ # Setup database first (will skip if already exists)
763
+ print("πŸ”§ Checking/setting up fashion database...")
764
+ setup_fashion_database()
765
+
766
+ # Run the complete RAG pipeline with default settings
767
+ result = run_fashion_rag_pipeline(
768
+ query=args.query,
769
+ database_path="fashion_db",
770
+ table_name="fashion_items",
771
+ search_type="auto",
772
+ limit=3,
773
+ save_images=True,
774
+ )
775
+
776
+ # Display results
777
+ print("\n" + "=" * 50)
778
+ print("🎯 PIPELINE RESULTS")
779
+ print("=" * 50)
780
+ print(f"Query: {result['query']}")
781
+ print(f"Search Type: {result['search_type']}")
782
+ print(f"Results Found: {len(result['results'])}")
783
+ print("\nπŸ“‹ Retrieved Items:")
784
+ for i, item in enumerate(result["results"], 1):
785
+ print(f"{i}. {item.description}")
786
+
787
+ print(f"\nπŸ€– LLM Response:")
788
+ print(result["response"])
789
+
790
+ # Show saved images info if any
791
+ if result.get("saved_image_paths"):
792
+ print(f"\nπŸ“Έ Saved Images:")
793
+ for i, path in enumerate(result["saved_image_paths"], 1):
794
+ print(f"{i}. {path}")
795
+
796
+
797
+ if __name__ == "__main__":
798
+ main()