import os import io import base64 import time import gradio as gr from PIL import Image import logging import numpy as np from gradio_client import Client import json # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────── Backend connection ───────── HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required") # Try to connect to backend try: client = Client("SnapwearAI/Pattern-Transfer-Backend", hf_token=HF_TOKEN) logger.info("✅ Backend client established") backend_connected = True except Exception as e: logger.warning(f"⚠️ Backend connection failed: {e}") client = None backend_connected = False # ───────── Styling ───────── css = """ body, .gradio-container { font-family: 'Inter', 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; } #col-left, #col-mid, #col-right { margin: 0 auto; max-width: 400px; } #col-showcase { margin: 0 auto; max-width: 1200px; } #button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: #ffffff; font-weight: 600; font-size: 16px; border: none; border-radius: 12px; padding: 12px 24px; transition: all 0.3s ease; } #button:hover { transform: translateY(-2px); box-shadow: 0 8px 25px rgba(102,126,234,0.3); } #button:disabled { background: #ccc !important; cursor: not-allowed; transform: none; box-shadow: none; } .hero-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 40px 20px; border-radius: 20px; margin: 20px 0; text-align: center; } .feature-box { background: #f8fafc; border: 1px solid #e2e8f0; padding: 20px; border-radius: 12px; margin: 10px 0; border-left: 4px solid #667eea; } .showcase-section { background: #ffffff; border: 1px solid #e2e8f0; padding: 30px; border-radius: 16px; box-shadow: 0 4px 20px rgba(0,0,0,0.1); margin: 20px 0; } .step-header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 15px; border-radius: 12px; text-align: center; font-weight: 600; margin: 10px 0; } .social-links { text-align: center; margin: 20px 0; } .social-links a { margin: 0 10px; padding: 8px 16px; background: #667eea; color: white; text-decoration: none; border-radius: 8px; transition: all 0.3s ease; } .social-links a:hover { background: #764ba2; transform: translateY(-2px); } .status-banner { padding: 15px; border-radius: 12px; margin: 10px 0; text-align: center; font-weight: 600; } .status-ready { background: #d4edda; border: 1px solid #c3e6cb; color: #155724; } .status-starting { background: #fff3cd; border: 1px solid #ffeaa7; color: #856404; } .status-processing { background: #cce5ff; border: 1px solid #99ccff; color: #004085; } .status-error { background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } .queue-info { background: #e8f4fd; border: 1px solid #bee5eb; padding: 12px; border-radius: 8px; margin: 10px 0; text-align: center; font-size: 14px; color: #0c5460; } """ def image_to_base64(image): """Convert PIL Image to base64 string.""" if image is None: return "" if hasattr(image, 'mode') and image.mode != 'RGB': image = image.convert('RGB') buffer = io.BytesIO() image.save(buffer, format="PNG") buffer.seek(0) return base64.b64encode(buffer.getvalue()).decode('utf-8') def base64_to_image(b64_string): """Convert base64 string to PIL Image.""" if not b64_string: return None try: image_data = base64.b64decode(b64_string) return Image.open(io.BytesIO(image_data)) except Exception as e: logger.error(f"Failed to decode base64 image: {e}") return None def try_connect_backend(): """Try to connect to backend and return status""" global client, backend_connected try: test_client = Client("SnapwearAI/Pattern-Transfer-Backend", hf_token=HF_TOKEN) client = test_client backend_connected = True return "🟢 Backend is ready! You can now generate pattern transfers.", True except Exception as e: client = None backend_connected = False error_str = str(e).lower() if "timeout" in error_str or "read operation timed out" in error_str: return "🟡 Backend is starting up (this takes 5-6 minutes on first load). Please wait and try again.", False else: return f"🔴 Backend error: {str(e)}", False def call_backend_with_retry(print_image, product_image, max_retries=3): """Call the backend with proper error handling and queue awareness.""" global client, backend_connected # Validate inputs if not print_image: return None, "❌ Please upload a print/pattern image" if not product_image: return None, "❌ Please upload a product image" # Check if we have a client if not client or not backend_connected: # Try to reconnect status_msg, is_ready = try_connect_backend() if not is_ready: return None, status_msg # Use fixed default values guidance_scale = 50.0 num_steps = 50 for attempt in range(max_retries): try: logger.info(f"Calling backend (attempt {attempt + 1}/{max_retries})") # Convert images to base64 print_b64 = image_to_base64(print_image) product_b64 = image_to_base64(product_image) logger.info("Images converted to base64") # Make the backend call with progress tracking start_time = time.time() # Add queue position info if available try: result = client.predict( print_b64, product_b64, guidance_scale, num_steps, api_name="/predict" ) except Exception as prediction_error: # Handle queue-related messages in error error_str = str(prediction_error).lower() if "queue" in error_str or "position" in error_str: # Extract queue info if present return None, f"📋 Request queued. {str(prediction_error)}" else: raise prediction_error processing_time = time.time() - start_time logger.info(f"Backend call completed in {processing_time:.2f}s") # Process the result if result and len(result) >= 2: result_b64, status = result[0], result[1] if result_b64: result_image = base64_to_image(result_b64) if result_image: logger.info("Successfully received and decoded result image") # Add processing time to status if not already present if "Generated in" not in status: status = f"{status} (Total time: {processing_time:.1f}s)" return result_image, status else: return None, "❌ Failed to decode result image" else: return None, status or "❌ No image returned" else: return None, "❌ Invalid response from backend" except Exception as e: error_str = str(e).lower() if "timeout" in error_str: # Backend might be starting up again backend_connected = False client = None return None, "🟡 Backend timed out. It may be starting up or busy with other requests. Please try again in a few moments." elif "queue" in error_str or "busy" in error_str: return None, f"📋 Server is busy processing other requests. Please wait and try again. {str(e)}" logger.error(f"Backend call attempt {attempt + 1} failed: {e}") if attempt == max_retries - 1: return None, f"❌ Backend error: {str(e)}" time.sleep(3) # Wait before retry return None, "❌ All attempts failed" # ───────── Main UI ───────── with gr.Blocks(css=css, title="AI Style Transfer Studio - Pattern & Color Transfer") as demo: # ──────── Hero Section ──────── gr.HTML("""

🎨 Snapwear Pattern Mockup Studio

Transform Any Pattern onto Any Product Instantly

• Instant results • Perfect for designers, brands & creators

""") # ──────── Status Check Section ──────── with gr.Row(): with gr.Column(): # Initial status message if backend_connected: initial_status = '
🟢 Model is ready! You can generate pattern transfers.
' else: initial_status = '
🟡 Model may be starting up. Click "Check Status" to verify.
' status_display = gr.HTML(value=initial_status) # Status check button check_status_btn = gr.Button("🔄 Check Status", size="sm") # ──────── Info Box ──────── gr.HTML("""

ℹ️ How It Works

First Time: Backend takes 5-6 minutes to start up after being idle.

Multiple Users: Requests are processed one at a time to ensure quality. You'll be queued if others are using the system.

Processing Time: 30-60 seconds per request once processing begins.

Queue Updates: You'll see your position and estimated wait time.

""") # ──────── Key Features ──────── gr.HTML("""

🚀 Instant Transfer

Apply any pattern to any product in 30-60 seconds

🎯 Perfect Mapping

Preserves product shape, lighting, and texture for realistic results

🎨 Endless Possibilities

Transfer prints, patterns, textures, and colors across any product type

""") # ──────── Step Headers ──────── with gr.Row(): with gr.Column(elem_id="col-left"): gr.HTML('
Step 1: Upload Pattern/Print 🎨
') with gr.Column(elem_id="col-mid"): gr.HTML('
Step 2: Upload Product 📦
') with gr.Column(elem_id="col-right"): gr.HTML('
Step 3: Generate Magic ✨
') # ──────── Main Interface ──────── with gr.Row(): # ① Pattern/Print Upload with gr.Column(elem_id="col-left"): print_image = gr.Image( label="Pattern/Print Image", type="pil", height=400, ) gr.HTML('

Upload any pattern, print, texture, or design you want to transfer

') # Print examples if os.path.exists("Assets/print"): print_examples = [os.path.join("Assets/print", f) for f in os.listdir("Assets/print")][:10] if print_examples: gr.Examples( label="✨ Example Patterns", inputs=print_image, examples_per_page=10, examples=print_examples, ) # ② Product Upload with gr.Column(elem_id="col-mid"): product_image = gr.Image( label="Product Image", type="pil", height=400, ) gr.HTML('

Upload the product you want to apply the pattern to

') # Product examples if os.path.exists("Assets/product"): product_examples = [os.path.join("Assets/product", f) for f in os.listdir("Assets/product")][:12] if product_examples: gr.Examples( label="📦 Example Products", inputs=product_image, examples_per_page=12, examples=product_examples, ) # ③ Result + Controls with gr.Column(elem_id="col-right"): result_img = gr.Image( label="✨ Transformed Result", show_share_button=True, height=400 ) # Status display with queue info status_text = gr.Text( label="Generation Status", interactive=False, placeholder="Upload images and click generate..." ) # Generate button generate_btn = gr.Button( "🚀 Transform Pattern", elem_id="button", size="lg", variant="primary" ) # Queue status info gr.HTML("""
💡 If busy, you'll be automatically queued and see position updates
""") # ──────── Showcase Examples ──────── gr.HTML("""

🌟 Showcase: Pattern & Color Transfer Examples

""") # Pattern Transfer Showcase with gr.Row(): gr.HTML('

🎨 Pattern Transfer Showcase

') try: if os.path.exists("Assets/examples"): showcase_examples = [ [os.path.join("Assets/examples", "1_product.jpg"), os.path.join("Assets/examples", "1_print.jpg"), os.path.join("Assets/examples", "1_result.jpg")], [os.path.join("Assets/examples", "2_product.jpg"), os.path.join("Assets/examples", "2_print.jpg"), os.path.join("Assets/examples", "2_result.jpg")], [os.path.join("Assets/examples", "3_product.jpg"), os.path.join("Assets/examples", "3_print.jpg"), os.path.join("Assets/examples", "3_result.jpg")], [os.path.join("Assets/examples", "4_product.jpg"), os.path.join("Assets/examples", "4_print.jpg"), os.path.join("Assets/examples", "4_result.jpg")], ] pattern_showcase = gr.Examples( examples=showcase_examples, inputs=[product_image, print_image, result_img], label="Pattern Transfer Examples - Click any example to try it yourself!", examples_per_page=4, ) except: gr.HTML("

Pattern transfer examples will appear here once example files are added to Assets/examples/

") # Color Transfer Showcase with gr.Row(): gr.HTML('

🌈 Color Transfer Showcase

') try: if os.path.exists("Assets/examples/color"): color_examples = [ [os.path.join("Assets/examples/color", "1_product.jpg"), os.path.join("Assets/examples/color", "1_print.jpg"), os.path.join("Assets/examples/color", "1_result.jpg")], [os.path.join("Assets/examples/color", "2_product.jpg"), os.path.join("Assets/examples/color", "2_print.jpg"), os.path.join("Assets/examples/color", "2_result.jpg")], [os.path.join("Assets/examples/color", "3_product.jpg"), os.path.join("Assets/examples/color", "3_print.jpg"), os.path.join("Assets/examples/color", "3_result.jpg")], ] color_showcase = gr.Examples( examples=color_examples, inputs=[product_image, print_image, result_img], label="Color Transfer Examples - Perfect for recoloring products!", examples_per_page=3, ) except: gr.HTML("

Color transfer examples will appear here once example files are added to Assets/examples/color/

") # ──────── Use Cases ──────── gr.HTML("""

🎯 Perfect For

👗 Fashion Designers

Visualize patterns on garments before production

🛍️ E-commerce Brands

Show product variations without inventory

🎨 Print-on-Demand

Preview designs on products instantly

📱 Content Creators

Create unique visuals for social media

""") # ──────── Event Handlers ──────── def update_status_display(): """Check backend status and update display""" status_msg, is_ready = try_connect_backend() if is_ready: css_class = "status-ready" elif "starting up" in status_msg: css_class = "status-starting" else: css_class = "status-error" status_html = f'
{status_msg}
' return status_html # Status check button click check_status_btn.click( fn=update_status_display, outputs=[status_display] ) # Generate button click with enhanced progress tracking generate_btn.click( fn=call_backend_with_retry, inputs=[print_image, product_image], outputs=[result_img, status_text], show_progress="full", concurrency_limit=1, # Ensure only one generation at a time on frontend too ) # ──────── Footer ──────── gr.HTML("""

🚀 Powered by Snapwear AI

Transform your creative vision with our models.

© 2024 Snapwear AI. Professional AI tools for fashion and design.

""") if __name__ == "__main__": demo.queue( max_size=20, default_concurrency_limit=1, # Single concurrent request to match backend api_open=False ).launch( server_name="0.0.0.0", server_port=7860, share=False, show_api=False )