Spaces:
Sleeping
Sleeping
| import huggingface_hub | |
| import pretrainedmodels | |
| import torch | |
| import torch.nn as nn | |
| def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"): | |
| """ | |
| Loads a pre-trained model. | |
| Args: | |
| model_name (str): Name of the model to load. | |
| num_classes (int): Number of classes for the model. | |
| pretrained (str): Whether to use pre-trained weights. | |
| Returns: | |
| torch.nn.Module: The loaded model. | |
| """ | |
| model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) | |
| dim_feats = model.last_linear.in_features | |
| model.last_linear = nn.Linear(dim_feats, num_classes) | |
| model.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| return model | |
| def load_model(device): | |
| """ | |
| Loads the age estimation model from Hugging Face Hub. | |
| Args: | |
| device (torch.device): The device to load the model onto. | |
| Returns: | |
| torch.nn.Module: The loaded model. | |
| """ | |
| model = get_model(model_name="se_resnext50_32x4d", pretrained=None) | |
| path = huggingface_hub.hf_hub_download( | |
| "public-data/yu4u-age-estimation-pytorch", "pretrained.pth" | |
| ) | |
| model.load_state_dict(torch.load(path, weights_only=True)) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |