Skip to content

Commit

Permalink
Test run 3
Browse files Browse the repository at this point in the history
  • Loading branch information
choglass committed Jun 4, 2024
1 parent 5f528b4 commit a876327
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 251 deletions.
173 changes: 97 additions & 76 deletions cell2mol/charge_assignment.py

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions cell2mol/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,15 +1530,15 @@ def get_selected_cs(self, debug: int=0):
tmp = specie.get_possible_cs(debug=debug)
if tmp is None:
self.selected_cs.append(None)
if specie.subtype != "metal":
elif specie.subtype != "metal":
self.selected_cs.append(list([cs.corr_total_charge for cs in specie.possible_cs]))
else :
self.selected_cs.append(specie.possible_cs)

if None in self.selected_cs:
self.error_empty_poscharges = True
self.error_get_poscharges = True
else :
self.error_empty_poscharges = False
self.error_get_poscharges = False

#######################################################
def assign_charges (self, debug: int=0):
Expand Down Expand Up @@ -1719,13 +1719,13 @@ def assign_charges_old (self, debug: int=0) -> object:
for idx, spec in enumerate(self.unique_species):
tmp = spec.get_possible_cs(debug=debug)
if tmp is None:
self.error_empty_poscharges = True
self.error_get_poscharges = True
return # Stopping. Empty list of possible charges received.
if spec.subtype != "metal":
elif spec.subtype != "metal":
selected_cs.append(list([cs.corr_total_charge for cs in spec.possible_cs]))
else :
selected_cs.append(spec.possible_cs)
self.error_empty_poscharges = False
self.error_get_poscharges = False

# Finds the charge_state that satisfies that the crystal must be neutral
final_charge_distribution = balance_charge(self.unique_indices, self.unique_species, debug=debug)
Expand Down Expand Up @@ -1860,13 +1860,13 @@ def assess_errors(self, mode):
if self.has_isolated_H: case = 1
elif self.has_missing_H: case = 2
else : case = 0
elif mode == "unique_species":
elif mode == "possible_charges":
print("-------------------------------")
print("Errors in unique species")
print("Errors in possible charges")
print("-------------------------------")
if self.has_isolated_H: case = 1
elif self.has_missing_H: case = 2
elif self.error_empty_poscharges: case = 5
elif self.error_get_poscharges: case = 5
else : case = 0
elif mode == "reconstruction":
print("-------------------------------")
Expand All @@ -1885,7 +1885,7 @@ def assess_errors(self, mode):
elif self.has_missing_H: case = 2
elif self.error_get_fragments: case = 3
elif self.error_reconstruction: case = 4
elif self.error_empty_poscharges : case = 5
elif self.error_get_poscharges : case = 5
elif self.error_multiple_distrib : case = 6
elif self.error_empty_distrib : case = 7
elif self.error_create_bonds : case = 8
Expand All @@ -1901,7 +1901,7 @@ def assess_errors(self, mode):
# # elif self.error_get_fragments: case = 3
# # elif self.error_reconstruction: case = 4
# # Assign Charges
# # elif self.error_empty_poscharges : case = 5
# # elif self.error_get_poscharges : case = 5
# # elif self.error_multiple_distrib : case = 6
# # elif self.error_empty_distrib : case = 7
# # elif self.error_prepare_mols : case = 8
Expand Down
41 changes: 27 additions & 14 deletions cell2mol/new_c2m_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@
### PREPARES THE REFERENCE CELL OBJECT ###
##########################################
cov_factor = 1.3

# cov_factor -= 0.1
print(f"{cov_factor=}")
metal_factor = 1.0

# Get reference molecules
# labels, pos, ref_labels, ref_fracs, cellvec, cell_param = readinfo(infopath)
# labels, pos, and cellvec will not be used
Expand All @@ -112,8 +113,7 @@
# Increase covalent factor for H atoms
cov_factor += 0.05
refcell.get_reference_molecules(ref_labels, ref_fracs, cov_factor=cov_factor, debug=debug)
if debug >= 1:
print(f"Covalent factor increases: {cov_factor=}")
if debug >= 1: print(f"Covalent factor increases: {cov_factor=}")
refcell.check_missing_H(debug=debug)
refcell.assess_errors(mode="hydrogens")
refcell.save(ref_cell_fname)
Expand All @@ -126,23 +126,36 @@
print(f"refcell.species_list {[specie.formula for specie in refcell.species_list]}\n")
# Get possible charge states for the unique species in the reference cell
refcell.get_selected_cs(debug=debug)
refcell.assess_errors(mode="unique_species")

refcell.assess_errors(mode="possible_charges")

if refcell.error_get_poscharges:
if debug >= 1: print(f"{refcell.selected_cs=}")
while refcell.error_get_poscharges and cov_factor > 1.15:
# Decrease covalent factor for H atoms
cov_factor -= 0.05
refcell.get_reference_molecules(ref_labels, ref_fracs, cov_factor=cov_factor, debug=debug)
refcell.check_missing_H(debug=debug)
if not refcell.has_isolated_H:
refcell.get_unique_species(debug=debug)
refcell.get_selected_cs(debug=debug)

if debug >= 1: print(f"Covalent factor decreases: {cov_factor=}")
refcell.assess_errors(mode="possible_charges")
# Save reference cell object
refcell.save(ref_cell_fname)

##########################################
# Define new cell object for the unit cell
newcell = cell(name, cell_labels, cell_pos, cell_fracs, cell_vector, cell_param)
newcell.get_subtype("unit_cell")

if refcell.error_case != 0:
sys.exit(1)
pass
else:
reconstruction = True
charge_assignment = False
spin_assignment = False

# Define new cell object for the unit cell
newcell = cell(name, cell_labels, cell_pos, cell_fracs, cell_vector, cell_param)
newcell.get_subtype("unit_cell")


# Get reference molecules
newcell.get_reference_molecules(refcell.labels, refcell.frac_coord, cov_factor=cov_factor, debug=-1)
if not newcell.has_isolated_H:
Expand Down Expand Up @@ -188,10 +201,10 @@
print("*** Reference molecules ***")
print(refcell)
print_output(refcell.refmoleclist)


print("***Unit cell molecules ***")
print(newcell)
if hasattr(newcell, "moleclist"):
print("***Unit cell molecules ***")
print(newcell)
print_output(newcell.moleclist)

surmmary.close()
Expand Down
8 changes: 4 additions & 4 deletions cell2mol/new_c2m_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def cell2mol(newcell: object, refcell: object, sym_ops, reconstruction: bool=Tru
if not newcell.error_reconstruction:

if None in refcell.selected_cs :
newcell.error_empty_poscharges = True
newcell.error_get_poscharges = True
else:
newcell.error_empty_poscharges = False
newcell.error_get_poscharges = False
print_possible_and_selected_cs(newcell, refcell, debug=debug)

# Find charge distribution for the unit cell
Expand All @@ -63,7 +63,7 @@ def cell2mol(newcell: object, refcell: object, sym_ops, reconstruction: bool=Tru
# Assign charge for the unit cell and check charge neutrality
newcell.assign_charges(debug=debug)

if newcell.error_empty_poscharges : return newcell
if newcell.error_get_poscharges : return newcell
elif newcell.error_multiple_distrib : return newcell
elif newcell.error_empty_distrib : return newcell
else :
Expand All @@ -85,7 +85,7 @@ def cell2mol(newcell: object, refcell: object, sym_ops, reconstruction: bool=Tru
print(" Spin Assignment ")
print("#########################################")
tini = time.time()
if not newcell.error_empty_poscharges and not newcell.error_multiple_distrib and not newcell.error_empty_distrib:
if not newcell.error_get_poscharges and not newcell.error_multiple_distrib and not newcell.error_empty_distrib:
newcell.assign_spin(debug=debug)
tend = time.time()
if debug >= 1: print(f"\nTotal execution time for Spin Assignment: {tend - tini:.2f} seconds")
Expand Down
40 changes: 1 addition & 39 deletions cell2mol/new_charge_assignment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import numpy as np
from cell2mol.hungarian import reorder
import copy
from cell2mol.xyz2mol import xyz2mol
from cell2mol.charge_assignment import check_rdkit_obj_connectivity, arrange_data_for_reorder, charge_state, protonation
from rdkit import Chem
from cell2mol.charge_assignment import protonation, get_charge
import itertools

#######################################################
Expand Down Expand Up @@ -194,41 +191,6 @@ def set_charge_state(reference, target, mode, debug: int=0):
print(f"SET_CHARGE_STATE: WARNING!!! {target.formula=} {final_charge=} {cs.corr_total_charge=} final_charge != cs.corr_total_charge")
target.set_charges(cs.corr_total_charge, cs.corr_atom_charges, cs.smiles, cs.rdkit_obj)
print(f"SET_CHARGE_STATE:{target.formula=} {target.totcharge=} {target.smiles=}")

######################################################
def get_charge(charge: int, prot: object, allow: bool=True, embed_chiral: bool=True, debug: int=0):
## Generates the connectivity of a molecule given a desired charge (charge).
# The molecule is described by a protonation states that has labels, and the atomic cartesian coordinates "coords"
# The adjacency matrix is also provided in the protonation state(adjmat)
#:return charge_state which is an object with the necessary information for other functions to handle the result

natoms = prot.natoms
atnums = prot.atnums

# prot.coords and prot.cov_factor will not be used
mols = xyz2mol(atnums, prot.coords, prot.adjmat, prot.cov_factor, charge=charge, allow_charged_fragments=allow)
print(f"GET_CHARGE.{len(mols)=} received from xyz2mol with charge {charge}")

if len(mols) > 1: print("WARNING: More than 1 mol received from xyz2mol for initcharge:", charge)

# Smiles are generated with rdkit
smiles = Chem.MolToSmiles(mols[0])
if debug >= 2: print(f"GET_CHARGE. {smiles=}")
# Gets the resulting charges
atom_charge = []
total_charge = 0
for i in range(natoms):
a = mols[0].GetAtomWithIdx(i) # Returns a particular Atom
atom_charge.append(a.GetFormalCharge())
total_charge += a.GetFormalCharge()

# Connectivity is checked
iscorrect = check_rdkit_obj_connectivity(mols[0], prot.natoms, charge, debug=debug)

# Charge_state is initiated
ch_state = charge_state(iscorrect, total_charge, atom_charge, mols[0], smiles, charge, allow, prot)

return ch_state

######################################################
def prepare_mol (mol):
Expand Down
2 changes: 1 addition & 1 deletion cell2mol/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def handle_error(case: int):
if case == 2: print("We detected that H atoms are likely missing. This will cause errors in the charge prediction, so STOPPING pre-emptively.")
if case == 3: print("We failed to get fragments. STOPPING pre-emptively.")
if case == 4: print("After reconstruction of the unit cell, we still detected some fragments. STOPPING pre-emptively.")
if case == 5: print("Empty list of possible charges received for molecule or ligand")
if case == 5: print("Error in list of possible charges received for molecule or ligand")
if case == 6: print("More than one valid possible charge distribution found")
if case == 7: print("No valid possible charge distribution found")
# if case == 8: print("Error while preparing molecules")
Expand Down
303 changes: 226 additions & 77 deletions cell2mol/test/check_Cell_object.ipynb

Large diffs are not rendered by default.

74 changes: 45 additions & 29 deletions cell2mol/xyz2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
# make a list of valences, e.g. for CO: [[4],[2,1]]
valences_list_of_lists = []
AC_valence = list(AC.sum(axis=1))
print(f"{AC_valence=}")
wrong = 0

for i, (atomicNum, valence) in enumerate(zip(atoms, AC_valence)):
# valence can't be smaller than number of neighbours
Expand All @@ -502,12 +504,20 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
possible_valence.append(valence)
# if atomicNum == 15:
# print("Possible valences for:", atomicNum,"are",possible_valence, valence)
if not possible_valence:
pass
# print('Valence of atom',i,'is',valence,'which bigger than allowed max',max(atomic_valence[atomicNum]),'. Stopping')
if len(possible_valence) == 0:
print('WARNING!! Valence of atom', elemdatabase.elementsym[atomicNum], i,\
'is',valence,'which bigger than allowed max',max(atomic_valence[atomicNum]),'. Stopping')
possible_valence.append(valence)
wrong += 1
# sys.exit()
valences_list_of_lists.append(possible_valence)

print(f"{wrong=}")
if wrong > 0:
# print(f"AC2BO: {wrong=}")
return None, atomic_valence_electrons

print(f"\tAC2BO: {valences_list_of_lists=}")

# convert [[4],[2,1]] to [[4,2],[4,1]]
valences_list = []
for i in itertools.product(*valences_list_of_lists):
Expand All @@ -519,12 +529,15 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
best_BO = AC.copy()
# print("Final valences list:", list(valences_list), len(list(valences_list)))
BO_is_OK_list = []
# print(f"AC2BO: {valences_list=}")
for valences in valences_list:

# print("Sending", valences, AC_valence, "to get_UA")
#print(f"\tSending", valences, AC_valence, "to get_UA")
UA, DU_from_AC = get_UA(valences, AC_valence)

check_len = len(UA) == 0
#print (f"\tAC2BO: check_len", check_len)
#print(f"\tUA", UA)
if check_len:
check_bo = BO_is_OK(
AC,
Expand All @@ -540,6 +553,7 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
check_bo = None

if check_len and check_bo:
print(f"\tAC2BO: return AC", check_len, check_bo)
return AC, atomic_valence_electrons

UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph)
Expand All @@ -566,25 +580,29 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
allow_charged_fragments=allow_charged_fragments,
)

if status:
print(f"\tAC2BO: status", status)
return BO, atomic_valence_electrons
elif (
BO.sum() >= best_BO.sum()
and valences_not_too_large(BO, valences)
and charge_OK
):
# print(f"\tAC2BO: status", status, "BO.sum()", BO.sum(), "best_BO.sum()", best_BO.sum())
best_BO = BO.copy()
# if status:
# return BO, atomic_valence_electrons
# elif (
# BO.sum() >= best_BO.sum()
# and valences_not_too_large(BO, valences)
# and charge_OK
# ):
# best_BO = BO.copy()
# if status:
# return BO, atomic_valence_electrons
if status:
if (
BO.sum() >= best_BO.sum()
and valences_not_too_large(BO, valences)
and charge_OK
):
best_BO = BO.copy()
print("AC2BO: best bo", best_BO)
# if (
# BO.sum() >= best_BO.sum()
# and valences_not_too_large(BO, valences)
# and charge_OK
# ):
# best_BO = BO.copy()
# print("AC2BO: best bo", best_BO)
# print("best bo", best_BO)
print(f"\tAC2BO: return best bo")
#print("AC2BO: return best bo", best_BO)
return best_BO, atomic_valence_electrons


Expand All @@ -599,7 +617,9 @@ def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True)
allow_charged_fragments=allow_charged_fragments,
use_graph=use_graph,
)

if BO is None:
return [], None

# add BO connectivity and charge info to mol object
mol = BO2mol(
mol,
Expand Down Expand Up @@ -859,14 +879,10 @@ def xyz2mol(
use_graph=use_graph,
)

# Check for stereocenters and chiral centers

if embed_chiral:
is_okay = []
for new_mol in new_mols:
is_okay.append(chiral_stereo_check(new_mol))
return new_mols, all(is_okay)

# Check for stereocenters and chiral centers -> Move to get_charge function
# if embed_chiral:
# chiral_stereo_check(new_mol))

if exportBO:
return new_mols, BO
else:
Expand Down

0 comments on commit a876327

Please sign in to comment.