Spaces:
Running
Running
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| import timm | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| class TaggingHead(torch.nn.Module): | |
| def __init__(self, input_dim, num_classes): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes)) | |
| def forward(self, x): | |
| logits = self.head(x) | |
| probs = torch.nn.functional.sigmoid(logits) | |
| return probs | |
| def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]: | |
| with tags_file.open("r", encoding="utf-8") as f: | |
| tag_info = json.load(f) | |
| tag_map = tag_info["tag_map"] | |
| tag_split = tag_info["tag_split"] | |
| gen_tag_count = tag_split["gen_tag_count"] | |
| character_tag_count = tag_split["character_tag_count"] | |
| return tag_map, gen_tag_count, character_tag_count | |
| def get_character_ip_mapping(mapping_file: Path): | |
| with mapping_file.open("r", encoding="utf-8") as f: | |
| mapping = json.load(f) | |
| return mapping | |
| def get_encoder(): | |
| base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3" | |
| encoder = timm.create_model(base_model_repo, pretrained=False) | |
| encoder.reset_classifier(0) | |
| return encoder | |
| def get_decoder(): | |
| decoder = TaggingHead(1024, 13461) | |
| return decoder | |
| def get_model(): | |
| encoder = get_encoder() | |
| decoder = get_decoder() | |
| model = torch.nn.Sequential(encoder, decoder) | |
| return model | |
| def load_model(weights_file, device): | |
| model = get_model() | |
| states_dict = torch.load(weights_file, map_location=device, weights_only=True) | |
| model.load_state_dict(states_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def pure_pil_alpha_to_color_v2( | |
| image: Image.Image, color: tuple[int, int, int] = (255, 255, 255) | |
| ) -> Image.Image: | |
| """ | |
| Convert a PIL image with an alpha channel to a RGB image. | |
| This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel. | |
| This function will convert the image to a RGB image, and fill the alpha channel with the given color. | |
| The alpha channel is the 4th channel of the image. | |
| """ | |
| image.load() # needed for split() | |
| background = Image.new("RGB", image.size, color) | |
| background.paste(image, mask=image.split()[3]) # 3 is the alpha channel | |
| return background | |
| def pil_to_rgb(image: Image.Image) -> Image.Image: | |
| if image.mode == "RGBA": | |
| image = pure_pil_alpha_to_color_v2(image) | |
| elif image.mode == "P": | |
| image = pure_pil_alpha_to_color_v2(image.convert("RGBA")) | |
| else: | |
| image = image.convert("RGB") | |
| return image | |
| class EndpointHandler: | |
| def __init__(self, path: str): | |
| repo_path = Path(path) | |
| assert repo_path.is_dir(), f"Model directory not found: {repo_path}" | |
| weights_file = repo_path / "model_v0.9.pth" | |
| tags_file = repo_path / "tags_v0.9_13k.json" | |
| mapping_file = repo_path / "char_ip_map.json" | |
| if not weights_file.exists(): | |
| raise FileNotFoundError(f"Model file not found: {weights_file}") | |
| if not tags_file.exists(): | |
| raise FileNotFoundError(f"Tags file not found: {tags_file}") | |
| if not mapping_file.exists(): | |
| raise FileNotFoundError(f"Mapping file not found: {mapping_file}") | |
| # Robust device selection: prefer CPU unless CUDA is truly usable | |
| force_cpu = os.environ.get("FORCE_CPU", "0") in {"1", "true", "TRUE", "yes", "on"} | |
| if not force_cpu and torch.cuda.is_available(): | |
| try: | |
| # Probe that CUDA can actually be used (driver present) | |
| torch.zeros(1).to("cuda") | |
| self.device = "cuda" | |
| except Exception: | |
| self.device = "cpu" | |
| else: | |
| self.device = "cpu" | |
| self.model = load_model(str(weights_file), self.device) | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((448, 448)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| self.fetch_image_timeout = 5.0 | |
| self.default_general_threshold = 0.3 | |
| self.default_character_threshold = 0.85 | |
| tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file) | |
| # Invert the tag_map for efficient index-to-tag lookups | |
| self.index_to_tag_map = {v: k for k, v in tag_map.items()} | |
| self.character_ip_mapping = get_character_ip_mapping(mapping_file) | |
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: | |
| inputs = data.pop("inputs", data) | |
| fetch_start_time = time.time() | |
| if isinstance(inputs, Image.Image): | |
| image = inputs | |
| elif image_url := inputs.pop("url", None): | |
| with requests.get( | |
| image_url, stream=True, timeout=self.fetch_image_timeout | |
| ) as res: | |
| res.raise_for_status() | |
| image = Image.open(res.raw) | |
| elif image_base64_encoded := inputs.pop("image", None): | |
| image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded))) | |
| else: | |
| raise ValueError(f"No image or url provided: {data}") | |
| # remove alpha channel if it exists | |
| image = pil_to_rgb(image) | |
| fetch_time = time.time() - fetch_start_time | |
| parameters = data.pop("parameters", {}) | |
| general_threshold = parameters.pop( | |
| "general_threshold", self.default_general_threshold | |
| ) | |
| character_threshold = parameters.pop( | |
| "character_threshold", self.default_character_threshold | |
| ) | |
| # Optional behavior controls | |
| mode = parameters.pop("mode", "threshold") # "threshold" | "topk" | |
| include_scores = bool(parameters.pop("include_scores", False)) | |
| topk_general = int(parameters.pop("topk_general", 25)) | |
| topk_character = int(parameters.pop("topk_character", 10)) | |
| inference_start_time = time.time() | |
| with torch.inference_mode(): | |
| # Preprocess image on CPU | |
| image_tensor = self.transform(image).unsqueeze(0) | |
| # Pin memory and use non_blocking transfer only when using CUDA | |
| if self.device == "cuda": | |
| image_tensor = image_tensor.pin_memory().to(self.device, non_blocking=True) | |
| else: | |
| image_tensor = image_tensor.to(self.device) | |
| # Run model on GPU | |
| probs = self.model(image_tensor)[0] # Get probs for the single image | |
| if mode == "topk": | |
| # Select top-k by category, independent of thresholds | |
| gen_slice = probs[: self.gen_tag_count] | |
| char_slice = probs[self.gen_tag_count :] | |
| k_gen = max(0, min(int(topk_general), self.gen_tag_count)) | |
| k_char = max(0, min(int(topk_character), self.character_tag_count)) | |
| gen_scores, gen_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) | |
| char_scores, char_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) | |
| if k_gen > 0: | |
| gen_scores, gen_idx = torch.topk(gen_slice, k_gen) | |
| if k_char > 0: | |
| char_scores, char_idx = torch.topk(char_slice, k_char) | |
| char_idx = char_idx + self.gen_tag_count | |
| # Merge for unified post-processing | |
| combined_indices = torch.cat((gen_idx, char_idx)).cpu() | |
| combined_scores = torch.cat((gen_scores, char_scores)).cpu() | |
| else: | |
| # Perform thresholding directly on the GPU | |
| general_mask = probs[: self.gen_tag_count] > general_threshold | |
| character_mask = probs[self.gen_tag_count :] > character_threshold | |
| # Get the indices of positive tags on the GPU | |
| general_indices = general_mask.nonzero(as_tuple=True)[0] | |
| character_indices = ( | |
| character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count | |
| ) | |
| # Combine indices and move the small result tensor to the CPU | |
| combined_indices = torch.cat((general_indices, character_indices)).cpu() | |
| combined_scores = probs[combined_indices].detach().float().cpu() | |
| inference_time = time.time() - inference_start_time | |
| post_process_start_time = time.time() | |
| cur_gen_tags = [] | |
| cur_char_tags = [] | |
| gen_scores_out: dict[str, float] = {} | |
| char_scores_out: dict[str, float] = {} | |
| # Use the efficient pre-computed map for lookups | |
| for pos, i in enumerate(combined_indices): | |
| idx = int(i.item()) | |
| tag = self.index_to_tag_map[idx] | |
| if idx < self.gen_tag_count: | |
| cur_gen_tags.append(tag) | |
| if include_scores: | |
| score = float(combined_scores[pos].item()) | |
| gen_scores_out[tag] = score | |
| else: | |
| cur_char_tags.append(tag) | |
| if include_scores: | |
| score = float(combined_scores[pos].item()) | |
| char_scores_out[tag] = score | |
| ip_tags = [] | |
| for tag in cur_char_tags: | |
| if tag in self.character_ip_mapping: | |
| ip_tags.extend(self.character_ip_mapping[tag]) | |
| ip_tags = sorted(set(ip_tags)) | |
| post_process_time = time.time() - post_process_start_time | |
| logging.info( | |
| f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s" | |
| ) | |
| out: dict[str, Any] = { | |
| "feature": cur_gen_tags, | |
| "character": cur_char_tags, | |
| "ip": ip_tags, | |
| "_timings": { | |
| "fetch_s": round(fetch_time, 4), | |
| "inference_s": round(inference_time, 4), | |
| "post_process_s": round(post_process_time, 4), | |
| "total_s": round(fetch_time + inference_time + post_process_time, 4), | |
| }, | |
| "_params": { | |
| "mode": mode, | |
| "general_threshold": general_threshold, | |
| "character_threshold": character_threshold, | |
| "topk_general": topk_general, | |
| "topk_character": topk_character, | |
| }, | |
| } | |
| if include_scores: | |
| out["feature_scores"] = gen_scores_out | |
| out["character_scores"] = char_scores_out | |
| return out | |