|
import torch |
|
import gradio as gr |
|
from utils import create_vocab, setup_seed |
|
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab |
|
setup_seed(4) |
|
|
|
def CTXGen(X1,X2,X3,model_name): |
|
device = torch.device("cpu") |
|
vocab_mlm = create_vocab() |
|
vocab_mlm = add_tokens_to_vocab(vocab_mlm) |
|
save_path = model_name |
|
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu')) |
|
model = model.to(device) |
|
|
|
predicted_token_probability_all = [] |
|
model.eval() |
|
topk = [] |
|
with torch.no_grad(): |
|
new_seq = None |
|
seq = [f"{X1}|{X2}|{X3}|||"] |
|
vocab_mlm.token_to_idx["X"] = 4 |
|
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) |
|
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) |
|
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"] |
|
if not mask_positions: |
|
raise ValueError("Nothing found in the sequence to predict.") |
|
|
|
for mask_position in mask_positions: |
|
padded_seq[mask_position] = "[MASK]" |
|
input_ids = vocab_mlm.__getitem__(padded_seq) |
|
input_ids = torch.tensor([input_ids]).to(device) |
|
logits = model(input_ids, idx_msa) |
|
mask_logits = logits[0, mask_position, :] |
|
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5) |
|
topk.append(predicted_token_id) |
|
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()] |
|
predicted_token_probability_all.append(predicted_token_probability[0].item()) |
|
padded_seq[mask_position] = predicted_token |
|
|
|
cls_pos = vocab_mlm.to_tokens(list(topk[0])) |
|
if X1 != "X": |
|
Topk = X1 |
|
Subtype = X1 |
|
Potency = padded_seq[2],predicted_token_probability_all[0] |
|
elif X2 != "X": |
|
Topk = cls_pos |
|
Subtype = padded_seq[1],predicted_token_probability_all[0] |
|
Potency = X2 |
|
else: |
|
Topk = cls_pos |
|
Subtype = padded_seq[1],predicted_token_probability_all[0] |
|
Potency = padded_seq[2],predicted_token_probability_all[1] |
|
return Subtype, Potency, Topk |
|
|
|
iface = gr.Interface( |
|
fn=CTXGen, |
|
inputs=[ |
|
gr.Dropdown(choices=['X','<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>', |
|
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', |
|
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', |
|
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>', |
|
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', |
|
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype"), |
|
gr.Dropdown(choices=['X','<high>','low'], label="Potency"), |
|
gr.Textbox(label="Conotoxin"), |
|
gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Subtype"), |
|
gr.Textbox(label="Potency"), |
|
gr.Textbox(label="Top5") |
|
] |
|
) |
|
iface.launch() |