noequalindi commited on
Commit
88282d2
·
1 Parent(s): 06ec94b

add gradio app and model

Browse files
Files changed (2) hide show
  1. app.py +197 -0
  2. models/cvt_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CvtForImageClassification, AutoFeatureExtractor
4
+ from PIL import Image
5
+ import os
6
+
7
+ # Configuración del dispositivo
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Cargar el extractor de características de Hugging Face
11
+ extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")
12
+
13
+ # Definir las clases en el mismo orden que el modelo las predice
14
+ class_names = [
15
+ "glioma_tumor",
16
+ "meningioma_tumor",
17
+ "no_tumor",
18
+ "pituitary_tumor"
19
+ ]
20
+
21
+ # Función para cargar el modelo (solo una vez)
22
+ def load_model():
23
+ model_dir = "models" # Ruta a los pesos
24
+ model_file_pytorch = "cvt_model.pth"
25
+
26
+ # Cargar los pesos del modelo desde el archivo .pth
27
+ checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
28
+
29
+ # Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
30
+ if isinstance(checkpoint, CvtForImageClassification):
31
+ model_pytorch = checkpoint # El checkpoint ya es un modelo completo
32
+ else:
33
+ model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
34
+ model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo
35
+
36
+ model_pytorch.to(device)
37
+ model_pytorch.eval()
38
+ return model_pytorch
39
+
40
+ # Cargar el modelo una vez cuando la app se inicie
41
+ model_pytorch = load_model()
42
+
43
+ # Función para hacer predicción con la imagen cargada
44
+ def predict_image(image):
45
+ # Preprocesar la imagen usando el extractor de características
46
+ inputs = extractor(images=image, return_tensors="pt").to(device)
47
+
48
+ # Hacer la predicción con el modelo
49
+ with torch.no_grad():
50
+ outputs = model_pytorch(**inputs)
51
+
52
+ # Obtener los logits de la salida
53
+ logits = outputs.logits
54
+
55
+ # Convertir los logits en probabilidades
56
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
57
+
58
+ # Obtener la clase predicha (índice con mayor probabilidad)
59
+ predicted_index = probabilities.argmax(dim=-1).item()
60
+
61
+ # Mapear el índice de la clase predicha al nombre de la clase
62
+ predicted_class = class_names[predicted_index]
63
+
64
+ # Retornar el nombre de la clase predicha
65
+ return predicted_class
66
+ # Función para limpiar los inputs
67
+ def clear_inputs():
68
+ return None, None, None
69
+
70
+ # Definir el tema y la interfaz de Gradio
71
+ theme = gr.themes.Soft(
72
+ primary_hue="indigo",
73
+ secondary_hue="indigo",
74
+ ).set(
75
+ background_fill_primary='#121212', # Dark background
76
+ background_fill_secondary='#1e1e1e',
77
+ block_background_fill='#1e1e1e', # Almost black
78
+ block_border_color='#333',
79
+ block_label_text_color='#fffff',
80
+ block_label_text_color_dark = '#fffff',
81
+ block_title_text_color_dark = '#fffff',
82
+ button_primary_background_fill='#4f46e5', # Violet
83
+ button_primary_background_fill_hover='#2563eb', # Light blue
84
+ button_secondary_background_fill='#4f46e5',
85
+ button_secondary_background_fill_hover='#2563eb',
86
+ input_background_fill='#333', # Dark grey
87
+ input_border_color='#444', # Intermediate grey
88
+ block_label_background_fill='#4f46e5',
89
+ block_label_background_fill_dark='#4f46e5',
90
+ slider_color='#2563eb',
91
+ slider_color_dark='#2563eb',
92
+ button_primary_text_color='#fffff',
93
+ button_secondary_text_color='#fffff',
94
+ button_secondary_background_fill_hover_dark='#4f46e5',
95
+ button_cancel_background_fill_hover='#444',
96
+ button_cancel_background_fill_hover_dark='#444'
97
+ )
98
+
99
+ with gr.Blocks(theme=theme, css="""
100
+ body, gradio-app {
101
+ background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1');
102
+ background-size: cover;
103
+ color: white;
104
+ }
105
+ .gradio-container {
106
+ background-color: transparent;
107
+ background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important;
108
+ background-size: cover !important;
109
+ color: white;
110
+ }
111
+ .gradio-container .gr-dropdown-container select::after {
112
+ content: '▼';
113
+ color: white;
114
+ padding-left: 5px;
115
+ }
116
+ .gradio-container .gr-dropdown-container select:focus {
117
+ outline: none;
118
+ border-color: #4f46e5;
119
+ }
120
+ .gradio-container select {
121
+ color: white;
122
+ }
123
+ input, select, span, button, svg, .secondary-wrap {
124
+ color: white;
125
+ }
126
+
127
+ h1 {
128
+ color: white;
129
+ font-size: 4em;
130
+ margin: 20px auto;
131
+ }
132
+ .gradio-container h1 {
133
+ font-size: 5em;
134
+ color: white;
135
+ text-align: center;
136
+ text-shadow: 2px 2px 0px #8A2BE2,
137
+ 4px 4px 0px #00000033;
138
+ text-transform: uppercase;
139
+ margin: 18px auto;
140
+ }
141
+ .gradio-container input {
142
+ color: white;
143
+ }
144
+ .gradio-container .output {
145
+ color: white;
146
+ }
147
+ .required-dropdown li {
148
+ color: white;
149
+ }
150
+ .button-style {
151
+ background-color: #4f46e5;
152
+ color: white;
153
+ }
154
+ .button-style:hover {
155
+ background-color: #2563eb;
156
+ color: white;
157
+ }
158
+
159
+ .gradio-container .contain textarea {
160
+ color: white;
161
+ font-weight: 600;
162
+ font-size: 1.5rem;
163
+ }
164
+ .contain textarea {
165
+ color: white;
166
+ font-weight: 600;
167
+ font-size: 1.5rem;
168
+ }
169
+ textarea {
170
+ color: white;
171
+ font-weight: 600;
172
+ font-size: 1.5rem;
173
+ background-color: black;
174
+ }
175
+ textarea .scroll-hide {
176
+ color: white;
177
+ }
178
+ .scroll-hide svelte-1f354aw {
179
+ color: white;
180
+ }
181
+ """) as demo:
182
+
183
+ gr.Markdown("# Brain Tumor Classification 🧠")
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ image_input = gr.Image(type="pil", label="Sube la imagen")
188
+ model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
189
+ classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
190
+ clear_btn = gr.Button("Limpiar")
191
+ with gr.Column():
192
+ prediction_output = gr.Textbox(label="Predicción")
193
+
194
+ classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
195
+ clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])
196
+
197
+ demo.launch()
models/cvt_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e10dc9a09c803275139d55eeec5cd0414507661eab754e89991f5c1fc6c9bcbf
3
+ size 80316819