import os from typing import Any, Dict from PIL import Image from huggingface_inference_toolkit.logging import logger from pymongo.mongo_client import MongoClient from diffusers.utils import load_image import numpy as np import pandas as pd import time from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np import pandas as pd import timm import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image from simple_parsing import field from timm.data import create_transform, resolve_data_config from torch import Tensor, nn from torch.nn import functional as F HF_TOKEN = os.environ.get("HF_TOKEN", "") torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_REPO_MAP = { "vit": "SmilingWolf/wd-vit-large-tagger-v3", } def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square(image: Image.Image) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), (255, 255, 255)) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def get_tags( probs: Tensor, labels: LabelData, gen_threshold: float, char_threshold: float, ): # Convert indices+probs to labels probs = list(zip(labels.names, probs.numpy())) # First 4 labels are actually ratings rating_labels = dict([probs[i] for i in labels.rating]) # General labels, pick any where prediction confidence > threshold gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) # Character labels, pick any where prediction confidence > threshold char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > char_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) # Combine general and character labels, sort by confidence combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) # Convert to a string suitable for use as a training caption caption = ", ".join(combined_names) taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") return caption, taglist, rating_labels, char_labels, gen_labels @dataclass class ScriptOptions: image_file: Path = field(positional=True) model: str = field(default="vit") gen_threshold: float = field(default=0.35) char_threshold: float = field(default=0.75) class EndpointHandler: def __init__(self, path=""): self.opts = ScriptOptions repo_id = MODEL_REPO_MAP.get(self.opts.model) self.model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() state_dict = timm.models.load_state_dict_from_hf(repo_id) self.model.load_state_dict(state_dict) self.labels: LabelData = load_labels_hf(repo_id=repo_id) self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model)) # move model to GPU, if available if torch_device.type != "cpu": self.model = self.model.to(torch_device) uri = os.environ.get("MongoDB", "") self.client = MongoClient(uri) self.db = self.client['nomorecopyright'] self.collection = self.db['imagerequests'] self.query = {"keywords": {"$exists": False}} self.projection = {"_id": 0, "createdImage": 1} def __call__(self, data: Dict[str, Any]) -> str: logger.info(f"Received incoming request with {data=}") if "inputs" in data and isinstance(data["inputs"], str): prompt = data.pop("inputs") else: raise ValueError( "Provided input body must contain either the key `inputs` or `prompt` with the" " prompt to use for the image generation, and it needs to be a non-empty string." ) start_index,limit_count=prompt.split(',') start_index=int(start_index) limit_count=int(limit_count) logger.info(f"Start index: {start_index}, Limit count: {limit_count}") data = list(self.collection.find(self.query).skip(start_index).limit(limit_count)) start_time=time.time() for document in data: try: image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg')) # get image # ensure image is RGB img_input = pil_ensure_rgb(image) # pad to square with white background img_input = pil_pad_square(img_input) # run the model's input transform to convert to tensor and rescale inputs: Tensor = self.transform(img_input).unsqueeze(0) # NCHW image RGB to BGR inputs = inputs[:, [2, 1, 0]] with torch.inference_mode(): # move model to GPU, if available if torch_device.type != "cpu": inputs = inputs.to(torch_device) outputs = self.model.forward(inputs) # apply the final activation function (timm doesn't support doing this internally) outputs = F.sigmoid(outputs) # move inputs, outputs, and model back to to cpu if we were on GPU if torch_device.type != "cpu": inputs = inputs.to("cpu") outputs = outputs.to("cpu") caption, taglist, ratings, character, general = get_tags( probs=outputs.squeeze(0), labels=self.labels, gen_threshold=self.opts.gen_threshold, char_threshold=self.opts.char_threshold, ) results={**ratings, **character, **general} results={key: float(value) for key, value in results.items()} saveQuery = {"_id": document.get('_id')} # Update operation to add keywords with confidence scores update_result = self.collection.update_one(saveQuery , {'$set': {'keywords': results}}) except Exception as e: logger.error(f"Error processing image: {e}") end_time=time.time() print(f"Time taken: {end_time-start_time:.2f} seconds") return 'OK'