Boboiazumi's picture
Upload 3 files
b2e8a4e verified
raw
history blame
1.44 kB
import gradio as gr
import torch
import torchvision.models as models
from torchvision import transforms
from torch import nn
from PIL import Image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
model = models.mobilenet_v2()
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
model = model.to("cpu")
model.load_state_dict(torch.load("cnn_model.pth", weights_only=True, map_location="cpu"))
model.eval()
label = ["nsfw", "safe"]
def inference(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
output = torch.nn.functional.softmax(output, dim=1)
predicted_class = torch.argmax(output, dim=1).item()
score = output[0][predicted_class]
if label[predicted_class] == "nsfw":
output = f'Boneka ini terlalu seksi dan tidak aman dilihat anak kecil (NSFW) [{label[predicted_class]}:{score}]'
else:
output = f'Boneka ini aman (SAFE) [{label[predicted_class]}:{score}]'
return output
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inputs = gr.Image(type="pil")
with gr.Column():
btn = gr.Button("Cek")
pred = gr.Text(label="Prediction")
btn.click(fn=inference, inputs=inputs, outputs=pred)
demo.queue().launch()