|
import torch |
|
from tsai.models.GatedTabTransformer import GatedTabTransformer |
|
from fastai.tabular.all import TabularDataLoaders |
|
|
|
|
|
|
|
dls = TabularDataLoaders.from_df(data, path=".", y_names="clase", |
|
cat_names=cat_names, cont_names=cont_names, |
|
procs=[Categorify, FillMissing, Normalize]) |
|
|
|
|
|
|
|
|
|
|
|
cat_names = ['Globulos rojos', 'Celulas de pus', 'Grumos de celulas de pus', |
|
'Bacterias', 'Hipertensión', 'Diabetes mellitus', |
|
'Enfermedad arterial coronaria', 'Apetito', 'Edema pedioso', 'Anemia'] |
|
|
|
|
|
|
|
|
|
cont_names = ['Edad', 'Presion arterial', 'Gravedad especifica', 'Albumina', 'Azucar', |
|
'Glucosa en sangre aleatoria', 'Urea en sangre', 'Creatinina serica', |
|
'Sodio', 'Potasio', 'Hemoglobina', 'Volumen de celulas empaquetadas', |
|
'Recuento de globulos blancos', 'Recuento de globulos rojos'] |
|
|
|
|
|
model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c, mlp_d_model=128, mlp_d_ffn=256, mlp_layers=6) |
|
model.load_state_dict(torch.load('/content/GatedTabTransformer_model.pth')) |
|
|