|
import copy |
|
import json |
|
import math |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from scipy.spatial import cKDTree |
|
from rdkit import Chem |
|
from rdkit.Chem import RWMol |
|
from rdkit.Chem import Draw, AllChem |
|
from rdkit.Chem import rdDepictor |
|
import matplotlib.pyplot as plt |
|
import re |
|
|
|
def output_to_smiles(output,idx_to_labels,bond_labels,result): |
|
|
|
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2 |
|
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2 |
|
|
|
center_coords = torch.stack((x_center, y_center), dim=1) |
|
|
|
output = {'bbox': output["boxes"].to("cpu").numpy(), |
|
'bbox_centers': center_coords.to("cpu").numpy(), |
|
'scores': output["scores"].to("cpu").numpy(), |
|
'pred_classes': output["labels"].to("cpu").numpy()} |
|
|
|
|
|
atoms_list, bonds_list = bbox_to_graph_with_charge(output, |
|
idx_to_labels=idx_to_labels, |
|
bond_labels=bond_labels, |
|
result=result) |
|
|
|
return mol_from_graph_with_chiral(atoms_list, bonds_list) |
|
|
|
|
|
def bbox_to_graph(output, idx_to_labels, bond_labels,result): |
|
|
|
|
|
atoms_mask = np.array([True if ins not in bond_labels else False for ins in output['pred_classes']]) |
|
|
|
|
|
atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
atoms_list = pd.DataFrame({'atom': atoms_list, |
|
'x': output['bbox_centers'][atoms_mask, 0], |
|
'y': output['bbox_centers'][atoms_mask, 1]}) |
|
|
|
|
|
for idx, row in atoms_list.iterrows(): |
|
if row.atom[-1] != '0': |
|
if row.atom[-2] != '-': |
|
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])] |
|
else: |
|
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])] |
|
|
|
kdt = cKDTree(overlapping[['x', 'y']]) |
|
dists, neighbours = kdt.query([row.x, row.y], k=2) |
|
if dists[1] < 7: |
|
atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True) |
|
|
|
bonds_list = [] |
|
|
|
|
|
for bbox, bond_type, score in zip(output['bbox'][np.logical_not(atoms_mask)], |
|
output['pred_classes'][np.logical_not(atoms_mask)], |
|
output['scores'][np.logical_not(atoms_mask)]): |
|
|
|
|
|
if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: |
|
_margin = 5 |
|
else: |
|
_margin = 8 |
|
|
|
|
|
anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1]) |
|
oposite_anchor_positions = anchor_positions.copy() |
|
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1] |
|
|
|
|
|
|
|
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions]) |
|
|
|
|
|
atoms_pos = atoms_list[['x', 'y']].values |
|
kdt = cKDTree(atoms_pos) |
|
dists, neighbours = kdt.query(anchor_positions, k=1) |
|
|
|
|
|
if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0: |
|
|
|
begin_idx, end_idx = neighbours[:2] |
|
else: |
|
|
|
begin_idx, end_idx = neighbours[2:] |
|
|
|
|
|
if begin_idx != end_idx: |
|
bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score)) |
|
else: |
|
continue |
|
|
|
return atoms_list, bonds_list |
|
|
|
|
|
def calculate_distance(coord1, coord2): |
|
|
|
return math.sqrt((coord1[0] - coord2[0])**2 + (coord1[1] - coord2[1])**2) |
|
|
|
def assemble_atoms_with_charges(atom_list, charge_list): |
|
used_charge_indices=set() |
|
atom_list['atom'] = atom_list['atom'] + '0' |
|
kdt = cKDTree(atom_list[['x','y']]) |
|
for i, charge in charge_list.iterrows(): |
|
if i in used_charge_indices: |
|
continue |
|
charge_=charge['charge'] |
|
if charge_=='1':charge_='+' |
|
dist, idx_atom=kdt.query([charge_list.x[i],charge_list.y[i]], k=1) |
|
atom_str=atom_list.loc[idx_atom,'atom'] |
|
atom_ = re.findall(r'[A-Za-z]+', atom_str)[0] + charge_ |
|
atom_list.loc[idx_atom,'atom']=atom_ |
|
|
|
return atom_list |
|
|
|
|
|
|
|
def assemble_atoms_with_charges2(atom_list, charge_list, max_distance=10): |
|
used_charge_indices = set() |
|
|
|
for idx, atom in atom_list.iterrows(): |
|
atom_coord = atom['x'],atom['y'] |
|
atom_label = atom['atom'] |
|
closest_charge = None |
|
min_distance = float('inf') |
|
|
|
for i, charge in charge_list.iterrows(): |
|
if i in used_charge_indices: |
|
continue |
|
|
|
charge_coord = charge['x'],charge['y'] |
|
charge_label = charge['charge'] |
|
|
|
distance = calculate_distance(atom_coord, charge_coord) |
|
|
|
if distance <= max_distance and distance < min_distance: |
|
closest_charge = charge |
|
min_distance = distance |
|
|
|
|
|
if closest_charge is not None: |
|
if closest_charge['charge'] == '1': |
|
charge_ = '+' |
|
else: |
|
charge_ = closest_charge['charge'] |
|
atom_ = atom['atom'] + charge_ |
|
|
|
|
|
atom_list.loc[idx,'atom'] = atom_ |
|
used_charge_indices.add(tuple(charge)) |
|
|
|
else: |
|
|
|
atom_list.loc[idx,'atom'] = atom['atom'] + '0' |
|
|
|
return atom_list |
|
|
|
|
|
|
|
def bbox_to_graph_with_charge(output, idx_to_labels, bond_labels,result): |
|
|
|
bond_labels_pre=bond_labels |
|
charge_labels = [18,19,20,21,22] |
|
|
|
atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']]) |
|
atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]] |
|
atoms_list = pd.DataFrame({'atom': atoms_list, |
|
'x': output['bbox_centers'][atoms_mask, 0], |
|
'y': output['bbox_centers'][atoms_mask, 1], |
|
'bbox': output['bbox'][atoms_mask].tolist() , |
|
}) |
|
|
|
charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']]) |
|
charge_list = [idx_to_labels[a] for a in output['pred_classes'][charge_mask]] |
|
charge_list = pd.DataFrame({'charge': charge_list, |
|
'x': output['bbox_centers'][charge_mask, 0], |
|
'y': output['bbox_centers'][charge_mask, 1]}) |
|
|
|
|
|
if len(charge_list) > 0: |
|
atoms_list = assemble_atoms_with_charges(atoms_list,charge_list) |
|
else: |
|
atoms_list['atom'] = atoms_list['atom']+'0' |
|
|
|
|
|
for idx, row in atoms_list.iterrows(): |
|
if row.atom[-1] != '0': |
|
try: |
|
if row.atom[-2] != '-': |
|
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])] |
|
except Exception as e: |
|
print(row.atom,"@row.atom") |
|
print(e) |
|
else: |
|
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])] |
|
|
|
kdt = cKDTree(overlapping[['x', 'y']]) |
|
dists, neighbours = kdt.query([row.x, row.y], k=2) |
|
if dists[1] < 7: |
|
atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True) |
|
|
|
bonds_list = [] |
|
|
|
|
|
bond_mask=np.logical_not(atoms_mask) & np.logical_not(charge_mask) |
|
for bbox, bond_type, score in zip(output['bbox'][bond_mask], |
|
output['pred_classes'][bond_mask], |
|
output['scores'][bond_mask]): |
|
|
|
|
|
if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: |
|
_margin = 5 |
|
else: |
|
_margin = 8 |
|
|
|
|
|
anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1]) |
|
oposite_anchor_positions = anchor_positions.copy() |
|
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1] |
|
|
|
|
|
|
|
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions]) |
|
|
|
|
|
atoms_pos = atoms_list[['x', 'y']].values |
|
kdt = cKDTree(atoms_pos) |
|
dists, neighbours = kdt.query(anchor_positions, k=1) |
|
|
|
|
|
if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0: |
|
|
|
begin_idx, end_idx = neighbours[:2] |
|
else: |
|
|
|
begin_idx, end_idx = neighbours[2:] |
|
|
|
|
|
if begin_idx != end_idx: |
|
if bond_type in bond_labels: |
|
bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score)) |
|
else: |
|
print(f'this box may be charges box not bonds {[bbox, bond_type, score ]}') |
|
else: |
|
continue |
|
|
|
|
|
return atoms_list, bonds_list |
|
|
|
|
|
|
|
def mol_from_graph_with_chiral(atoms_list, bonds): |
|
|
|
mol = RWMol() |
|
nodes_idx = {} |
|
atoms = atoms_list.atom.values.tolist() |
|
coords = [(row['x'], 300-row['y'], 0) for index, row in atoms_list.iterrows()] |
|
coords = tuple(coords) |
|
coords = tuple(tuple(num / 100 for num in sub_tuple) for sub_tuple in coords) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(bonds)): |
|
idx_1, idx_2, bond_type, bond_dir, score = bonds[i] |
|
if bond_type in ['-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: |
|
bonds[i] = (idx_1, idx_2, 'SINGLE', bond_dir, score) |
|
elif bond_type == '=': |
|
bonds[i] = (idx_1, idx_2, 'DOUBLE', bond_dir, score) |
|
elif bond_type == '#': |
|
bonds[i] = (idx_1, idx_2, 'TRIPLE', bond_dir, score) |
|
|
|
|
|
|
|
bond_types = {'SINGLE': Chem.rdchem.BondType.SINGLE, |
|
'DOUBLE': Chem.rdchem.BondType.DOUBLE, |
|
'TRIPLE': Chem.rdchem.BondType.TRIPLE, |
|
'AROMATIC': Chem.rdchem.BondType.AROMATIC} |
|
|
|
bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE, |
|
'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT, |
|
'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE, |
|
'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH, |
|
'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,} |
|
|
|
|
|
|
|
try: |
|
|
|
s10=[str(x) for x in range(10)] |
|
for idx, node in enumerate(atoms): |
|
|
|
|
|
if 'other' in node: |
|
a='*' |
|
if '-' in node or '+' in node: |
|
fc = int(node[-2:]) |
|
else: |
|
fc = int(node[-1]) |
|
elif node[-1] in s10: |
|
if '-' in node or '+' in node: |
|
a = node[:-2] |
|
fc = int(node[-2:]) |
|
else: |
|
a = node[:-1] |
|
fc = int(node[-1]) |
|
elif node[-1]=='+': |
|
a = node[:-1] |
|
fc = 1 |
|
elif node[-1]=='-': |
|
a = node[:-1] |
|
fc = -1 |
|
|
|
|
|
|
|
|
|
else: |
|
a = node |
|
fc = 0 |
|
|
|
ad = Chem.Atom(a) |
|
ad.SetFormalCharge(fc) |
|
|
|
atom_idx = mol.AddAtom(ad) |
|
nodes_idx[idx] = atom_idx |
|
|
|
|
|
existing_bonds = set() |
|
for idx_1, idx_2, bond_type, bond_dir, score in bonds: |
|
if (idx_1 in nodes_idx) and (idx_2 in nodes_idx): |
|
if (idx_1, idx_2) not in existing_bonds and (idx_2, idx_1) not in existing_bonds: |
|
try: |
|
mol.AddBond(nodes_idx[idx_1], nodes_idx[idx_2], bond_types[bond_type]) |
|
except Exception as e: |
|
print([idx_1, idx_2, bond_type, bond_dir, score],f"erro @add bonds ") |
|
print(f"erro@add existing_bonds: {e}\n{bonds}") |
|
continue |
|
existing_bonds.add((idx_1, idx_2)) |
|
|
|
if Chem.MolFromSmiles(Chem.MolToSmiles(mol.GetMol())): |
|
prev_mol = copy.deepcopy(mol) |
|
else: |
|
mol = copy.deepcopy(prev_mol) |
|
|
|
|
|
chiral_centers = Chem.FindMolChiralCenters( |
|
mol, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) |
|
chiral_center_ids = [idx for idx, _ in chiral_centers] |
|
|
|
for id in chiral_center_ids: |
|
for index, tup in enumerate(bonds): |
|
if id == tup[1]: |
|
new_tup = tuple([tup[1], tup[0], tup[2], tup[3], tup[4]]) |
|
bonds[index] = new_tup |
|
mol.RemoveBond(int(tup[0]), int(tup[1])) |
|
try: |
|
mol.AddBond(int(tup[1]), int(tup[0]), bond_types[tup[2]]) |
|
except Exception as e: |
|
print( index, tup, id) |
|
print(f"bonds: {bonds}") |
|
print(f"erro@chiral_center_ids: {e}") |
|
mol = mol.GetMol() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mol.RemoveAllConformers() |
|
conf = Chem.Conformer(mol.GetNumAtoms()) |
|
conf.Set3D(True) |
|
for i, (x, y, z) in enumerate(coords): |
|
conf.SetAtomPosition(i, (x, y, z)) |
|
mol.AddConformer(conf) |
|
|
|
Chem.AssignStereochemistryFrom3D(mol) |
|
|
|
bonds_ = [[row[0], row[1], row[3]] for row in bonds] |
|
|
|
n_atoms=len(atoms) |
|
for i in chiral_center_ids: |
|
for j in range(n_atoms): |
|
if [i,j,'BEGINWEDGE'] in bonds_: |
|
mol.GetBondBetweenAtoms(i, j).SetBondDir(bond_dirs['BEGINWEDGE']) |
|
elif [i,j,'BEGINDASH'] in bonds_: |
|
mol.GetBondBetweenAtoms(i, j).SetBondDir(bond_dirs['BEGINDASH']) |
|
|
|
Chem.SanitizeMol(mol) |
|
Chem.DetectBondStereochemistry(mol) |
|
Chem.AssignChiralTypesFromBondDirs(mol) |
|
Chem.AssignStereochemistry(mol) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smiles=Chem.MolToSmiles(mol) |
|
return smiles,mol |
|
|
|
|
|
except Chem.rdchem.AtomValenceException as e: |
|
print(f"捕获到 AtomValenceException 异常@@{e}") |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
print(f"捕获到 异常@@{e}") |
|
print(f"Error@@node {node} atom@@ {a} \n") |
|
print(atoms,idx,atoms[idx]) |
|
|
|
|
|
|
|
|
|
def mol_from_graph_without_chiral(atoms, bonds): |
|
|
|
mol = RWMol() |
|
nodes_idx = {} |
|
|
|
for i in range(len(bonds)): |
|
idx_1, idx_2, bond_type, bond_dir, score = bonds[i] |
|
if bond_type in ['-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: |
|
bonds[i] = (idx_1, idx_2, 'SINGLE', bond_dir, score) |
|
elif bond_type == '=': |
|
bonds[i] = (idx_1, idx_2, 'DOUBLE', bond_dir, score) |
|
elif bond_type == '#': |
|
bonds[i] = (idx_1, idx_2, 'TRIPLE', bond_dir, score) |
|
|
|
|
|
bond_types = {'SINGLE': Chem.rdchem.BondType.SINGLE, |
|
'DOUBLE': Chem.rdchem.BondType.DOUBLE, |
|
'TRIPLE': Chem.rdchem.BondType.TRIPLE, |
|
'AROMATIC': Chem.rdchem.BondType.AROMATIC} |
|
|
|
|
|
try: |
|
|
|
for idx, node in enumerate(atoms): |
|
if ('0' in node) or ('1' in node): |
|
a = node[:-1] |
|
fc = int(node[-1]) |
|
if '-1' in node: |
|
a = node[:-2] |
|
fc = -1 |
|
|
|
a = Chem.Atom(a) |
|
a.SetFormalCharge(fc) |
|
|
|
atom_idx = mol.AddAtom(a) |
|
nodes_idx[idx] = atom_idx |
|
|
|
|
|
existing_bonds = set() |
|
for idx_1, idx_2, bond_type, bond_dir, score in bonds: |
|
if (idx_1 in nodes_idx) and (idx_2 in nodes_idx): |
|
if (idx_1, idx_2) not in existing_bonds and (idx_2, idx_1) not in existing_bonds: |
|
try: |
|
mol.AddBond(nodes_idx[idx_1], nodes_idx[idx_2], bond_types[bond_type]) |
|
except: |
|
continue |
|
existing_bonds.add((idx_1, idx_2)) |
|
if Chem.MolFromSmiles(Chem.MolToSmiles(mol.GetMol())): |
|
prev_mol = copy.deepcopy(mol) |
|
else: |
|
mol = copy.deepcopy(prev_mol) |
|
|
|
mol = mol.GetMol() |
|
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol)) |
|
return Chem.MolToSmiles(mol) |
|
|
|
except Chem.rdchem.AtomValenceException as e: |
|
print("捕获到 AtomValenceException 异常") |
|
|
|
|
|
|
|
|
|
|
|
|