you_might_speak / z_inference.py
Deepak Sahu
training; app
2c1ff7f
raw
history blame contribute delete
847 Bytes
import torch
from z_modelops import NameToLanguages, load_labels
from z_dataops import transform
import json
from torch import nn
def load_model(location="model/rnn.pth"):
'''loads the model, together with arch'''
model = torch.load(location, weights_only=False)
return model
def infer_lang(name:str, model, label:dict, k=3)-> str:
name_tensor = transform(name)
with torch.no_grad():
logits = model(name_tensor.unsqueeze(0))
y_pred = nn.Softmax(dim=1)(logits)
top_k_idx = y_pred.sort(descending=True, dim=1).indices.numpy()[0][:k]
return [label[str(idx)] for idx in top_k_idx]
def setup_inference():
# load model
model = load_model()
# call the model with inputs
labels = load_labels()
return model, labels
if __name__=="__main__":
model, labels = setup_inference()