Spaces:
Sleeping
Sleeping
File size: 847 Bytes
2c1ff7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from z_modelops import NameToLanguages, load_labels
from z_dataops import transform
import json
from torch import nn
def load_model(location="model/rnn.pth"):
'''loads the model, together with arch'''
model = torch.load(location, weights_only=False)
return model
def infer_lang(name:str, model, label:dict, k=3)-> str:
name_tensor = transform(name)
with torch.no_grad():
logits = model(name_tensor.unsqueeze(0))
y_pred = nn.Softmax(dim=1)(logits)
top_k_idx = y_pred.sort(descending=True, dim=1).indices.numpy()[0][:k]
return [label[str(idx)] for idx in top_k_idx]
def setup_inference():
# load model
model = load_model()
# call the model with inputs
labels = load_labels()
return model, labels
if __name__=="__main__":
model, labels = setup_inference()
|