|
|
import gradio as gr |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import rdkit |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import Draw |
|
|
|
|
|
def preprocess_image(image): |
|
|
""" |
|
|
预处理输入图片 |
|
|
""" |
|
|
|
|
|
image = image.resize((640, 640)) |
|
|
|
|
|
img_array = np.array(image) |
|
|
img_array = img_array.astype(np.float32) / 255.0 |
|
|
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
|
|
|
img_array = img_array.transpose(0, 3, 1, 2) |
|
|
return img_array |
|
|
|
|
|
def visualize_molecule(smiles): |
|
|
""" |
|
|
使用RDKit将SMILES转换为分子结构图 |
|
|
""" |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
img = Draw.MolToImage(mol) |
|
|
return img |
|
|
except: |
|
|
return None |
|
|
|
|
|
def predict(input_image): |
|
|
""" |
|
|
主要的推理函数 |
|
|
""" |
|
|
try: |
|
|
|
|
|
session = ort.InferenceSession("model.onnx") |
|
|
|
|
|
|
|
|
processed_image = preprocess_image(input_image) |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
output_name = session.get_outputs()[0].name |
|
|
|
|
|
|
|
|
predictions = session.run([output_name], {input_name: processed_image}) |
|
|
|
|
|
|
|
|
output = predictions[0] |
|
|
atoms_df, bonds_list,charge_list =bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels, |
|
|
bond_labels=bond_labels, result=[]) |
|
|
smiles,mol_rebuit=mol_from_graph_with_chiral(atoms_df, bonds_list,charge_list ) |
|
|
|
|
|
|
|
|
mol_image = visualize_molecule(smiles) |
|
|
|
|
|
if mol_image is None: |
|
|
return "无效的SMILES字符串", None |
|
|
|
|
|
return smiles, mol_image |
|
|
|
|
|
except Exception as e: |
|
|
return f"发生错误: {str(e)}", None |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[ |
|
|
gr.Text(label="SMILES字符串"), |
|
|
gr.Image(label="分子结构图") |
|
|
], |
|
|
title="化学结构OCR", |
|
|
description="上传一张包含化学结构的图片,获取对应的SMILES表示和分子结构图。", |
|
|
examples=[ |
|
|
["example1.jpg"], |
|
|
["example2.jpg"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |