import torch from torchvision import transforms from transformers import AutoProcessor, FocalNetForImageClassification from PIL import Image, ImageDraw, ImageFont import numpy as np class NSFWDetector: def __init__(self): self.model_path = "TostAI/nsfw-image-detection-large" self.feature_extractor = AutoProcessor.from_pretrained(self.model_path) self.model = FocalNetForImageClassification.from_pretrained(self.model_path) self.model.eval() self.transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) self.label_to_category = { "LABEL_0": "Safe", "LABEL_1": "Questionable", "LABEL_2": "Unsafe" } def check_image(self, image): # Convert image to RGB if it isn't already image = image.convert("RGB") # Process image inputs = self.feature_extractor(images=image, return_tensors="pt") # Get prediction with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predicted = torch.max(probabilities, 1) # Get the label label = self.model.config.id2label[predicted.item()] category = self.label_to_category.get(label, label) return category != "Safe", category, confidence.item() * 100 def create_error_image(message="NSFW Content Detected"): # Create a black image img = Image.new('RGB', (512, 512), color='black') draw = ImageDraw.Draw(img) # Use default font try: # Try to get a default system font font = ImageFont.load_default() # Calculate text position to center it text_bbox = draw.textbbox((0, 0), message, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] x = (512 - text_width) // 2 y = (512 - text_height) // 2 # Draw white text draw.text((x, y), message, fill='white', font=font) except Exception as e: print(f"Error adding text to image: {e}") return img