jibsn commited on
Commit
1d97607
·
verified ·
1 Parent(s): 8b15fa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -28
app.py CHANGED
@@ -20,7 +20,6 @@ idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
20
 
21
  def image_to_numpy(image_path):
22
 
23
- # image = Image.open(image_path)
24
  w, h = image_path.size
25
 
26
  img_array = np.array(image_path)
@@ -37,9 +36,7 @@ def image_to_numpy(image_path):
37
 
38
 
39
  def visualize_molecule(smiles):
40
- """
41
- 使用RDKit将SMILES转换为分子结构图
42
- """
43
  try:
44
  mol = Chem.MolFromSmiles(smiles)
45
  if mol is None:
@@ -50,23 +47,16 @@ def visualize_molecule(smiles):
50
  return None
51
 
52
  def predict(input_image):
53
- """
54
- 主要的推理函数
55
- """
56
  try:
57
- # 加载和初始化ONNX模型
58
  session = ort.InferenceSession("model.onnx")
59
 
60
- # 预处理图片
61
- # Example usage: #change thie image
62
  img_array,w,h = image_to_numpy(input_image)
63
  processed_image=np.expand_dims(img_array, 0)
64
-
65
- # 获取模型输入输出名称
66
  input_name = session.get_inputs()[0].name
67
  output_name = session.get_outputs()[0].name
68
-
69
- # 进行推理
70
  outputs = session.run(None, {input_name: processed_image})
71
  preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])}
72
  ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
@@ -82,10 +72,6 @@ def predict(input_image):
82
  'scores': score_[selected_indices]
83
  }
84
 
85
- # filtered_output_dict={image_path: output
86
- # }
87
-
88
-
89
  x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
90
  y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
91
  center_coords = torch.stack((x_center, y_center), dim=1)
@@ -98,32 +84,30 @@ def predict(input_image):
98
  bond_labels=bond_labels, result=[])
99
  smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list)
100
 
101
- # 使用RDKit生成分子结构图
102
  mol_image = visualize_molecule(smiles)
103
 
104
  if mol_image is None:
105
- return "无效的SMILES字符串", None
106
 
107
  return smiles, mol_image
108
 
109
  except Exception as e:
110
- return f"发生错误: {str(e)}", None
111
 
112
- # 创建Gradio界面
113
  iface = gr.Interface(
114
  fn=predict,
115
- inputs=gr.Image(type="pil"),
116
  outputs=[
117
- gr.Text(label="SMILES字符串"),
118
- gr.Image(label="分子结构图")
 
119
  ],
120
- title="化学结构OCR",
121
- description="上传一张包含化学结构的图片,获取对应的SMILES表示和分子结构图。",
122
  examples=[
123
  ["example.png"]
124
  ]
125
  )
126
 
127
- # 启动应用
128
  if __name__ == "__main__":
129
  iface.launch()
 
20
 
21
  def image_to_numpy(image_path):
22
 
 
23
  w, h = image_path.size
24
 
25
  img_array = np.array(image_path)
 
36
 
37
 
38
  def visualize_molecule(smiles):
39
+
 
 
40
  try:
41
  mol = Chem.MolFromSmiles(smiles)
42
  if mol is None:
 
47
  return None
48
 
49
  def predict(input_image):
50
+
 
 
51
  try:
 
52
  session = ort.InferenceSession("model.onnx")
53
 
 
 
54
  img_array,w,h = image_to_numpy(input_image)
55
  processed_image=np.expand_dims(img_array, 0)
56
+
 
57
  input_name = session.get_inputs()[0].name
58
  output_name = session.get_outputs()[0].name
59
+
 
60
  outputs = session.run(None, {input_name: processed_image})
61
  preds = {'pred_logits':torch.from_numpy(outputs[0]), 'pred_boxes':torch.from_numpy(outputs[1])}
62
  ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
 
72
  'scores': score_[selected_indices]
73
  }
74
 
 
 
 
 
75
  x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
76
  y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
77
  center_coords = torch.stack((x_center, y_center), dim=1)
 
84
  bond_labels=bond_labels, result=[])
85
  smiles, mol_rebuit = mol_from_graph_with_chiral(atoms_df, bonds_list)
86
 
 
87
  mol_image = visualize_molecule(smiles)
88
 
89
  if mol_image is None:
90
+ return "Invalid SMILES", None
91
 
92
  return smiles, mol_image
93
 
94
  except Exception as e:
95
+ return f"Error: {str(e)}", None
96
 
 
97
  iface = gr.Interface(
98
  fn=predict,
99
+ inputs=gr.Image(label="Upload molecular image", type="pil", show_label=False).style(height=256),
100
  outputs=[
101
+ gr.Image(label="Prediction").style(height=256),
102
+ gr.Text(label="SMILES").style(show_copy_button=True),
103
+ gr.Textbox(label="Molfile").style(show_copy_button=True),
104
  ],
105
+ title="OCSR",
106
+ description="Convert a molecular image into SMILES.<br> ",
107
  examples=[
108
  ["example.png"]
109
  ]
110
  )
111
 
 
112
  if __name__ == "__main__":
113
  iface.launch()