jibsn commited on
Commit
1f7d212
·
verified ·
1 Parent(s): 2120316

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from PIL import Image
5
+ import io
6
+ import rdkit
7
+ from rdkit import Chem
8
+ from rdkit.Chem import Draw
9
+
10
+ def preprocess_image(image):
11
+ """
12
+ 预处理输入图片
13
+ """
14
+ # 将图片调整为模型所需的输入尺寸
15
+ image = image.resize((224, 224)) # 根据实际模型需求调整尺寸
16
+ # 转换为numpy数组并归一化
17
+ img_array = np.array(image)
18
+ img_array = img_array.astype(np.float32) / 255.0
19
+ # 添加批次维度
20
+ img_array = np.expand_dims(img_array, axis=0)
21
+ # 根据模型训练时的预处理方式进行调整
22
+ img_array = img_array.transpose(0, 3, 1, 2) # BHWC to BCHW
23
+ return img_array
24
+
25
+ def visualize_molecule(smiles):
26
+ """
27
+ 使用RDKit将SMILES转换为分子结构图
28
+ """
29
+ try:
30
+ mol = Chem.MolFromSmiles(smiles)
31
+ if mol is None:
32
+ return None
33
+ img = Draw.MolToImage(mol)
34
+ return img
35
+ except:
36
+ return None
37
+
38
+ def predict(input_image):
39
+ """
40
+ 主要的推理函数
41
+ """
42
+ try:
43
+ # 加载和初始化ONNX模型
44
+ session = ort.InferenceSession("model.onnx") # 替换为实际模型路径
45
+
46
+ # 预处理图片
47
+ processed_image = preprocess_image(input_image)
48
+
49
+ # 获取模型输入输出名称
50
+ input_name = session.get_inputs()[0].name
51
+ output_name = session.get_outputs()[0].name
52
+
53
+ # 进行推理
54
+ predictions = session.run([output_name], {input_name: processed_image})
55
+
56
+ # 假设模型输出是SMILES字符串
57
+ smiles = predictions[0][0] # 根据实际模型输出格式调整
58
+
59
+ # 使用RDKit生成分子结构图
60
+ mol_image = visualize_molecule(smiles)
61
+
62
+ if mol_image is None:
63
+ return "无效的SMILES字符串", None
64
+
65
+ return smiles, mol_image
66
+
67
+ except Exception as e:
68
+ return f"发生错误: {str(e)}", None
69
+
70
+ # 创建Gradio界面
71
+ iface = gr.Interface(
72
+ fn=predict,
73
+ inputs=gr.Image(type="pil"),
74
+ outputs=[
75
+ gr.Text(label="SMILES字符串"),
76
+ gr.Image(label="分子结构图")
77
+ ],
78
+ title="化学结构OCR",
79
+ description="上传一张包含化学结构的图片,获取对应的SMILES表示和分子结构图。",
80
+ examples=[
81
+ ["example1.jpg"], # 添加示例图片
82
+ ["example2.jpg"]
83
+ ]
84
+ )
85
+
86
+ # 启动应用
87
+ if __name__ == "__main__":
88
+ iface.launch()