|
|
|
import os, io, json, base64, requests |
|
from typing import Any, Dict |
|
from PIL import Image |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline |
|
from huggingface_hub import snapshot_download |
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING") |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = "."): |
|
print("HANDLER v6: init start") |
|
token = os.getenv("HF_TOKEN") |
|
local_dir = snapshot_download(MODEL_ID, token=token) |
|
print(f"HANDLER v6: snapshot at {local_dir}") |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.pipe_txt2img = None |
|
self.pipe_inpaint = None |
|
last_err = None |
|
|
|
for dtype in (torch.float16, torch.bfloat16, torch.float32): |
|
try: |
|
|
|
try: |
|
p = StableDiffusionXLPipeline.from_pretrained( |
|
local_dir, torch_dtype=dtype, use_safetensors=True |
|
).to(self.device) |
|
|
|
if getattr(p.unet.config, "in_channels", 4) == 4: |
|
self.pipe_txt2img = p |
|
print(f"HANDLER v6: txt2img OK ({dtype}, in_ch=4)") |
|
else: |
|
print("HANDLER v6: txt2img UNet in_ch != 4; disabling txt2img for this repo") |
|
try: |
|
p.to("cpu"); del p |
|
except Exception: |
|
pass |
|
self.pipe_txt2img = None |
|
except Exception as e: |
|
self.pipe_txt2img = None |
|
print(f"HANDLER v6: txt2img failed on {dtype}: {e}") |
|
|
|
|
|
self.pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained( |
|
local_dir, torch_dtype=dtype, use_safetensors=True |
|
).to(self.device) |
|
print(f"HANDLER v6: inpaint OK ({dtype}, in_ch={getattr(self.pipe_inpaint.unet.config, 'in_channels', 'NA')})") |
|
break |
|
except Exception as e: |
|
last_err = e |
|
self.pipe_txt2img = None |
|
self.pipe_inpaint = None |
|
print(f"HANDLER v6: inpaint failed on {dtype}: {e}") |
|
|
|
if self.pipe_inpaint is None: |
|
raise RuntimeError(f"Failed to load pipelines: {last_err}") |
|
|
|
try: |
|
self.pipe_inpaint.enable_attention_slicing() |
|
if self.pipe_txt2img: |
|
self.pipe_txt2img.enable_attention_slicing() |
|
except Exception: |
|
pass |
|
|
|
print("HANDLER v6: ready") |
|
|
|
|
|
def _unwrap(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
if "inputs" in data: |
|
inner = data["inputs"] |
|
if isinstance(inner, str): |
|
try: |
|
return json.loads(inner) |
|
except Exception: |
|
return {} |
|
if isinstance(inner, dict): |
|
return inner |
|
return data |
|
|
|
def _fetch_url_bytes(self, url: str) -> bytes: |
|
r = requests.get(url, timeout=60) |
|
r.raise_for_status() |
|
return r.content |
|
|
|
def _to_pil(self, payload: Any, mode: str) -> Image.Image: |
|
|
|
if isinstance(payload, str): |
|
if payload.startswith("http://") or payload.startswith("https://"): |
|
payload = self._fetch_url_bytes(payload) |
|
else: |
|
if payload.startswith("data:"): |
|
payload = payload.split(",", 1)[1] |
|
payload = base64.b64decode(payload) |
|
return Image.open(io.BytesIO(payload)).convert(mode) |
|
|
|
def _int(self, data, key, default): |
|
try: |
|
return int(data.get(key, default)) |
|
except Exception: |
|
return default |
|
|
|
def _float(self, data, key, default): |
|
try: |
|
return float(data.get(key, default)) |
|
except Exception: |
|
return default |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
data = self._unwrap(data) |
|
|
|
prompt = data.get("prompt", "") |
|
negative_prompt = data.get("negative_prompt", None) |
|
steps = self._int(data, "num_inference_steps", 30) |
|
guidance = self._float(data, "guidance_scale", 7.0) |
|
seed = data.get("seed", None) |
|
|
|
generator = None |
|
if seed is not None: |
|
try: |
|
generator = torch.Generator(device=self.device).manual_seed(int(seed)) |
|
except Exception: |
|
generator = None |
|
|
|
|
|
|
|
init_img_payload = None |
|
if "image" in data: |
|
init_img_payload = data["image"] |
|
elif "init_image" in data: |
|
init_img_payload = data["init_image"] |
|
elif "image_url" in data: |
|
init_img_payload = data["image_url"] |
|
|
|
mask_payload = data.get("mask") or data.get("mask_url") |
|
|
|
|
|
if init_img_payload is None: |
|
|
|
width = self._int(data, "width", 1024) |
|
height = self._int(data, "height", 1024) |
|
width = max(64, (width // 8) * 8) |
|
height = max(64, (height // 8) * 8) |
|
|
|
if self.pipe_txt2img is not None: |
|
image = self.pipe_txt2img( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
generator=generator, |
|
).images[0] |
|
else: |
|
|
|
canvas = Image.new("RGB", (width, height), (255, 255, 255)) |
|
mask = Image.new("L", (width, height), 255) |
|
image = self.pipe_inpaint( |
|
prompt=prompt, |
|
image=canvas, |
|
mask_image=mask, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
generator=generator, |
|
).images[0] |
|
else: |
|
|
|
init_img = self._to_pil(init_img_payload, "RGB") |
|
|
|
if mask_payload is not None: |
|
mask_img = self._to_pil(mask_payload, "L").resize(init_img.size, Image.NEAREST) |
|
else: |
|
mask_img = Image.new("L", init_img.size, 255) |
|
|
|
strength = self._float(data, "strength", 0.85) |
|
|
|
image = self.pipe_inpaint( |
|
prompt=prompt, |
|
image=init_img, |
|
mask_image=mask_img, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
strength=strength, |
|
generator=generator, |
|
).images[0] |
|
|
|
|
|
buf = io.BytesIO() |
|
image.save(buf, format="PNG") |
|
out_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
return {"image_base64": out_b64} |
|
|