| from rdkit import Chem |
|
|
| from src import constants |
|
|
|
|
| def remove_dummy_atoms(rdmol, sanitize=False): |
| |
| dummy_inds = [] |
| for a in rdmol.GetAtoms(): |
| if a.GetSymbol() == '*': |
| dummy_inds.append(a.GetIdx()) |
|
|
| dummy_inds = sorted(dummy_inds, reverse=True) |
| new_mol = Chem.EditableMol(rdmol) |
| for idx in dummy_inds: |
| new_mol.RemoveAtom(idx) |
| new_mol = new_mol.GetMol() |
| if sanitize: |
| Chem.SanitizeMol(new_mol) |
| return new_mol |
|
|
|
|
| def build_molecule(coords, atom_types, bonds=None, bond_types=None, |
| atom_props=None, atom_decoder=None, bond_decoder=None): |
| """ |
| Build RDKit molecule with given bonds |
| :param coords: N x 3 |
| :param atom_types: N |
| :param bonds: 2 x N_bonds |
| :param bond_types: N_bonds |
| :param atom_props: Dict, key: property name, value: list of float values (N,) |
| :param atom_decoder: list |
| :param bond_decoder: list |
| :return: RDKit molecule |
| """ |
| if atom_decoder is None: |
| atom_decoder = constants.atom_decoder |
| if bond_decoder is None: |
| bond_decoder = constants.bond_decoder |
| assert len(coords) == len(atom_types) |
| assert bonds is None or bonds.size(1) == len(bond_types) |
|
|
| mol = Chem.RWMol() |
| for i, atom in enumerate(atom_types): |
| element = atom_decoder[atom.item()] |
| charge = None |
| explicitHs = None |
|
|
| if len(element) > 1 and element.endswith('H'): |
| explicitHs = 1 |
| element = element[:-1] |
| elif element.endswith('+'): |
| charge = 1 |
| element = element[:-1] |
| elif element.endswith('-'): |
| charge = -1 |
| element = element[:-1] |
|
|
| if element == 'NOATOM': |
| |
| element = '*' |
|
|
| a = Chem.Atom(element) |
|
|
| if explicitHs is not None: |
| a.SetNumExplicitHs(explicitHs) |
| if charge is not None: |
| a.SetFormalCharge(charge) |
|
|
| if atom_props is not None: |
| for k, vals in atom_props.items(): |
| a.SetDoubleProp(k, vals[i].item()) |
|
|
| mol.AddAtom(a) |
|
|
| |
| conf = Chem.Conformer(mol.GetNumAtoms()) |
| for i in range(mol.GetNumAtoms()): |
| conf.SetAtomPosition(i, (coords[i, 0].item(), |
| coords[i, 1].item(), |
| coords[i, 2].item())) |
| mol.AddConformer(conf) |
|
|
| |
| if bonds is not None: |
| for bond, bond_type in zip(bonds.T, bond_types): |
| bond_type = bond_decoder[bond_type] |
| src = bond[0].item() |
| dst = bond[1].item() |
|
|
| |
| if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*': |
| continue |
| |
| |
|
|
| if mol.GetBondBetweenAtoms(src, dst) is not None: |
| assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \ |
| "Trying to assign two different types to the same bond." |
| continue |
|
|
| if bond_type is None or src == dst: |
| continue |
| mol.AddBond(src, dst, bond_type) |
|
|
| mol = remove_dummy_atoms(mol, sanitize=False) |
| return mol |
|
|