Commit
·
88282d2
1
Parent(s):
06ec94b
add gradio app and model
Browse files- app.py +197 -0
- models/cvt_model.pth +3 -0
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import CvtForImageClassification, AutoFeatureExtractor
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Configuración del dispositivo
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
|
10 |
+
# Cargar el extractor de características de Hugging Face
|
11 |
+
extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")
|
12 |
+
|
13 |
+
# Definir las clases en el mismo orden que el modelo las predice
|
14 |
+
class_names = [
|
15 |
+
"glioma_tumor",
|
16 |
+
"meningioma_tumor",
|
17 |
+
"no_tumor",
|
18 |
+
"pituitary_tumor"
|
19 |
+
]
|
20 |
+
|
21 |
+
# Función para cargar el modelo (solo una vez)
|
22 |
+
def load_model():
|
23 |
+
model_dir = "models" # Ruta a los pesos
|
24 |
+
model_file_pytorch = "cvt_model.pth"
|
25 |
+
|
26 |
+
# Cargar los pesos del modelo desde el archivo .pth
|
27 |
+
checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
|
28 |
+
|
29 |
+
# Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
|
30 |
+
if isinstance(checkpoint, CvtForImageClassification):
|
31 |
+
model_pytorch = checkpoint # El checkpoint ya es un modelo completo
|
32 |
+
else:
|
33 |
+
model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
|
34 |
+
model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo
|
35 |
+
|
36 |
+
model_pytorch.to(device)
|
37 |
+
model_pytorch.eval()
|
38 |
+
return model_pytorch
|
39 |
+
|
40 |
+
# Cargar el modelo una vez cuando la app se inicie
|
41 |
+
model_pytorch = load_model()
|
42 |
+
|
43 |
+
# Función para hacer predicción con la imagen cargada
|
44 |
+
def predict_image(image):
|
45 |
+
# Preprocesar la imagen usando el extractor de características
|
46 |
+
inputs = extractor(images=image, return_tensors="pt").to(device)
|
47 |
+
|
48 |
+
# Hacer la predicción con el modelo
|
49 |
+
with torch.no_grad():
|
50 |
+
outputs = model_pytorch(**inputs)
|
51 |
+
|
52 |
+
# Obtener los logits de la salida
|
53 |
+
logits = outputs.logits
|
54 |
+
|
55 |
+
# Convertir los logits en probabilidades
|
56 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
57 |
+
|
58 |
+
# Obtener la clase predicha (índice con mayor probabilidad)
|
59 |
+
predicted_index = probabilities.argmax(dim=-1).item()
|
60 |
+
|
61 |
+
# Mapear el índice de la clase predicha al nombre de la clase
|
62 |
+
predicted_class = class_names[predicted_index]
|
63 |
+
|
64 |
+
# Retornar el nombre de la clase predicha
|
65 |
+
return predicted_class
|
66 |
+
# Función para limpiar los inputs
|
67 |
+
def clear_inputs():
|
68 |
+
return None, None, None
|
69 |
+
|
70 |
+
# Definir el tema y la interfaz de Gradio
|
71 |
+
theme = gr.themes.Soft(
|
72 |
+
primary_hue="indigo",
|
73 |
+
secondary_hue="indigo",
|
74 |
+
).set(
|
75 |
+
background_fill_primary='#121212', # Dark background
|
76 |
+
background_fill_secondary='#1e1e1e',
|
77 |
+
block_background_fill='#1e1e1e', # Almost black
|
78 |
+
block_border_color='#333',
|
79 |
+
block_label_text_color='#fffff',
|
80 |
+
block_label_text_color_dark = '#fffff',
|
81 |
+
block_title_text_color_dark = '#fffff',
|
82 |
+
button_primary_background_fill='#4f46e5', # Violet
|
83 |
+
button_primary_background_fill_hover='#2563eb', # Light blue
|
84 |
+
button_secondary_background_fill='#4f46e5',
|
85 |
+
button_secondary_background_fill_hover='#2563eb',
|
86 |
+
input_background_fill='#333', # Dark grey
|
87 |
+
input_border_color='#444', # Intermediate grey
|
88 |
+
block_label_background_fill='#4f46e5',
|
89 |
+
block_label_background_fill_dark='#4f46e5',
|
90 |
+
slider_color='#2563eb',
|
91 |
+
slider_color_dark='#2563eb',
|
92 |
+
button_primary_text_color='#fffff',
|
93 |
+
button_secondary_text_color='#fffff',
|
94 |
+
button_secondary_background_fill_hover_dark='#4f46e5',
|
95 |
+
button_cancel_background_fill_hover='#444',
|
96 |
+
button_cancel_background_fill_hover_dark='#444'
|
97 |
+
)
|
98 |
+
|
99 |
+
with gr.Blocks(theme=theme, css="""
|
100 |
+
body, gradio-app {
|
101 |
+
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1');
|
102 |
+
background-size: cover;
|
103 |
+
color: white;
|
104 |
+
}
|
105 |
+
.gradio-container {
|
106 |
+
background-color: transparent;
|
107 |
+
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important;
|
108 |
+
background-size: cover !important;
|
109 |
+
color: white;
|
110 |
+
}
|
111 |
+
.gradio-container .gr-dropdown-container select::after {
|
112 |
+
content: '▼';
|
113 |
+
color: white;
|
114 |
+
padding-left: 5px;
|
115 |
+
}
|
116 |
+
.gradio-container .gr-dropdown-container select:focus {
|
117 |
+
outline: none;
|
118 |
+
border-color: #4f46e5;
|
119 |
+
}
|
120 |
+
.gradio-container select {
|
121 |
+
color: white;
|
122 |
+
}
|
123 |
+
input, select, span, button, svg, .secondary-wrap {
|
124 |
+
color: white;
|
125 |
+
}
|
126 |
+
|
127 |
+
h1 {
|
128 |
+
color: white;
|
129 |
+
font-size: 4em;
|
130 |
+
margin: 20px auto;
|
131 |
+
}
|
132 |
+
.gradio-container h1 {
|
133 |
+
font-size: 5em;
|
134 |
+
color: white;
|
135 |
+
text-align: center;
|
136 |
+
text-shadow: 2px 2px 0px #8A2BE2,
|
137 |
+
4px 4px 0px #00000033;
|
138 |
+
text-transform: uppercase;
|
139 |
+
margin: 18px auto;
|
140 |
+
}
|
141 |
+
.gradio-container input {
|
142 |
+
color: white;
|
143 |
+
}
|
144 |
+
.gradio-container .output {
|
145 |
+
color: white;
|
146 |
+
}
|
147 |
+
.required-dropdown li {
|
148 |
+
color: white;
|
149 |
+
}
|
150 |
+
.button-style {
|
151 |
+
background-color: #4f46e5;
|
152 |
+
color: white;
|
153 |
+
}
|
154 |
+
.button-style:hover {
|
155 |
+
background-color: #2563eb;
|
156 |
+
color: white;
|
157 |
+
}
|
158 |
+
|
159 |
+
.gradio-container .contain textarea {
|
160 |
+
color: white;
|
161 |
+
font-weight: 600;
|
162 |
+
font-size: 1.5rem;
|
163 |
+
}
|
164 |
+
.contain textarea {
|
165 |
+
color: white;
|
166 |
+
font-weight: 600;
|
167 |
+
font-size: 1.5rem;
|
168 |
+
}
|
169 |
+
textarea {
|
170 |
+
color: white;
|
171 |
+
font-weight: 600;
|
172 |
+
font-size: 1.5rem;
|
173 |
+
background-color: black;
|
174 |
+
}
|
175 |
+
textarea .scroll-hide {
|
176 |
+
color: white;
|
177 |
+
}
|
178 |
+
.scroll-hide svelte-1f354aw {
|
179 |
+
color: white;
|
180 |
+
}
|
181 |
+
""") as demo:
|
182 |
+
|
183 |
+
gr.Markdown("# Brain Tumor Classification 🧠")
|
184 |
+
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
image_input = gr.Image(type="pil", label="Sube la imagen")
|
188 |
+
model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
|
189 |
+
classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
|
190 |
+
clear_btn = gr.Button("Limpiar")
|
191 |
+
with gr.Column():
|
192 |
+
prediction_output = gr.Textbox(label="Predicción")
|
193 |
+
|
194 |
+
classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
|
195 |
+
clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])
|
196 |
+
|
197 |
+
demo.launch()
|
models/cvt_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e10dc9a09c803275139d55eeec5cd0414507661eab754e89991f5c1fc6c9bcbf
|
3 |
+
size 80316819
|