File size: 3,676 Bytes
83dd5ce
2c7e99c
 
1f7d212
3ecd782
1f7d212
 
2c7e99c
4087a58
1f7d212
 
3ecd782
bfa9c80
1f7d212
ed788b2
 
 
 
 
3ecd782
6d108e4
 
2666f2a
3ecd782
2666f2a
6d108e4
 
 
 
 
 
 
 
3ecd782
6d108e4
3ecd782
1f7d212
 
1d97607
1f7d212
 
 
 
 
 
 
 
 
 
1d97607
1f7d212
2387726
1f7d212
6d108e4
 
1d97607
1f7d212
 
1d97607
6d108e4
 
3ecd782
6d108e4
 
3ecd782
 
 
 
 
 
 
 
 
1f7d212
3ecd782
 
 
 
 
 
 
 
6d108e4
7d64d36
6d108e4
1f7d212
 
 
 
1d97607
1f7d212
a87bffc
 
1f7d212
 
1d97607
1f7d212
 
 
cf18899
1f7d212
cf18899
3062f33
1f7d212
1d97607
 
1f7d212
f7158ef
1f7d212
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
import io
import rdkit
import cv2
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from postprocessor import RTDETRPostProcessor
from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral

bond_labels = [13,14,15,16,17]
idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
            9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
            16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ?


def image_to_numpy(image_path):

    w, h = image_path.size
    
    img_array = np.array(image_path)
    
    img_resized = cv2.resize(img_array, (640, 640), interpolation=cv2.INTER_LINEAR)
    
    img_float = img_resized.astype(np.float32)
    img_normalized = img_float / 255.0
    
    if len(img_normalized.shape) == 3:
        img_normalized = img_normalized.transpose(2, 0, 1)
    
    return img_normalized, w, h


def visualize_molecule(smiles):

    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        img = Draw.MolToImage(mol)
        return img
    except:
        return None

def predict(input_image):

    try:
        session = ort.InferenceSession("model.onnx")
        
        img_array,w,h = image_to_numpy(input_image)
        processed_image=np.expand_dims(img_array, 0)

        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name

        outputs = session.run(None, {input_name: processed_image})
        preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])}
        ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
        postprocessor = RTDETRPostProcessor(num_classes=23, use_focal_loss=True)
        result_ = postprocessor(preds, ori_size)
        score_=result_[0]['scores']
        boxe_=result_[0]['boxes']
        label_=result_[0]['labels']
        selected_indices =score_ > 0.5
        output={
            'labels': label_[selected_indices],
            'boxes': boxe_[selected_indices],
            'scores': score_[selected_indices]
        }
        
        x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
        y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
        center_coords = torch.stack((x_center, y_center), dim=1)
        output = {'bbox':         output["boxes"].to("cpu").numpy(),
                    'bbox_centers': center_coords.to("cpu").numpy(),
                    'scores':       output["scores"].to("cpu").numpy(),
                    'pred_classes': output["labels"].to("cpu").numpy()}

        atoms_df, bonds_list = bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
                                                    bond_labels=bond_labels,  result=[])
        smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list)
        
        mol_image = visualize_molecule(smiles)
        
        if mol_image is None:
            return "Invalid SMILES", None
        
        return mol_image, smiles
        # return mol_rebuit, smiles
        
    except Exception as e:
        return f"Error: {str(e)}", None

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False),
    outputs=[
        gr.Image(label="Prediction"),
        gr.Text(label="SMILES"),
    ],
    title="OCSR",
    description="Convert a molecular image into SMILES.<br> ",
    examples=[
        ["example.png"]
    ]
)

if __name__ == "__main__":
    iface.launch()