File size: 3,676 Bytes
83dd5ce 2c7e99c 1f7d212 3ecd782 1f7d212 2c7e99c 4087a58 1f7d212 3ecd782 bfa9c80 1f7d212 ed788b2 3ecd782 6d108e4 2666f2a 3ecd782 2666f2a 6d108e4 3ecd782 6d108e4 3ecd782 1f7d212 1d97607 1f7d212 1d97607 1f7d212 2387726 1f7d212 6d108e4 1d97607 1f7d212 1d97607 6d108e4 3ecd782 6d108e4 3ecd782 1f7d212 3ecd782 6d108e4 7d64d36 6d108e4 1f7d212 1d97607 1f7d212 a87bffc 1f7d212 1d97607 1f7d212 cf18899 1f7d212 cf18899 3062f33 1f7d212 1d97607 1f7d212 f7158ef 1f7d212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
import io
import rdkit
import cv2
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from postprocessor import RTDETRPostProcessor
from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral
bond_labels = [13,14,15,16,17]
idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ?
def image_to_numpy(image_path):
w, h = image_path.size
img_array = np.array(image_path)
img_resized = cv2.resize(img_array, (640, 640), interpolation=cv2.INTER_LINEAR)
img_float = img_resized.astype(np.float32)
img_normalized = img_float / 255.0
if len(img_normalized.shape) == 3:
img_normalized = img_normalized.transpose(2, 0, 1)
return img_normalized, w, h
def visualize_molecule(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
img = Draw.MolToImage(mol)
return img
except:
return None
def predict(input_image):
try:
session = ort.InferenceSession("model.onnx")
img_array,w,h = image_to_numpy(input_image)
processed_image=np.expand_dims(img_array, 0)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run(None, {input_name: processed_image})
preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])}
ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
postprocessor = RTDETRPostProcessor(num_classes=23, use_focal_loss=True)
result_ = postprocessor(preds, ori_size)
score_=result_[0]['scores']
boxe_=result_[0]['boxes']
label_=result_[0]['labels']
selected_indices =score_ > 0.5
output={
'labels': label_[selected_indices],
'boxes': boxe_[selected_indices],
'scores': score_[selected_indices]
}
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
center_coords = torch.stack((x_center, y_center), dim=1)
output = {'bbox': output["boxes"].to("cpu").numpy(),
'bbox_centers': center_coords.to("cpu").numpy(),
'scores': output["scores"].to("cpu").numpy(),
'pred_classes': output["labels"].to("cpu").numpy()}
atoms_df, bonds_list = bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
bond_labels=bond_labels, result=[])
smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list)
mol_image = visualize_molecule(smiles)
if mol_image is None:
return "Invalid SMILES", None
return mol_image, smiles
# return mol_rebuit, smiles
except Exception as e:
return f"Error: {str(e)}", None
iface = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False),
outputs=[
gr.Image(label="Prediction"),
gr.Text(label="SMILES"),
],
title="OCSR",
description="Convert a molecular image into SMILES.<br> ",
examples=[
["example.png"]
]
)
if __name__ == "__main__":
iface.launch() |