OCSR / app.py
jibsn's picture
Update app.py
3062f33 verified
raw
history blame
3.64 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
import io
import rdkit
import cv2
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from postprocessor import RTDETRPostProcessor
from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral
bond_labels = [13,14,15,16,17]
idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ?
def image_to_numpy(image_path):
w, h = image_path.size
img_array = np.array(image_path)
img_resized = cv2.resize(img_array, (640, 640), interpolation=cv2.INTER_LINEAR)
img_float = img_resized.astype(np.float32)
img_normalized = img_float / 255.0
if len(img_normalized.shape) == 3:
img_normalized = img_normalized.transpose(2, 0, 1)
return img_normalized, w, h
def visualize_molecule(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
img = Draw.MolToImage(mol)
return img
except:
return None
def predict(input_image):
try:
session = ort.InferenceSession("model.onnx")
img_array,w,h = image_to_numpy(input_image)
processed_image=np.expand_dims(img_array, 0)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run(None, {input_name: processed_image})
preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])}
ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
postprocessor = RTDETRPostProcessor(num_classes=23, use_focal_loss=True)
result_ = postprocessor(preds, ori_size)
score_=result_[0]['scores']
boxe_=result_[0]['boxes']
label_=result_[0]['labels']
selected_indices =score_ > 0.5
output={
'labels': label_[selected_indices],
'boxes': boxe_[selected_indices],
'scores': score_[selected_indices]
}
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
center_coords = torch.stack((x_center, y_center), dim=1)
output = {'bbox': output["boxes"].to("cpu").numpy(),
'bbox_centers': center_coords.to("cpu").numpy(),
'scores': output["scores"].to("cpu").numpy(),
'pred_classes': output["labels"].to("cpu").numpy()}
atoms_df, bonds_list = bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
bond_labels=bond_labels, result=[])
smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list)
mol_image = visualize_molecule(smiles)
if mol_image is None:
return "Invalid SMILES", None
return mol_image, smiles
except Exception as e:
return f"Error: {str(e)}", None
iface = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False),
outputs=[
gr.Image(label="Prediction"),
gr.Text(label="SMILES"),
],
title="OCSR",
description="Convert a molecular image into SMILES.<br> ",
examples=[
["example.png"]
]
)
if __name__ == "__main__":
iface.launch()