Tercero commited on
Commit
195aaf5
·
verified ·
1 Parent(s): d77e2d6

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +29 -0
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tsai.models.GatedTabTransformer import GatedTabTransformer
3
+ from fastai.tabular.all import TabularDataLoaders
4
+
5
+ # Define el DataLoaders que utilizaste durante el entrenamiento (esto es necesario para configurar correctamente el modelo)
6
+ # Crear los loaders de datos
7
+ dls = TabularDataLoaders.from_df(data, path=".", y_names="clase",
8
+ cat_names=cat_names, cont_names=cont_names,
9
+ procs=[Categorify, FillMissing, Normalize])
10
+
11
+ # Aquí asumo que tienes las clases categóricas y los nombres de las características continuas
12
+ # Nombres de variables categóricas
13
+
14
+
15
+ cat_names = ['Globulos rojos', 'Celulas de pus', 'Grumos de celulas de pus',
16
+ 'Bacterias', 'Hipertensión', 'Diabetes mellitus',
17
+ 'Enfermedad arterial coronaria', 'Apetito', 'Edema pedioso', 'Anemia']
18
+
19
+ # Nombres de variables continuas
20
+
21
+
22
+ cont_names = ['Edad', 'Presion arterial', 'Gravedad especifica', 'Albumina', 'Azucar',
23
+ 'Glucosa en sangre aleatoria', 'Urea en sangre', 'Creatinina serica',
24
+ 'Sodio', 'Potasio', 'Hemoglobina', 'Volumen de celulas empaquetadas',
25
+ 'Recuento de globulos blancos', 'Recuento de globulos rojos']
26
+ # Carga del modelo preentrenado
27
+ # model = GatedTabTransformer(emb_szs={}, n_cont=10, out_sz=2) # Ajusta los parámetros según tu configuración
28
+ model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c, mlp_d_model=128, mlp_d_ffn=256, mlp_layers=6)
29
+ model.load_state_dict(torch.load('/content/GatedTabTransformer_model.pth'))