File size: 847 Bytes
2c1ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from z_modelops import NameToLanguages, load_labels
from z_dataops import transform
import json
from torch import nn

def load_model(location="model/rnn.pth"): 
    '''loads the model, together with arch'''
    model = torch.load(location, weights_only=False) 
    return model

def infer_lang(name:str, model, label:dict, k=3)-> str:
    name_tensor = transform(name)
    with torch.no_grad():
        logits = model(name_tensor.unsqueeze(0))
        y_pred = nn.Softmax(dim=1)(logits)
    top_k_idx = y_pred.sort(descending=True, dim=1).indices.numpy()[0][:k]
    return [label[str(idx)] for idx in top_k_idx]

def setup_inference():
    # load model 
    model = load_model()
    # call the model with inputs 
    labels = load_labels()
    return model, labels


if __name__=="__main__":
    model, labels = setup_inference()