import gradio as gr import onnxruntime as ort import numpy as np from PIL import Image from torchvision import transforms import io import rdkit import cv2 import torch from rdkit import Chem from rdkit.Chem import Draw from postprocessor import RTDETRPostProcessor from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral bond_labels = [13,14,15,16,17] idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', 9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ? def image_to_numpy(image_path): w, h = image_path.size img_array = np.array(image_path) img_resized = cv2.resize(img_array, (640, 640), interpolation=cv2.INTER_LINEAR) img_float = img_resized.astype(np.float32) img_normalized = img_float / 255.0 if len(img_normalized.shape) == 3: img_normalized = img_normalized.transpose(2, 0, 1) return img_normalized, w, h def visualize_molecule(smiles): try: mol = Chem.MolFromSmiles(smiles) if mol is None: return None img = Draw.MolToImage(mol) return img except: return None def predict(input_image): try: session = ort.InferenceSession("model.onnx") img_array,w,h = image_to_numpy(input_image) processed_image=np.expand_dims(img_array, 0) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name outputs = session.run(None, {input_name: processed_image}) preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])} ori_size=torch.Tensor([w,h]).long().unsqueeze(0) postprocessor = RTDETRPostProcessor(num_classes=23, use_focal_loss=True) result_ = postprocessor(preds, ori_size) score_=result_[0]['scores'] boxe_=result_[0]['boxes'] label_=result_[0]['labels'] selected_indices =score_ > 0.5 output={ 'labels': label_[selected_indices], 'boxes': boxe_[selected_indices], 'scores': score_[selected_indices] } x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2 y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2 center_coords = torch.stack((x_center, y_center), dim=1) output = {'bbox': output["boxes"].to("cpu").numpy(), 'bbox_centers': center_coords.to("cpu").numpy(), 'scores': output["scores"].to("cpu").numpy(), 'pred_classes': output["labels"].to("cpu").numpy()} atoms_df, bonds_list = bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels, bond_labels=bond_labels, result=[]) smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list) mol_image = visualize_molecule(smiles) if mol_image is None: return "Invalid SMILES", None return mol_image, smiles # return mol_rebuit, smiles except Exception as e: return f"Error: {str(e)}", None iface = gr.Interface( fn=predict, inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False), outputs=[ gr.Image(label="Prediction"), gr.Text(label="SMILES"), ], title="OCSR", description="Convert a molecular image into SMILES.
", examples=[ ["example.png"] ] ) if __name__ == "__main__": iface.launch()