Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
class CNN(nn.Module):
|
11 |
+
def __init__(self, in_channels=1, num_classes=4):
|
12 |
+
super(CNN, self).__init__()
|
13 |
+
|
14 |
+
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
|
15 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
16 |
+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
17 |
+
|
18 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
19 |
+
|
20 |
+
self.batch_norm1 = nn.BatchNorm2d(32)
|
21 |
+
self.batch_norm2 = nn.BatchNorm2d(64)
|
22 |
+
self.batch_norm3 = nn.BatchNorm2d(128)
|
23 |
+
|
24 |
+
self.dropout = nn.Dropout(0.5)
|
25 |
+
|
26 |
+
# Calcular el tama帽o de la entrada a la capa fully connected
|
27 |
+
self.fc1 = nn.Linear(128 * (200 // 8) * (200 // 8), 256)
|
28 |
+
self.fc2 = nn.Linear(256, num_classes)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = F.relu(self.batch_norm1(self.conv1(x)))
|
32 |
+
x = self.pool(x)
|
33 |
+
|
34 |
+
x = F.relu(self.batch_norm2(self.conv2(x)))
|
35 |
+
x = self.pool(x)
|
36 |
+
|
37 |
+
x = F.relu(self.batch_norm3(self.conv3(x)))
|
38 |
+
x = self.pool(x)
|
39 |
+
|
40 |
+
x = x.view(x.shape[0], -1) # Aplanar
|
41 |
+
x = self.dropout(F.relu(self.fc1(x)))
|
42 |
+
x = self.fc2(x)
|
43 |
+
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
model = CNN()
|
48 |
+
model.load_state_dict(
|
49 |
+
torch.load("gabriel_complex_modelo.pth", map_location=torch.device("cpu"))
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def inference(model, imagen, device="cpu"):
|
54 |
+
label_mapping = {0: "C铆rculo", 1: "Tri谩ngulo", 2: "Cuadrado", 3: "Estrella"}
|
55 |
+
|
56 |
+
model.eval() # Ponemos el modelo en modo evaluaci贸n
|
57 |
+
|
58 |
+
# Realizar la inferencia
|
59 |
+
with torch.no_grad():
|
60 |
+
scores = model(imagen) # Output: tensor con logits
|
61 |
+
probabilities = torch.softmax(
|
62 |
+
scores, dim=1
|
63 |
+
) # Convertir logits a probabilidades
|
64 |
+
_, prediction = scores.max(1) # Obtener la clase con mayor probabilidad
|
65 |
+
label_predicho = prediction.item()
|
66 |
+
|
67 |
+
# Diccionario con las probabilidades
|
68 |
+
probabilities_dict = {
|
69 |
+
label_mapping[i]: float(probabilities[0, i]) for i in range(4)
|
70 |
+
}
|
71 |
+
|
72 |
+
return label_mapping[label_predicho], probabilities_dict
|
73 |
+
|
74 |
+
|
75 |
+
def predict(img):
|
76 |
+
image_array = img["composite"][:, :, 3]
|
77 |
+
image_array = 255 - image_array
|
78 |
+
|
79 |
+
image_tensor = torch.from_numpy(image_array).unsqueeze(0)
|
80 |
+
|
81 |
+
transform_to_gray = transforms.Compose(
|
82 |
+
[
|
83 |
+
transforms.Resize((200, 200)),
|
84 |
+
transforms.ConvertImageDtype(dtype=torch.float32), # Convertir a flotante
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
image = transform_to_gray(image_tensor)
|
89 |
+
image = image.unsqueeze(0) # Agregar dimensi贸n extra
|
90 |
+
|
91 |
+
# Hacemos la inferencia
|
92 |
+
label_predict, probabilities = inference(model, image, device="cpu")
|
93 |
+
print(label_predict)
|
94 |
+
print(probabilities)
|
95 |
+
|
96 |
+
return probabilities # Retorna el diccionario con las probabilidades
|
97 |
+
|
98 |
+
|
99 |
+
with gr.Blocks() as demo:
|
100 |
+
with gr.Row():
|
101 |
+
im = gr.Sketchpad(type="numpy", crop_size="1:1")
|
102 |
+
out = gr.Label()
|
103 |
+
|
104 |
+
im.change(predict, outputs=out, inputs=im, show_progress="hidden")
|
105 |
+
|
106 |
+
demo.launch(share=True, debug=False)
|