|
import gradio as gr |
|
import onnxruntime as ort |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms |
|
import io |
|
import rdkit |
|
import cv2 |
|
import torch |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from postprocessor import RTDETRPostProcessor |
|
from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral |
|
|
|
bond_labels = [13,14,15,16,17] |
|
idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', |
|
9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', |
|
16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} |
|
|
|
|
|
def image_to_numpy(image_path): |
|
|
|
w, h = image_path.size |
|
|
|
img_array = np.array(image_path) |
|
|
|
img_resized = cv2.resize(img_array, (640, 640), interpolation=cv2.INTER_LINEAR) |
|
|
|
img_float = img_resized.astype(np.float32) |
|
img_normalized = img_float / 255.0 |
|
|
|
if len(img_normalized.shape) == 3: |
|
img_normalized = img_normalized.transpose(2, 0, 1) |
|
|
|
return img_normalized, w, h |
|
|
|
|
|
def visualize_molecule(smiles): |
|
|
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return None |
|
img = Draw.MolToImage(mol) |
|
return img |
|
except: |
|
return None |
|
|
|
def predict(input_image): |
|
|
|
try: |
|
session = ort.InferenceSession("model.onnx") |
|
|
|
img_array,w,h = image_to_numpy(input_image) |
|
processed_image=np.expand_dims(img_array, 0) |
|
|
|
input_name = session.get_inputs()[0].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
outputs = session.run(None, {input_name: processed_image}) |
|
preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])} |
|
ori_size=torch.Tensor([w,h]).long().unsqueeze(0) |
|
postprocessor = RTDETRPostProcessor(num_classes=23, use_focal_loss=True) |
|
result_ = postprocessor(preds, ori_size) |
|
score_=result_[0]['scores'] |
|
boxe_=result_[0]['boxes'] |
|
label_=result_[0]['labels'] |
|
selected_indices =score_ > 0.5 |
|
output={ |
|
'labels': label_[selected_indices], |
|
'boxes': boxe_[selected_indices], |
|
'scores': score_[selected_indices] |
|
} |
|
|
|
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2 |
|
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2 |
|
center_coords = torch.stack((x_center, y_center), dim=1) |
|
output = {'bbox': output["boxes"].to("cpu").numpy(), |
|
'bbox_centers': center_coords.to("cpu").numpy(), |
|
'scores': output["scores"].to("cpu").numpy(), |
|
'pred_classes': output["labels"].to("cpu").numpy()} |
|
|
|
atoms_df, bonds_list = bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels, |
|
bond_labels=bond_labels, result=[]) |
|
smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list) |
|
|
|
mol_image = visualize_molecule(smiles) |
|
|
|
if mol_image is None: |
|
return "Invalid SMILES", None |
|
|
|
return mol_image, smiles |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", None |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False), |
|
outputs=[ |
|
gr.Image(label="Prediction"), |
|
gr.Text(label="SMILES"), |
|
], |
|
title="OCSR", |
|
description="Convert a molecular image into SMILES.<br> ", |
|
examples=[ |
|
["example.png"] |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |