Tercero's picture
Create model.py
195aaf5 verified
import torch
from tsai.models.GatedTabTransformer import GatedTabTransformer
from fastai.tabular.all import TabularDataLoaders
# Define el DataLoaders que utilizaste durante el entrenamiento (esto es necesario para configurar correctamente el modelo)
# Crear los loaders de datos
dls = TabularDataLoaders.from_df(data, path=".", y_names="clase",
cat_names=cat_names, cont_names=cont_names,
procs=[Categorify, FillMissing, Normalize])
# Aquí asumo que tienes las clases categóricas y los nombres de las características continuas
# Nombres de variables categóricas
cat_names = ['Globulos rojos', 'Celulas de pus', 'Grumos de celulas de pus',
'Bacterias', 'Hipertensión', 'Diabetes mellitus',
'Enfermedad arterial coronaria', 'Apetito', 'Edema pedioso', 'Anemia']
# Nombres de variables continuas
cont_names = ['Edad', 'Presion arterial', 'Gravedad especifica', 'Albumina', 'Azucar',
'Glucosa en sangre aleatoria', 'Urea en sangre', 'Creatinina serica',
'Sodio', 'Potasio', 'Hemoglobina', 'Volumen de celulas empaquetadas',
'Recuento de globulos blancos', 'Recuento de globulos rojos']
# Carga del modelo preentrenado
# model = GatedTabTransformer(emb_szs={}, n_cont=10, out_sz=2) # Ajusta los parámetros según tu configuración
model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c, mlp_d_model=128, mlp_d_ffn=256, mlp_layers=6)
model.load_state_dict(torch.load('/content/GatedTabTransformer_model.pth'))