Spaces:
Runtime error
Runtime error
| """ | |
| Model management for FLUX Prompt Optimizer | |
| Handles Florence-2 and Bagel model integration | |
| """ | |
| import logging | |
| import requests | |
| import spaces | |
| import torch | |
| from typing import Optional, Dict, Any, Tuple | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from config import MODEL_CONFIG, get_device_config | |
| from utils import clean_memory, safe_execute | |
| logger = logging.getLogger(__name__) | |
| class BaseImageAnalyzer: | |
| """Base class for image analysis models""" | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| self.device_config = get_device_config() | |
| self.is_initialized = False | |
| def initialize(self) -> bool: | |
| """Initialize the model""" | |
| raise NotImplementedError | |
| def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image and return description""" | |
| raise NotImplementedError | |
| def cleanup(self) -> None: | |
| """Clean up model resources""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.processor is not None: | |
| del self.processor | |
| self.processor = None | |
| clean_memory() | |
| class Florence2Analyzer(BaseImageAnalyzer): | |
| """Florence-2 model for image analysis""" | |
| def __init__(self): | |
| super().__init__() | |
| self.config = MODEL_CONFIG["florence2"] | |
| def initialize(self) -> bool: | |
| """Initialize Florence-2 model""" | |
| if self.is_initialized: | |
| return True | |
| try: | |
| logger.info("Initializing Florence-2 model...") | |
| model_id = self.config["model_id"] | |
| # Load processor | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_id, | |
| trust_remote_code=self.config["trust_remote_code"] | |
| ) | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=self.config["trust_remote_code"], | |
| torch_dtype=self.config["torch_dtype"] if self.device_config["use_gpu"] else torch.float32 | |
| ) | |
| # Move to appropriate device | |
| if self.device_config["use_gpu"]: | |
| self.model = self.model.to(self.device_config["device"]) | |
| else: | |
| self.model = self.model.to("cpu") | |
| self.model.eval() | |
| self.is_initialized = True | |
| logger.info(f"Florence-2 initialized on {self.device_config['device']}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Florence-2 initialization failed: {e}") | |
| self.cleanup() | |
| return False | |
| def _gpu_inference(self, image: Image.Image, task_prompt: str) -> str: | |
| """Run inference on GPU with spaces decorator""" | |
| try: | |
| # Move model to GPU for inference | |
| if self.device_config["use_gpu"]: | |
| self.model = self.model.to("cuda") | |
| # Prepare inputs | |
| inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") | |
| # Move inputs to device | |
| device = "cuda" if self.device_config["use_gpu"] else self.device_config["device"] | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate response | |
| with torch.no_grad(): | |
| if self.device_config["use_gpu"]: | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| generated_ids = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=self.config["max_new_tokens"], | |
| num_beams=3, | |
| do_sample=False | |
| ) | |
| else: | |
| generated_ids = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=self.config["max_new_tokens"], | |
| num_beams=3, | |
| do_sample=False | |
| ) | |
| # Decode response | |
| generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed = self.processor.post_process_generation( | |
| generated_text, | |
| task=task_prompt, | |
| image_size=(image.width, image.height) | |
| ) | |
| # Extract caption | |
| if task_prompt in parsed: | |
| return parsed[task_prompt] | |
| else: | |
| return str(parsed) if parsed else "" | |
| except Exception as e: | |
| logger.error(f"Florence-2 GPU inference failed: {e}") | |
| return "" | |
| finally: | |
| # Move model back to CPU to free GPU memory | |
| if self.device_config["use_gpu"]: | |
| self.model = self.model.to("cpu") | |
| clean_memory() | |
| def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image using Florence-2""" | |
| if not self.is_initialized: | |
| success = self.initialize() | |
| if not success: | |
| return "Model initialization failed", {"error": "Florence-2 not available"} | |
| try: | |
| # Define analysis tasks | |
| tasks = { | |
| "detailed": "<DETAILED_CAPTION>", | |
| "more_detailed": "<MORE_DETAILED_CAPTION>", | |
| "caption": "<CAPTION>" | |
| } | |
| results = {} | |
| # Run analysis for each task | |
| for task_name, task_prompt in tasks.items(): | |
| if self.device_config["use_gpu"]: | |
| result = self._gpu_inference(image, task_prompt) | |
| else: | |
| result = self._cpu_inference(image, task_prompt) | |
| results[task_name] = result | |
| # Choose best result | |
| if results["more_detailed"]: | |
| main_description = results["more_detailed"] | |
| elif results["detailed"]: | |
| main_description = results["detailed"] | |
| else: | |
| main_description = results["caption"] or "A photograph" | |
| # Prepare metadata | |
| metadata = { | |
| "model": "Florence-2", | |
| "device": self.device_config["device"], | |
| "all_results": results, | |
| "confidence": 0.85 # Florence-2 generally reliable | |
| } | |
| logger.info(f"Florence-2 analysis complete: {len(main_description)} chars") | |
| return main_description, metadata | |
| except Exception as e: | |
| logger.error(f"Florence-2 analysis failed: {e}") | |
| return "Analysis failed", {"error": str(e)} | |
| def _cpu_inference(self, image: Image.Image, task_prompt: str) -> str: | |
| """Run inference on CPU""" | |
| try: | |
| inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=self.config["max_new_tokens"], | |
| num_beams=2, # Reduced for CPU | |
| do_sample=False | |
| ) | |
| generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed = self.processor.post_process_generation( | |
| generated_text, | |
| task=task_prompt, | |
| image_size=(image.width, image.height) | |
| ) | |
| if task_prompt in parsed: | |
| return parsed[task_prompt] | |
| else: | |
| return str(parsed) if parsed else "" | |
| except Exception as e: | |
| logger.error(f"Florence-2 CPU inference failed: {e}") | |
| return "" | |
| class BagelAnalyzer(BaseImageAnalyzer): | |
| """Bagel-7B model analyzer via API""" | |
| def __init__(self): | |
| super().__init__() | |
| self.config = MODEL_CONFIG["bagel"] | |
| self.session = requests.Session() | |
| def initialize(self) -> bool: | |
| """Initialize Bagel analyzer (API-based)""" | |
| try: | |
| # Test API connectivity | |
| test_response = self.session.get( | |
| self.config["api_url"], | |
| timeout=self.config["timeout"] | |
| ) | |
| if test_response.status_code == 200: | |
| self.is_initialized = True | |
| logger.info("Bagel API connection established") | |
| return True | |
| else: | |
| logger.error(f"Bagel API not accessible: {test_response.status_code}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Bagel initialization failed: {e}") | |
| return False | |
| def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image using Bagel-7B API""" | |
| if not self.is_initialized: | |
| success = self.initialize() | |
| if not success: | |
| return "Bagel API not available", {"error": "API connection failed"} | |
| try: | |
| # Convert image to base64 or prepare for API call | |
| # Note: This is a placeholder - actual implementation would depend on Bagel API format | |
| # For now, return a placeholder response | |
| # In real implementation, you would: | |
| # 1. Convert image to required format | |
| # 2. Make API call to Bagel endpoint | |
| # 3. Parse response | |
| description = "Detailed image analysis via Bagel-7B (API implementation needed)" | |
| metadata = { | |
| "model": "Bagel-7B", | |
| "method": "API", | |
| "confidence": 0.8 | |
| } | |
| logger.info("Bagel analysis complete (placeholder)") | |
| return description, metadata | |
| except Exception as e: | |
| logger.error(f"Bagel analysis failed: {e}") | |
| return "Analysis failed", {"error": str(e)} | |
| class ModelManager: | |
| """Manager for handling multiple analysis models""" | |
| def __init__(self, preferred_model: str = None): | |
| self.preferred_model = preferred_model or MODEL_CONFIG["primary_model"] | |
| self.analyzers = {} | |
| self.current_analyzer = None | |
| def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
| """Get or create analyzer for specified model""" | |
| model_name = model_name or self.preferred_model | |
| if model_name not in self.analyzers: | |
| if model_name == "florence2": | |
| self.analyzers[model_name] = Florence2Analyzer() | |
| elif model_name == "bagel": | |
| self.analyzers[model_name] = BagelAnalyzer() | |
| else: | |
| logger.error(f"Unknown model: {model_name}") | |
| return None | |
| return self.analyzers[model_name] | |
| def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image with specified or preferred model""" | |
| analyzer = self.get_analyzer(model_name) | |
| if analyzer is None: | |
| return "No analyzer available", {"error": "Model not found"} | |
| success, result = safe_execute(analyzer.analyze_image, image) | |
| if success: | |
| return result | |
| else: | |
| return "Analysis failed", {"error": result} | |
| def cleanup_all(self) -> None: | |
| """Clean up all model resources""" | |
| for analyzer in self.analyzers.values(): | |
| analyzer.cleanup() | |
| self.analyzers.clear() | |
| clean_memory() | |
| # Global model manager instance | |
| model_manager = ModelManager() | |
| def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]: | |
| """ | |
| Convenience function for image analysis | |
| Args: | |
| image: PIL Image to analyze | |
| model_name: Optional model name ("florence2" or "bagel") | |
| Returns: | |
| Tuple of (description, metadata) | |
| """ | |
| return model_manager.analyze_image(image, model_name) | |
| # Export main components | |
| __all__ = [ | |
| "BaseImageAnalyzer", | |
| "Florence2Analyzer", | |
| "BagelAnalyzer", | |
| "ModelManager", | |
| "model_manager", | |
| "analyze_image" | |
| ] |