OCSR / app.py
jibsn's picture
Update app.py
7d64d36 verified
raw
history blame
2.76 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
def preprocess_image(image):
"""
预处理输入图片
"""
# 将图片调整为模型所需的输入尺寸
image = image.resize((640, 640)) # 根据实际模型需求调整尺寸
# 转换为numpy数组并归一化
img_array = np.array(image)
img_array = img_array.astype(np.float32) / 255.0
# 添加批次维度
img_array = np.expand_dims(img_array, axis=0)
# 根据模型训练时的预处理方式进行调整
img_array = img_array.transpose(0, 3, 1, 2) # BHWC to BCHW
return img_array
def visualize_molecule(smiles):
"""
使用RDKit将SMILES转换为分子结构图
"""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
img = Draw.MolToImage(mol)
return img
except:
return None
def predict(input_image):
"""
主要的推理函数
"""
try:
# 加载和初始化ONNX模型
session = ort.InferenceSession("model.onnx") # 替换为实际模型路径
# 预处理图片
processed_image = preprocess_image(input_image)
# 获取模型输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 进行推理
predictions = session.run([output_name], {input_name: processed_image})
# 假设模型输出是SMILES字符串
output = predictions[0] # 根据实际模型输出格式调整
atoms_df, bonds_list,charge_list =bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
bond_labels=bond_labels, result=[])
smiles,mol_rebuit=mol_from_graph_with_chiral(atoms_df, bonds_list,charge_list )
# 使用RDKit生成分子结构图
mol_image = visualize_molecule(smiles)
if mol_image is None:
return "无效的SMILES字符串", None
return smiles, mol_image
except Exception as e:
return f"发生错误: {str(e)}", None
# 创建Gradio界面
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Text(label="SMILES字符串"),
gr.Image(label="分子结构图")
],
title="化学结构OCR",
description="上传一张包含化学结构的图片,获取对应的SMILES表示和分子结构图。",
examples=[
["example1.jpg"], # 添加示例图片
["example2.jpg"]
]
)
# 启动应用
if __name__ == "__main__":
iface.launch()