Spaces:
Build error
Build error
| """ | |
| Wrapper class to call the stablediffusion.cpp shared library for GGUF support | |
| """ | |
| import ctypes | |
| import platform | |
| from ctypes import ( | |
| POINTER, | |
| c_bool, | |
| c_char_p, | |
| c_float, | |
| c_int, | |
| c_int64, | |
| c_void_p, | |
| ) | |
| from dataclasses import dataclass | |
| from os import path | |
| from typing import List, Any | |
| import numpy as np | |
| from PIL import Image | |
| from backend.gguf.sdcpp_types import ( | |
| RngType, | |
| SampleMethod, | |
| Schedule, | |
| SDCPPLogLevel, | |
| SDImage, | |
| SdType, | |
| ) | |
| class ModelConfig: | |
| model_path: str = "" | |
| clip_l_path: str = "" | |
| t5xxl_path: str = "" | |
| diffusion_model_path: str = "" | |
| vae_path: str = "" | |
| taesd_path: str = "" | |
| control_net_path: str = "" | |
| lora_model_dir: str = "" | |
| embed_dir: str = "" | |
| stacked_id_embed_dir: str = "" | |
| vae_decode_only: bool = True | |
| vae_tiling: bool = False | |
| free_params_immediately: bool = False | |
| n_threads: int = 4 | |
| wtype: SdType = SdType.SD_TYPE_Q4_0 | |
| rng_type: RngType = RngType.CUDA_RNG | |
| schedule: Schedule = Schedule.DEFAULT | |
| keep_clip_on_cpu: bool = False | |
| keep_control_net_cpu: bool = False | |
| keep_vae_on_cpu: bool = False | |
| class Txt2ImgConfig: | |
| prompt: str = "a man wearing sun glasses, highly detailed" | |
| negative_prompt: str = "" | |
| clip_skip: int = -1 | |
| cfg_scale: float = 2.0 | |
| guidance: float = 3.5 | |
| width: int = 512 | |
| height: int = 512 | |
| sample_method: SampleMethod = SampleMethod.EULER_A | |
| sample_steps: int = 1 | |
| seed: int = -1 | |
| batch_count: int = 2 | |
| control_cond: Image = None | |
| control_strength: float = 0.90 | |
| style_strength: float = 0.5 | |
| normalize_input: bool = False | |
| input_id_images_path: bytes = b"" | |
| class GGUFDiffusion: | |
| """GGUF Diffusion | |
| To support GGUF diffusion model based on stablediffusion.cpp | |
| https://github.com/ggerganov/ggml/blob/master/docs/gguf.md | |
| Implmented based on stablediffusion.h | |
| """ | |
| def __init__( | |
| self, | |
| libpath: str, | |
| config: ModelConfig, | |
| logging_enabled: bool = False, | |
| ): | |
| sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath) | |
| try: | |
| self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path) | |
| except OSError as e: | |
| print(f"Failed to load library {sdcpp_shared_lib_path}") | |
| raise ValueError(f"Error: {e}") | |
| if not config.clip_l_path or not path.exists(config.clip_l_path): | |
| raise ValueError( | |
| "CLIP model file not found,please check readme.md for GGUF model usage" | |
| ) | |
| if not config.t5xxl_path or not path.exists(config.t5xxl_path): | |
| raise ValueError( | |
| "T5XXL model file not found,please check readme.md for GGUF model usage" | |
| ) | |
| if not config.diffusion_model_path or not path.exists( | |
| config.diffusion_model_path | |
| ): | |
| raise ValueError( | |
| "Diffusion model file not found,please check readme.md for GGUF model usage" | |
| ) | |
| if not config.vae_path or not path.exists(config.vae_path): | |
| raise ValueError( | |
| "VAE model file not found,please check readme.md for GGUF model usage" | |
| ) | |
| self.model_config = config | |
| self.libsdcpp.new_sd_ctx.argtypes = [ | |
| c_char_p, # const char* model_path | |
| c_char_p, # const char* clip_l_path | |
| c_char_p, # const char* t5xxl_path | |
| c_char_p, # const char* diffusion_model_path | |
| c_char_p, # const char* vae_path | |
| c_char_p, # const char* taesd_path | |
| c_char_p, # const char* control_net_path_c_str | |
| c_char_p, # const char* lora_model_dir | |
| c_char_p, # const char* embed_dir_c_str | |
| c_char_p, # const char* stacked_id_embed_dir_c_str | |
| c_bool, # bool vae_decode_only | |
| c_bool, # bool vae_tiling | |
| c_bool, # bool free_params_immediately | |
| c_int, # int n_threads | |
| SdType, # enum sd_type_t wtype | |
| RngType, # enum rng_type_t rng_type | |
| Schedule, # enum schedule_t s | |
| c_bool, # bool keep_clip_on_cpu | |
| c_bool, # bool keep_control_net_cpu | |
| c_bool, # bool keep_vae_on_cpu | |
| ] | |
| self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p) | |
| self.sd_ctx = self.libsdcpp.new_sd_ctx( | |
| self._str_to_bytes(self.model_config.model_path), | |
| self._str_to_bytes(self.model_config.clip_l_path), | |
| self._str_to_bytes(self.model_config.t5xxl_path), | |
| self._str_to_bytes(self.model_config.diffusion_model_path), | |
| self._str_to_bytes(self.model_config.vae_path), | |
| self._str_to_bytes(self.model_config.taesd_path), | |
| self._str_to_bytes(self.model_config.control_net_path), | |
| self._str_to_bytes(self.model_config.lora_model_dir), | |
| self._str_to_bytes(self.model_config.embed_dir), | |
| self._str_to_bytes(self.model_config.stacked_id_embed_dir), | |
| self.model_config.vae_decode_only, | |
| self.model_config.vae_tiling, | |
| self.model_config.free_params_immediately, | |
| self.model_config.n_threads, | |
| self.model_config.wtype, | |
| self.model_config.rng_type, | |
| self.model_config.schedule, | |
| self.model_config.keep_clip_on_cpu, | |
| self.model_config.keep_control_net_cpu, | |
| self.model_config.keep_vae_on_cpu, | |
| ) | |
| if logging_enabled: | |
| self._set_logcallback() | |
| def _set_logcallback(self): | |
| print("Setting logging callback") | |
| # Define function callback | |
| SdLogCallbackType = ctypes.CFUNCTYPE( | |
| None, | |
| SDCPPLogLevel, | |
| ctypes.c_char_p, | |
| ctypes.c_void_p, | |
| ) | |
| self.libsdcpp.sd_set_log_callback.argtypes = [ | |
| SdLogCallbackType, | |
| ctypes.c_void_p, | |
| ] | |
| self.libsdcpp.sd_set_log_callback.restype = None | |
| # Convert the Python callback to a C func pointer | |
| self.c_log_callback = SdLogCallbackType( | |
| self.log_callback | |
| ) # prevent GC,keep callback as member variable | |
| self.libsdcpp.sd_set_log_callback(self.c_log_callback, None) | |
| def _get_sdcpp_shared_lib_path( | |
| self, | |
| root_path: str, | |
| ) -> str: | |
| system_name = platform.system() | |
| print(f"GGUF Diffusion on {system_name}") | |
| lib_name = "stable-diffusion.dll" | |
| sdcpp_lib_path = "" | |
| if system_name == "Windows": | |
| sdcpp_lib_path = path.join(root_path, lib_name) | |
| elif system_name == "Linux": | |
| lib_name = "libstable-diffusion.so" | |
| sdcpp_lib_path = path.join(root_path, lib_name) | |
| elif system_name == "Darwin": | |
| lib_name = "libstable-diffusion.dylib" | |
| sdcpp_lib_path = path.join(root_path, lib_name) | |
| else: | |
| print("Unknown platform.") | |
| return sdcpp_lib_path | |
| def log_callback( | |
| level, | |
| text, | |
| data, | |
| ): | |
| print(f"{text.decode('utf-8')}", end="") | |
| def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes: | |
| if in_str: | |
| return in_str.encode(encoding) | |
| else: | |
| return b"" | |
| def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]: | |
| self.libsdcpp.txt2img.restype = POINTER(SDImage) | |
| self.libsdcpp.txt2img.argtypes = [ | |
| c_void_p, # sd_ctx_t* sd_ctx (pointer to context object) | |
| c_char_p, # const char* prompt | |
| c_char_p, # const char* negative_prompt | |
| c_int, # int clip_skip | |
| c_float, # float cfg_scale | |
| c_float, # float guidance | |
| c_int, # int width | |
| c_int, # int height | |
| SampleMethod, # enum sample_method_t sample_method | |
| c_int, # int sample_steps | |
| c_int64, # int64_t seed | |
| c_int, # int batch_count | |
| POINTER(SDImage), # const sd_image_t* control_cond (pointer to SDImage) | |
| c_float, # float control_strength | |
| c_float, # float style_strength | |
| c_bool, # bool normalize_input | |
| c_char_p, # const char* input_id_images_path | |
| ] | |
| image_buffer = self.libsdcpp.txt2img( | |
| self.sd_ctx, | |
| self._str_to_bytes(txt2img_cfg.prompt), | |
| self._str_to_bytes(txt2img_cfg.negative_prompt), | |
| txt2img_cfg.clip_skip, | |
| txt2img_cfg.cfg_scale, | |
| txt2img_cfg.guidance, | |
| txt2img_cfg.width, | |
| txt2img_cfg.height, | |
| txt2img_cfg.sample_method, | |
| txt2img_cfg.sample_steps, | |
| txt2img_cfg.seed, | |
| txt2img_cfg.batch_count, | |
| txt2img_cfg.control_cond, | |
| txt2img_cfg.control_strength, | |
| txt2img_cfg.style_strength, | |
| txt2img_cfg.normalize_input, | |
| txt2img_cfg.input_id_images_path, | |
| ) | |
| images = self._get_sd_images_from_buffer( | |
| image_buffer, | |
| txt2img_cfg.batch_count, | |
| ) | |
| return images | |
| def _get_sd_images_from_buffer( | |
| self, | |
| image_buffer: Any, | |
| batch_count: int, | |
| ) -> List[Any]: | |
| images = [] | |
| if image_buffer: | |
| for i in range(batch_count): | |
| image = image_buffer[i] | |
| print( | |
| f"Generated image: {image.width}x{image.height} with {image.channel} channels" | |
| ) | |
| width = image.width | |
| height = image.height | |
| channels = image.channel | |
| pixel_data = np.ctypeslib.as_array( | |
| image.data, shape=(height, width, channels) | |
| ) | |
| if channels == 1: | |
| pil_image = Image.fromarray(pixel_data.squeeze(), mode="L") | |
| elif channels == 3: | |
| pil_image = Image.fromarray(pixel_data, mode="RGB") | |
| elif channels == 4: | |
| pil_image = Image.fromarray(pixel_data, mode="RGBA") | |
| else: | |
| raise ValueError(f"Unsupported number of channels: {channels}") | |
| images.append(pil_image) | |
| return images | |
| def terminate(self): | |
| if self.libsdcpp: | |
| if self.sd_ctx: | |
| self.libsdcpp.free_sd_ctx.argtypes = [c_void_p] | |
| self.libsdcpp.free_sd_ctx.restype = None | |
| self.libsdcpp.free_sd_ctx(self.sd_ctx) | |
| del self.sd_ctx | |
| self.sd_ctx = None | |
| del self.libsdcpp | |
| self.libsdcpp = None | |