import gradio as gr import csv import os from pathlib import Path import cv2 import numpy as np from PIL import Image import onnxruntime as ort from huggingface_hub import hf_hub_download import spaces # 画像のサイズ設定 IMAGE_SIZE = 448 def preprocess_image(image): image = np.array(image) image = image[:, :, ::-1] # BGRからRGBへ変換 # 画像を正方形にするためのパディングを追加 size = max(image.shape[0:2]) pad_x = size - image.shape[1] pad_y = size - image.shape[0] pad_l = pad_x // 2 pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) # サイズに合わせた補間方法を選択 interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) image = image.astype(np.float32) return image @spaces.GPU def main(image_path, model_id): print("Hugging Faceからモデルをダウンロード中") onnx_path = hf_hub_download(model_id, "model.onnx") csv_path = hf_hub_download(model_id, "selected_tags.csv") # ONNXモデルとCSVファイルの読み込み image = Image.open(image_path) image = image.convert("RGB") if image.mode != "RGB" else image image = preprocess_image(image) img = np.array([image]) ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] with open(csv_path, "r", encoding="utf-8") as f: reader = csv.reader(f) next(reader) # ヘッダーをスキップ rows = list(reader) rating_tags = [row[1] for row in rows if row[2] == "9"] character_tags = [row[1] for row in rows if row[2] == "4"] general_tags = [row[1] for row in rows if row[2] == "0"] # タグと評価 NSFW_flag, IP_flag, tag_text = evaluate_tags(prob, rating_tags, character_tags, general_tags) return NSFW_flag, IP_flag, tag_text def evaluate_tags(prob, rating_tags, character_tags, general_tags): thresh = 0.35 # NSFW/SFW判定 tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)} max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0)) max_sfw_score = tag_confidences.get("general", 0) NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです" # 版権キャラクターの可能性を評価 character_tags_with_probs = [] for i, p in enumerate(prob[4:]): if p >= thresh and i >= len(general_tags): tag_index = i - len(general_tags) if tag_index < len(character_tags): tag_name = character_tags[tag_index] prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換 character_tags_with_probs.append((tag_name, f"{prob_percent}%")) IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます" # タグを生成 general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh]) character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh]) tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text return NSFW_flag, IP_flag, tag_text class webui: def __init__(self): self.demo = gr.Blocks() def launch(self): with self.demo: with gr.Row(): with gr.Column(): input_image = gr.Image(type='filepath', label="Analysis Image") model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3") output_0 = gr.Textbox(label="NSFW Flag") output_1 = gr.Textbox(label="IP Flag") output_2 = gr.Textbox(label="Tags") submit = gr.Button(value="Start Analysis") submit.click( main, inputs=[input_image, model_id], outputs=[output_0, output_1, output_2] ) self.demo.launch(share=True) # 公開リンクを設定 if __name__ == "__main__": ui = webui() ui.launch()