import torch from tsai.models.GatedTabTransformer import GatedTabTransformer from fastai.tabular.all import TabularDataLoaders # Define el DataLoaders que utilizaste durante el entrenamiento (esto es necesario para configurar correctamente el modelo) # Crear los loaders de datos dls = TabularDataLoaders.from_df(data, path=".", y_names="clase", cat_names=cat_names, cont_names=cont_names, procs=[Categorify, FillMissing, Normalize]) # Aquí asumo que tienes las clases categóricas y los nombres de las características continuas # Nombres de variables categóricas cat_names = ['Globulos rojos', 'Celulas de pus', 'Grumos de celulas de pus', 'Bacterias', 'Hipertensión', 'Diabetes mellitus', 'Enfermedad arterial coronaria', 'Apetito', 'Edema pedioso', 'Anemia'] # Nombres de variables continuas cont_names = ['Edad', 'Presion arterial', 'Gravedad especifica', 'Albumina', 'Azucar', 'Glucosa en sangre aleatoria', 'Urea en sangre', 'Creatinina serica', 'Sodio', 'Potasio', 'Hemoglobina', 'Volumen de celulas empaquetadas', 'Recuento de globulos blancos', 'Recuento de globulos rojos'] # Carga del modelo preentrenado # model = GatedTabTransformer(emb_szs={}, n_cont=10, out_sz=2) # Ajusta los parámetros según tu configuración model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c, mlp_d_model=128, mlp_d_ffn=256, mlp_layers=6) model.load_state_dict(torch.load('/content/GatedTabTransformer_model.pth'))