Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tsai.models.GatedTabTransformer import GatedTabTransformer
|
3 |
+
from fastai.tabular.all import TabularDataLoaders
|
4 |
+
|
5 |
+
# Define el DataLoaders que utilizaste durante el entrenamiento (esto es necesario para configurar correctamente el modelo)
|
6 |
+
# Crear los loaders de datos
|
7 |
+
dls = TabularDataLoaders.from_df(data, path=".", y_names="clase",
|
8 |
+
cat_names=cat_names, cont_names=cont_names,
|
9 |
+
procs=[Categorify, FillMissing, Normalize])
|
10 |
+
|
11 |
+
# Aquí asumo que tienes las clases categóricas y los nombres de las características continuas
|
12 |
+
# Nombres de variables categóricas
|
13 |
+
|
14 |
+
|
15 |
+
cat_names = ['Globulos rojos', 'Celulas de pus', 'Grumos de celulas de pus',
|
16 |
+
'Bacterias', 'Hipertensión', 'Diabetes mellitus',
|
17 |
+
'Enfermedad arterial coronaria', 'Apetito', 'Edema pedioso', 'Anemia']
|
18 |
+
|
19 |
+
# Nombres de variables continuas
|
20 |
+
|
21 |
+
|
22 |
+
cont_names = ['Edad', 'Presion arterial', 'Gravedad especifica', 'Albumina', 'Azucar',
|
23 |
+
'Glucosa en sangre aleatoria', 'Urea en sangre', 'Creatinina serica',
|
24 |
+
'Sodio', 'Potasio', 'Hemoglobina', 'Volumen de celulas empaquetadas',
|
25 |
+
'Recuento de globulos blancos', 'Recuento de globulos rojos']
|
26 |
+
# Carga del modelo preentrenado
|
27 |
+
# model = GatedTabTransformer(emb_szs={}, n_cont=10, out_sz=2) # Ajusta los parámetros según tu configuración
|
28 |
+
model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c, mlp_d_model=128, mlp_d_ffn=256, mlp_layers=6)
|
29 |
+
model.load_state_dict(torch.load('/content/GatedTabTransformer_model.pth'))
|