Spaces:
Runtime error
Runtime error
| from pprint import pprint | |
| from time import perf_counter | |
| from traceback import print_exc | |
| from typing import Any | |
| from app_settings import Settings | |
| from backend.image_saver import ImageSaver | |
| from backend.lcm_text_to_image import LCMTextToImage | |
| from backend.models.lcmdiffusion_setting import DiffusionTask | |
| from backend.utils import get_blank_image | |
| from models.interface_types import InterfaceType | |
| class Context: | |
| def __init__( | |
| self, | |
| interface_type: InterfaceType, | |
| device="cpu", | |
| ): | |
| self.interface_type = interface_type.value | |
| self.lcm_text_to_image = LCMTextToImage(device) | |
| self._latency = 0 | |
| self._error = "" | |
| def latency(self): | |
| return self._latency | |
| def error(self): | |
| return self._error | |
| def generate_text_to_image( | |
| self, | |
| settings: Settings, | |
| reshape: bool = False, | |
| device: str = "cpu", | |
| save_config=True, | |
| ) -> Any: | |
| try: | |
| self._error = "" | |
| tick = perf_counter() | |
| from state import get_settings | |
| if ( | |
| settings.lcm_diffusion_setting.diffusion_task | |
| == DiffusionTask.text_to_image.value | |
| ): | |
| settings.lcm_diffusion_setting.init_image = None | |
| if save_config: | |
| get_settings().save() | |
| pprint(settings.lcm_diffusion_setting.model_dump()) | |
| if not settings.lcm_diffusion_setting.lcm_lora: | |
| return None | |
| self.lcm_text_to_image.init( | |
| device, | |
| settings.lcm_diffusion_setting, | |
| ) | |
| images = self.lcm_text_to_image.generate( | |
| settings.lcm_diffusion_setting, | |
| reshape, | |
| ) | |
| elapsed = perf_counter() - tick | |
| self._latency = elapsed | |
| print(f"Latency : {elapsed:.2f} seconds") | |
| if settings.lcm_diffusion_setting.controlnet: | |
| if settings.lcm_diffusion_setting.controlnet.enabled: | |
| images.append( | |
| settings.lcm_diffusion_setting.controlnet._control_image | |
| ) | |
| if settings.lcm_diffusion_setting.use_safety_checker: | |
| print("Safety Checker is enabled") | |
| from state import get_safety_checker | |
| safety_checker = get_safety_checker() | |
| blank_image = get_blank_image( | |
| settings.lcm_diffusion_setting.image_width, | |
| settings.lcm_diffusion_setting.image_height, | |
| ) | |
| for idx, image in enumerate(images): | |
| if not safety_checker.is_safe(image): | |
| images[idx] = blank_image | |
| except Exception as exception: | |
| print(f"Error in generating images: {exception}") | |
| self._error = str(exception) | |
| print_exc() | |
| return None | |
| return images | |
| def save_images( | |
| self, | |
| images: Any, | |
| settings: Settings, | |
| ) -> list[str]: | |
| saved_images = [] | |
| if images and settings.generated_images.save_image: | |
| saved_images = ImageSaver.save_images( | |
| settings.generated_images.path, | |
| images=images, | |
| lcm_diffusion_setting=settings.lcm_diffusion_setting, | |
| format=settings.generated_images.format, | |
| jpeg_quality=settings.generated_images.save_image_quality, | |
| ) | |
| return saved_images | |