# HANDLER v6 — SDXL txt2img + inpaint, supports image_url/mask_url, guards UNet channels import os, io, json, base64, requests from typing import Any, Dict from PIL import Image import torch from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline from huggingface_hub import snapshot_download MODEL_ID = os.getenv("MODEL_ID", "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING") class EndpointHandler: def __init__(self, path: str = "."): print("HANDLER v6: init start") token = os.getenv("HF_TOKEN") local_dir = snapshot_download(MODEL_ID, token=token) print(f"HANDLER v6: snapshot at {local_dir}") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipe_txt2img = None self.pipe_inpaint = None last_err = None for dtype in (torch.float16, torch.bfloat16, torch.float32): try: # Try to load txt2img try: p = StableDiffusionXLPipeline.from_pretrained( local_dir, torch_dtype=dtype, use_safetensors=True ).to(self.device) # Keep txt2img ONLY if UNet is 4-ch (proper base) if getattr(p.unet.config, "in_channels", 4) == 4: self.pipe_txt2img = p print(f"HANDLER v6: txt2img OK ({dtype}, in_ch=4)") else: print("HANDLER v6: txt2img UNet in_ch != 4; disabling txt2img for this repo") try: p.to("cpu"); del p except Exception: pass self.pipe_txt2img = None except Exception as e: self.pipe_txt2img = None print(f"HANDLER v6: txt2img failed on {dtype}: {e}") # Load inpaint (required) self.pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained( local_dir, torch_dtype=dtype, use_safetensors=True ).to(self.device) print(f"HANDLER v6: inpaint OK ({dtype}, in_ch={getattr(self.pipe_inpaint.unet.config, 'in_channels', 'NA')})") break except Exception as e: last_err = e self.pipe_txt2img = None self.pipe_inpaint = None print(f"HANDLER v6: inpaint failed on {dtype}: {e}") if self.pipe_inpaint is None: raise RuntimeError(f"Failed to load pipelines: {last_err}") try: self.pipe_inpaint.enable_attention_slicing() if self.pipe_txt2img: self.pipe_txt2img.enable_attention_slicing() except Exception: pass print("HANDLER v6: ready") # ---------- helpers ---------- def _unwrap(self, data: Dict[str, Any]) -> Dict[str, Any]: if "inputs" in data: inner = data["inputs"] if isinstance(inner, str): try: return json.loads(inner) except Exception: return {} if isinstance(inner, dict): return inner return data def _fetch_url_bytes(self, url: str) -> bytes: r = requests.get(url, timeout=60) r.raise_for_status() return r.content def _to_pil(self, payload: Any, mode: str) -> Image.Image: # Accept: bytes, base64, or data URL, or HTTP(S) URL if isinstance(payload, str): if payload.startswith("http://") or payload.startswith("https://"): payload = self._fetch_url_bytes(payload) else: if payload.startswith("data:"): payload = payload.split(",", 1)[1] payload = base64.b64decode(payload) return Image.open(io.BytesIO(payload)).convert(mode) def _int(self, data, key, default): try: return int(data.get(key, default)) except Exception: return default def _float(self, data, key, default): try: return float(data.get(key, default)) except Exception: return default # ---------- main entry ---------- def __call__(self, data: Dict[str, Any]): data = self._unwrap(data) prompt = data.get("prompt", "") negative_prompt = data.get("negative_prompt", None) steps = self._int(data, "num_inference_steps", 30) guidance = self._float(data, "guidance_scale", 7.0) seed = data.get("seed", None) generator = None if seed is not None: try: generator = torch.Generator(device=self.device).manual_seed(int(seed)) except Exception: generator = None # Normalize keys for images/masks # Accept: image / init_image / image_url ; mask / mask_url init_img_payload = None if "image" in data: init_img_payload = data["image"] elif "init_image" in data: init_img_payload = data["init_image"] elif "image_url" in data: init_img_payload = data["image_url"] mask_payload = data.get("mask") or data.get("mask_url") # --------- choose mode --------- if init_img_payload is None: # txt2img mode (only if we truly have a 4-ch UNet) width = self._int(data, "width", 1024) height = self._int(data, "height", 1024) width = max(64, (width // 8) * 8) height = max(64, (height // 8) * 8) if self.pipe_txt2img is not None: image = self.pipe_txt2img( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance, generator=generator, ).images[0] else: # Fallback: blank-canvas inpaint (works with 9-ch UNet) canvas = Image.new("RGB", (width, height), (255, 255, 255)) mask = Image.new("L", (width, height), 255) image = self.pipe_inpaint( prompt=prompt, image=canvas, mask_image=mask, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance, generator=generator, ).images[0] else: # inpaint mode init_img = self._to_pil(init_img_payload, "RGB") if mask_payload is not None: mask_img = self._to_pil(mask_payload, "L").resize(init_img.size, Image.NEAREST) else: mask_img = Image.new("L", init_img.size, 255) # edit-all default strength = self._float(data, "strength", 0.85) image = self.pipe_inpaint( prompt=prompt, image=init_img, mask_image=mask_img, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance, strength=strength, generator=generator, ).images[0] # Return PNG as base64 buf = io.BytesIO() image.save(buf, format="PNG") out_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return {"image_base64": out_b64}