File size: 2,761 Bytes
1f7d212
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b9877
1f7d212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d64d36
 
 
 
1f7d212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw

def preprocess_image(image):
    """
    预处理输入图片
    """
    # 将图片调整为模型所需的输入尺寸
    image = image.resize((640, 640))  # 根据实际模型需求调整尺寸
    # 转换为numpy数组并归一化
    img_array = np.array(image)
    img_array = img_array.astype(np.float32) / 255.0
    # 添加批次维度
    img_array = np.expand_dims(img_array, axis=0)
    # 根据模型训练时的预处理方式进行调整
    img_array = img_array.transpose(0, 3, 1, 2)  # BHWC to BCHW
    return img_array

def visualize_molecule(smiles):
    """
    使用RDKit将SMILES转换为分子结构图
    """
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        img = Draw.MolToImage(mol)
        return img
    except:
        return None

def predict(input_image):
    """
    主要的推理函数
    """
    try:
        # 加载和初始化ONNX模型
        session = ort.InferenceSession("model.onnx")  # 替换为实际模型路径
        
        # 预处理图片
        processed_image = preprocess_image(input_image)
        
        # 获取模型输入输出名称
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        # 进行推理
        predictions = session.run([output_name], {input_name: processed_image})
        
        # 假设模型输出是SMILES字符串
        output = predictions[0]  # 根据实际模型输出格式调整
        atoms_df, bonds_list,charge_list =bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
                                                    bond_labels=bond_labels,  result=[])
        smiles,mol_rebuit=mol_from_graph_with_chiral(atoms_df, bonds_list,charge_list )
        
        # 使用RDKit生成分子结构图
        mol_image = visualize_molecule(smiles)
        
        if mol_image is None:
            return "无效的SMILES字符串", None
        
        return smiles, mol_image
        
    except Exception as e:
        return f"发生错误: {str(e)}", None

# 创建Gradio界面
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Text(label="SMILES字符串"),
        gr.Image(label="分子结构图")
    ],
    title="化学结构OCR",
    description="上传一张包含化学结构的图片,获取对应的SMILES表示和分子结构图。",
    examples=[
        ["example1.jpg"],  # 添加示例图片
        ["example2.jpg"]
    ]
)

# 启动应用
if __name__ == "__main__":
    iface.launch()