import os import io import base64 import time import gradio as gr from PIL import Image import logging import numpy as np from gradio_client import Client import json # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────── Backend connection ───────── HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required") # Try to connect to backend try: client = Client("SnapwearAI/Pattern-Transfer-Backend", hf_token=HF_TOKEN) logger.info("✅ Backend client established") backend_connected = True except Exception as e: logger.warning(f"⚠️ Backend connection failed: {e}") client = None backend_connected = False # ───────── Styling ───────── css = """ body, .gradio-container { font-family: 'Inter', 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; } #col-left, #col-mid, #col-right { margin: 0 auto; max-width: 400px; } #col-showcase { margin: 0 auto; max-width: 1200px; } #button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: #ffffff; font-weight: 600; font-size: 16px; border: none; border-radius: 12px; padding: 12px 24px; transition: all 0.3s ease; } #button:hover { transform: translateY(-2px); box-shadow: 0 8px 25px rgba(102,126,234,0.3); } #button:disabled { background: #ccc !important; cursor: not-allowed; transform: none; box-shadow: none; } .hero-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 40px 20px; border-radius: 20px; margin: 20px 0; text-align: center; } .feature-box { background: #f8fafc; border: 1px solid #e2e8f0; padding: 20px; border-radius: 12px; margin: 10px 0; border-left: 4px solid #667eea; } .showcase-section { background: #ffffff; border: 1px solid #e2e8f0; padding: 30px; border-radius: 16px; box-shadow: 0 4px 20px rgba(0,0,0,0.1); margin: 20px 0; } .step-header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 15px; border-radius: 12px; text-align: center; font-weight: 600; margin: 10px 0; } .social-links { text-align: center; margin: 20px 0; } .social-links a { margin: 0 10px; padding: 8px 16px; background: #667eea; color: white; text-decoration: none; border-radius: 8px; transition: all 0.3s ease; } .social-links a:hover { background: #764ba2; transform: translateY(-2px); } .status-banner { padding: 15px; border-radius: 12px; margin: 10px 0; text-align: center; font-weight: 600; } .status-ready { background: #d4edda; border: 1px solid #c3e6cb; color: #155724; } .status-starting { background: #fff3cd; border: 1px solid #ffeaa7; color: #856404; } .status-processing { background: #cce5ff; border: 1px solid #99ccff; color: #004085; } .status-error { background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } .queue-info { background: #e8f4fd; border: 1px solid #bee5eb; padding: 12px; border-radius: 8px; margin: 10px 0; text-align: center; font-size: 14px; color: #0c5460; } """ def image_to_base64(image): """Convert PIL Image to base64 string.""" if image is None: return "" if hasattr(image, 'mode') and image.mode != 'RGB': image = image.convert('RGB') buffer = io.BytesIO() image.save(buffer, format="PNG") buffer.seek(0) return base64.b64encode(buffer.getvalue()).decode('utf-8') def base64_to_image(b64_string): """Convert base64 string to PIL Image.""" if not b64_string: return None try: image_data = base64.b64decode(b64_string) return Image.open(io.BytesIO(image_data)) except Exception as e: logger.error(f"Failed to decode base64 image: {e}") return None def try_connect_backend(): """Try to connect to backend and return status""" global client, backend_connected try: test_client = Client("SnapwearAI/Pattern-Transfer-Backend", hf_token=HF_TOKEN) client = test_client backend_connected = True return "🟢 Backend is ready! You can now generate pattern transfers.", True except Exception as e: client = None backend_connected = False error_str = str(e).lower() if "timeout" in error_str or "read operation timed out" in error_str: return "🟡 Backend is starting up (this takes 5-6 minutes on first load). Please wait and try again.", False else: return f"🔴 Backend error: {str(e)}", False def call_backend_with_retry(print_image, product_image, max_retries=3): """Call the backend with proper error handling and queue awareness.""" global client, backend_connected # Validate inputs if not print_image: return None, "❌ Please upload a print/pattern image" if not product_image: return None, "❌ Please upload a product image" # Check if we have a client if not client or not backend_connected: # Try to reconnect status_msg, is_ready = try_connect_backend() if not is_ready: return None, status_msg # Use fixed default values guidance_scale = 50.0 num_steps = 50 for attempt in range(max_retries): try: logger.info(f"Calling backend (attempt {attempt + 1}/{max_retries})") # Convert images to base64 print_b64 = image_to_base64(print_image) product_b64 = image_to_base64(product_image) logger.info("Images converted to base64") # Make the backend call with progress tracking start_time = time.time() # Add queue position info if available try: result = client.predict( print_b64, product_b64, guidance_scale, num_steps, api_name="/predict" ) except Exception as prediction_error: # Handle queue-related messages in error error_str = str(prediction_error).lower() if "queue" in error_str or "position" in error_str: # Extract queue info if present return None, f"📋 Request queued. {str(prediction_error)}" else: raise prediction_error processing_time = time.time() - start_time logger.info(f"Backend call completed in {processing_time:.2f}s") # Process the result if result and len(result) >= 2: result_b64, status = result[0], result[1] if result_b64: result_image = base64_to_image(result_b64) if result_image: logger.info("Successfully received and decoded result image") # Add processing time to status if not already present if "Generated in" not in status: status = f"{status} (Total time: {processing_time:.1f}s)" return result_image, status else: return None, "❌ Failed to decode result image" else: return None, status or "❌ No image returned" else: return None, "❌ Invalid response from backend" except Exception as e: error_str = str(e).lower() if "timeout" in error_str: # Backend might be starting up again backend_connected = False client = None return None, "🟡 Backend timed out. It may be starting up or busy with other requests. Please try again in a few moments." elif "queue" in error_str or "busy" in error_str: return None, f"📋 Server is busy processing other requests. Please wait and try again. {str(e)}" logger.error(f"Backend call attempt {attempt + 1} failed: {e}") if attempt == max_retries - 1: return None, f"❌ Backend error: {str(e)}" time.sleep(3) # Wait before retry return None, "❌ All attempts failed" # ───────── Main UI ───────── with gr.Blocks(css=css, title="AI Style Transfer Studio - Pattern & Color Transfer") as demo: # ──────── Hero Section ──────── gr.HTML("""
• Instant results • Perfect for designers, brands & creators
First Time: Backend takes 5-6 minutes to start up after being idle.
Multiple Users: Requests are processed one at a time to ensure quality. You'll be queued if others are using the system.
Processing Time: 30-60 seconds per request once processing begins.
Queue Updates: You'll see your position and estimated wait time.
Apply any pattern to any product in 30-60 seconds
Preserves product shape, lighting, and texture for realistic results
Transfer prints, patterns, textures, and colors across any product type
Upload any pattern, print, texture, or design you want to transfer
') # Print examples if os.path.exists("Assets/print"): print_examples = [os.path.join("Assets/print", f) for f in os.listdir("Assets/print")][:10] if print_examples: gr.Examples( label="✨ Example Patterns", inputs=print_image, examples_per_page=10, examples=print_examples, ) # ② Product Upload with gr.Column(elem_id="col-mid"): product_image = gr.Image( label="Product Image", type="pil", height=400, ) gr.HTML('Upload the product you want to apply the pattern to
') # Product examples if os.path.exists("Assets/product"): product_examples = [os.path.join("Assets/product", f) for f in os.listdir("Assets/product")][:12] if product_examples: gr.Examples( label="📦 Example Products", inputs=product_image, examples_per_page=12, examples=product_examples, ) # ③ Result + Controls with gr.Column(elem_id="col-right"): result_img = gr.Image( label="✨ Transformed Result", show_share_button=True, height=400 ) # Status display with queue info status_text = gr.Text( label="Generation Status", interactive=False, placeholder="Upload images and click generate..." ) # Generate button generate_btn = gr.Button( "🚀 Transform Pattern", elem_id="button", size="lg", variant="primary" ) # Queue status info gr.HTML("""Pattern transfer examples will appear here once example files are added to Assets/examples/
") # Color Transfer Showcase with gr.Row(): gr.HTML('Color transfer examples will appear here once example files are added to Assets/examples/color/
") # ──────── Use Cases ──────── gr.HTML("""Visualize patterns on garments before production
Show product variations without inventory
Preview designs on products instantly
Create unique visuals for social media
Transform your creative vision with our models.
© 2024 Snapwear AI. Professional AI tools for fashion and design.