|
import torch
|
|
import gradio as gr
|
|
from utils import create_vocab, setup_seed
|
|
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
|
setup_seed(4)
|
|
device = torch.device("cpu")
|
|
vocab_mlm = create_vocab()
|
|
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
|
save_path = 'mlm-model-27.pt'
|
|
model = torch.load(save_path)
|
|
model = model.to(device)
|
|
|
|
def CTXGen(X1, X2, X3, top_k):
|
|
predicted_token_probability_all = []
|
|
model.eval()
|
|
topk = []
|
|
with torch.no_grad():
|
|
new_seq = None
|
|
seq = [f"{X1}|{X2}|{X3}|||"]
|
|
vocab_mlm.token_to_idx["X"] = 4
|
|
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
|
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
|
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
|
|
if not mask_positions:
|
|
raise ValueError("Nothing found in the sequence to predict.")
|
|
|
|
for mask_position in mask_positions:
|
|
padded_seq[mask_position] = "[MASK]"
|
|
input_ids = vocab_mlm.__getitem__(padded_seq)
|
|
input_ids = torch.tensor([input_ids]).to(device)
|
|
logits = model(input_ids, idx_msa)
|
|
mask_logits = logits[0, mask_position, :]
|
|
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=top_k)
|
|
topk.append(predicted_token_id)
|
|
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
|
|
predicted_token_probability_all.append(predicted_token_probability[0].item())
|
|
padded_seq[mask_position] = predicted_token
|
|
|
|
cls_pos = vocab_mlm.to_tokens(list(topk[0]))
|
|
Topk = cls_pos
|
|
if X1 != "X":
|
|
Subtype = X1
|
|
Potency = padded_seq[2],predicted_token_probability_all[0]
|
|
elif X2 != "X":
|
|
Subtype = padded_seq[1],predicted_token_probability_all[0]
|
|
Potency = X2
|
|
else:
|
|
Subtype = padded_seq[1],predicted_token_probability_all[0]
|
|
Potency = padded_seq[2],predicted_token_probability_all[1]
|
|
return Subtype, Potency, Topk
|
|
|
|
iface = gr.Interface(fn=CTXGen,
|
|
inputs=["text", "text", "text", "text"],
|
|
outputs= ["text", "text", "text"])
|
|
iface.launch() |