Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| from config import ( | |
| HIDDEN_SIZE, | |
| DROPOUT_PROB, | |
| LAST_NUM_NEURON, | |
| HF_REPO_NAME, | |
| WEIGHTS_FILE_NAME, | |
| PRETRAINED_MODEL, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| class EnergySmellsDetector(nn.Module): | |
| def __init__(self, model_name): | |
| super(EnergySmellsDetector, self).__init__() | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(DROPOUT_PROB) | |
| self.fc = nn.Linear(HIDDEN_SIZE, LAST_NUM_NEURON) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| x = self.dropout(outputs.pooler_output) | |
| logits = self.fc(x) | |
| return torch.sigmoid(logits).to(float) | |
| def load_model_from_hf(): | |
| model_path = hf_hub_download(repo_id=HF_REPO_NAME, filename=WEIGHTS_FILE_NAME) | |
| # Load model | |
| model = EnergySmellsDetector(PRETRAINED_MODEL) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| return model | |