Skip to content

Commit

Permalink
Fix IndexError from non_local_groups
Browse files Browse the repository at this point in the history
  • Loading branch information
choglass committed Jan 5, 2025
1 parent 75e470f commit 479542b
Show file tree
Hide file tree
Showing 12 changed files with 994 additions and 96 deletions.
54 changes: 46 additions & 8 deletions cell2mol/charge_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,20 @@ def get_possible_charge_state(spec: object, debug: int=0):
print(f"GET_POSSIBLE_CHARGE_STATE: {spec.formula} ({spec.subtype}) {spec.cov_factor=}")
charge_states = []
### Evaluates possible charges for each protonation state ###
print(f"GET_POSSIBLE_CHARGE_STATE: {spec.formula} ({spec.subtype}) ({len(spec.protonation_states)}) {spec.protonation_states=}")
for prot in spec.protonation_states:
# charge_states = []
charge_states_for_one_prot = []
final_charges = get_list_of_charges_to_try(prot)
if debug >= 2: print(f" POSCHARGE will try charges {final_charges}")

for ich in final_charges:
ch_state = get_charge(ich, prot) ## Protonation is passed to the ch_state object (ch_state.protonation)
charge_states.append(ch_state)
charge_states_for_one_prot.append(ch_state)
if ch_state is not None:
if debug >= 2: print(f" POSCHARGE: charge {ich} with smiles {ch_state.smiles}")
else :
if debug >= 2: print(f" POSCHARGE: charge {ich} failed {ch_state}")
charge_states.extend(charge_states_for_one_prot)
if debug >= 2: print(f"POSCHARGE: {len(charge_states)=}")
### After collecting charge states, then best ones are selected
if spec.subtype == "ligand":
Expand Down Expand Up @@ -222,6 +224,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:

# Boolean that decides whether a non-local approach is needed
non_local_groups = 0
non_local_groups_indices = []
needs_nonlocal = False

# Initialization of the variables
Expand Down Expand Up @@ -387,6 +390,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
if a.connec == 1:
needs_nonlocal = True
non_local_groups += 1
non_local_groups_indices.append(idx)
if debug >= 2: print(f" GET_PROTONATION_STATES: will be sent to nonlocal due to {a.label} atom")
elif a.connec > 1:
block[idx] = 1
Expand All @@ -401,6 +405,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
elif a.connec == 2:
needs_nonlocal = True
non_local_groups += 1
non_local_groups_indices.append(idx)
if debug >= 2: print(f" GET_PROTONATION_STATES: will be sent to nonlocal due to {a.label} atom")
# block[idx] = 1
# elemlist[idx] = "H"
Expand Down Expand Up @@ -435,7 +440,8 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
if a.connec >= 3:
block[idx] = 1
# needs_nonlocal = True
# non_local_groups += 1
# non_local_groups += 1
# non_local_groups_indices.append(idx)
# elemlist[idx] = "H"
# addedlist[idx] = 1
else:
Expand All @@ -451,6 +457,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
else:
needs_nonlocal = True
non_local_groups += 1
non_local_groups_indices.append(idx)
if debug >= 2: print(f" GET_PROTONATION_STATES: will be sent to nonlocal due to {a.label} atom")
# Phosphorous
elif (a.connec >= 3) and a.label == "P":
Expand Down Expand Up @@ -480,6 +487,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
else:
needs_nonlocal = True
non_local_groups += 1
non_local_groups_indices.append(idx)
if debug >= 2: print(f" GET_PROTONATION_STATES: will be sent to nonlocal due to {a.label} atom")
# Silicon
elif a.label == "Si":
Expand All @@ -498,6 +506,7 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
if not needs_nonlocal:
needs_nonlocal = True
non_local_groups += 1
non_local_groups_indices.append(idx)
if debug >= 2: print(f" GET_PROTONATION_STATES: will be sent to nonlocal due to {a.label} atom with no rules")

# If, at this stage, we have found that any atom must be added, this is done before entering the non_local part.
Expand Down Expand Up @@ -544,12 +553,13 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:
if debug >= 2: print(f" GET_PROTONATION_STATES: local_labels: {local_labels}")
if debug >= 2: print(f" GET_PROTONATION_STATES: block: {block}")
if debug >= 2: print(f" GET_PROTONATION_STATES: addedlist: {addedlist}")
if debug >= 2: print(f" GET_PROTONATION_STATES: {non_local_groups} non_local_groups groups found")
if debug >= 2: print(f" GET_PROTONATION_STATES: {non_local_groups=}")
if debug >= 2: print(f" GET_PROTONATION_STATES: {len(non_local_groups_indices)} non_local_groups groups found")
if debug >= 2: print(f" GET_PROTONATION_STATES: non_local_groups={[ligand.labels[idx] for idx in non_local_groups_indices]}")
if debug >= 2: print(f" GET_PROTONATION_STATES: {non_local_groups_indices=}")
# CREATES ALL COMBINATIONS OF PROTONATION STATES#
# Creates [0,1] tuples for each non_local protonation site
tmp = []
for kdx in range(0,non_local_groups):
for kdx in range(0, len(non_local_groups_indices)):
tmp.append([0,1])

if len(tmp) > 1:
Expand All @@ -574,13 +584,13 @@ def get_protonation_states_specie(specie: object, debug: int=0) -> list:

o_s = np.sum(com)
toallocate = int(0)
print(f"{non_local_groups=}")
print(f"{non_local_groups=} {non_local_groups_indices=}")
for jdx, a in enumerate(ligand.atoms):
if a.mconnec >= 1 and a.label not in avoid and block[jdx] == 0:
print(jdx, a.label, a.mconnec)
print("====")
for jdx, a in enumerate(ligand.atoms):
if a.mconnec >= 1 and a.label not in avoid and block[jdx] == 0:
if a.mconnec >= 1 and a.label not in avoid and block[jdx] == 0 and jdx in non_local_groups_indices:
print(a.label)
if non_local_groups > 1:
print(f"{com=} {toallocate=}")
Expand Down Expand Up @@ -904,11 +914,29 @@ def get_metal_poscharges(metal: object, debug: int=0) -> list:
# Coordination Geometry Table of the D-Block Elements and Their Ions.
# J. Chem. Educ. 1997, 74, 915.

# metalloids = ["B", "Si", "Ge", "As", "Sb", "Te", "Po"]
mol = metal.get_parent("molecule")
if not hasattr(mol,"is_haptic"): mol.get_hapticity()
atnum = elemdatabase.elementnr[metal.label]

at_charge = defaultdict(list)

# Alkali Metals
at_charge[3] = [1] # Li
at_charge[11] = [1] # Na
at_charge[13] = [1] # K
at_charge[31] = [1] # Rb
at_charge[49] = [1] # Cs
at_charge[81] = [1] # Fr

# Alkaline Earth Metals
at_charge[4] = [2] # Be
at_charge[12] = [2] # Mg
at_charge[20] = [2] # Ca
at_charge[38] = [2] # Sr
at_charge[56] = [2] # Ba
at_charge[88] = [2] # Ra

# 1st-row transition metals.
at_charge[21] = [3] # Sc
at_charge[22] = [2, 3, 4] # Ti
Expand Down Expand Up @@ -943,6 +971,16 @@ def get_metal_poscharges(metal: object, debug: int=0) -> list:
at_charge[79] = [1, 3] # Au
at_charge[80] = [2] # Hg

# post-transition metals
at_charge[13] = [0, 3] # Al
at_charge[31] = [0, 3] # Ga
at_charge[32] = [0, 2, 4] # Ge
at_charge[49] = [0, 3] # In
at_charge[50] = [0, 2, 4] # Sn
at_charge[81] = [0, 1, 3] # Tl
at_charge[82] = [0, 2, 4] # Pb
at_charge[83] = [0, 3] # Bi

poscharges = at_charge[atnum]

list_of_zero_OS = ["Fe", "Ni", "Ru"]
Expand Down
Loading

0 comments on commit 479542b

Please sign in to comment.