Skip to content

Commit

Permalink
Reorganize spin prediction model
Browse files Browse the repository at this point in the history
  • Loading branch information
choglass committed Jan 9, 2024
1 parent f6728de commit 375c0ee
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 37 deletions.
152 changes: 139 additions & 13 deletions cell2mol/c2m_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from typing import Tuple
import sklearn
from cell2mol import __file__
from cell2mol.elementdata import ElementData


elemdatabase = ElementData()

##################################################################################
def get_refmoleclist_and_check_missingH(cell: object, reflabels: list, fracs: list, debug: int=0) -> Tuple[object, float, float]:
Expand All @@ -38,12 +42,12 @@ def get_refmoleclist_and_check_missingH(cell: object, reflabels: list, fracs: li

# Get ref.molecules --> output: a valid list of ref.molecules
# refmoleclist, covalent_factor, metal_factor, Warning = get_reference_molecules(reflabels, refpos, debug=debug)
refmoleclist, covalent_factor, metal_factor, Warning = get_reference_molecules_simple (reflabels, refpos, debug=1)
refmoleclist, covalent_factor, metal_factor, Warning = get_reference_molecules_simple (reflabels, refpos, debug)
cell.warning_list.append(Warning)

# Check missing hydrogens in ref.molecules
if not any(cell.warning_list):
Warning, ismissingH, Missing_H_in_C, Missing_H_in_CoordWater = check_missingH(refmoleclist)
Warning, ismissingH, Missing_H_in_C, Missing_H_in_CoordWater = check_missingH(refmoleclist, debug)
cell.warning_list.append(Missing_H_in_C)
cell.warning_list.append(Missing_H_in_CoordWater)

Expand Down Expand Up @@ -244,27 +248,143 @@ def assign_spin (cell: object, debug: int=0) -> object:
print("#########################################")

for mol in cell.moleclist:
N = count_N(mol)
# count number of electrons in the complex
N = count_N(mol)

if mol.type == "Complex":
if len(mol.metalist) == 1: # mono-metallic complexes
N = count_N(mol)
met = mol.metalist[0]
period = elemdatabase.elementperiod[met.label]
d_elec = count_d_elec (met.label, met.totcharge)

if period == 4: # 3d transition metals
if d_elec in [0, 1, 9, 10]:
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)
elif d_elec in [2, 3] and met.hapticity == False :
if N % 2 == 0:
mol.magnetism(3)
else:
mol.magnetism(4)
elif d_elec in [4, 5, 6, 7, 8] or (d_elec in [2, 3] and met.hapticity == True) :
# Predict spin multiplicity of metal based on Random forest model
feature = generate_feature_vector (met)
path_rf = os.path.join( os.path.abspath(os.path.dirname(__file__)), "total_spin_3131.pkl")
rf = pickle.load(open(path_rf, 'rb'))
predictions = rf.predict(feature)
spin_rf = predictions[0]
mol.magnetism(spin_rf)
else :
print("Error: d_elec is not in the range of 0-10", d_elec)

if met.hapticity == False :
rel = calcualte_relative_metal_radius (met)
met.relative_radius(rel, rel, rel)
else :
rel = calcualte_relative_metal_radius (met)
rel_g, rel_c = calcualte_relative_metal_radius_haptic_complexes (met)
met.relative_radius(rel, rel_g, rel_c)

for lig in mol.ligandlist:
if count_N(lig) %2 == 0:
lig.magnetism(1)
else:
lig.magnetism(2)

if debug >= 1: print(f"{mol.type=}, {mol.formula=}, {mol.spin=} {spin_rf=}")
if debug >= 1: print(f"{met.label=} {met.hapticity=} {met.hapttype=} {met.geometry=} {met.coordination_number=} {met.coordinating_atoms=}")

#elif (period == 5 or period == 6 ) and (d_elec in [2, 3] and met.hapticity == False) :
# TODO : Predict the ground state spin of coordination complexes with 4d or 5d transition metal (d2, d3)
else : # 4d or 5d transition metals
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)

for lig in mol.ligandlist:
if count_N(lig) %2 == 0:
lig.magnetism(1)
else:
lig.magnetism(2)

else : # Bi- & Poly-metallic complexes
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)

for lig in mol.ligandlist:
if count_N(lig) %2 == 0:
lig.magnetism(1)
else:
lig.magnetism(2)

else: # mol.type == "Other"
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)

if debug >= 1:
for mol in cell.moleclist:
if mol.type == "Complex":
print(f"{mol.type=}, {mol.formula=}, {mol.spin=}")
for lig in mol.ligandlist:
if lig.natoms != 1:
print(f"\t{lig.formula=}, {lig.spin=}")
else :
print(f"\t{lig.formula=}")
else :
if mol.natoms != 1:
print(f"{mol.type=}, {mol.formula=}, {mol.spin=}")
else :
print(f"{mol.type=}, {mol.formula=}")

return cell

##################################################################################
def assign_spin_old (cell: object, debug: int=0) -> object:
"""Assign spin multiplicity to molecules in the cell object
Args:
cell (object): cell object
debug (int, optional): debug level. Defaults to 0.
Returns:
object: cell object with spin multiplicity assigned
"""

if debug >= 1:
print("#########################################")
print("Assigning spin multiplicity")
print("#########################################")

for mol in cell.moleclist:
# count number of electrons in the complex
N = count_N(mol)

if mol.type == "Complex":
if len(mol.metalist) == 1: # mono-metallic complexes
met = mol.metalist[0]

# count valence electrons
d_elec = count_d_elec(met.label, met.totcharge)

# calculate relative metal radius
rel = calcualte_relative_metal_radius(met)

# Count nitrosyl ligands
# Make a list of ligands
arr = []
for lig in mol.ligandlist:
arr.append(sorted(lig.labels))
if count_N(lig) %2 == 0:
lig.magnetism(1)
else:
lig.magnetism(2)

lig.magnetism(2)

# Count nitrosyl ligands
nitrosyl = count_nitrosyl(np.array(arr, dtype=object))
if debug >= 2: print(np.array(arr, dtype=object))
if debug >= 2: print(f"{nitrosyl=}")
Expand All @@ -279,6 +399,7 @@ def assign_spin (cell: object, debug: int=0) -> object:
feature = generate_feature_vector (met)
print(feature)
path_rf = os.path.join( os.path.abspath(os.path.dirname(__file__)), "TM-GSspin_RandomForest.pkl")

#print(path_rf)
rf = pickle.load(open(path_rf, 'rb'))
predictions = rf.predict(feature)
Expand All @@ -287,9 +408,13 @@ def assign_spin (cell: object, debug: int=0) -> object:
mol.ml_prediction(spin_rf)

if spin == 0 : # unknown spin state
mol.magnetism(1)
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)
else:
mol.magnetism(spin)
mol.magnetism(spin_rf)
#mol.magnetism(spin)

if debug >= 1: print(f"{mol.type=}, {mol.formula=}, {mol.spin=} {mol.spin_rf=}")
if debug >= 1: print(f"{met.label=} {met.hapticity=} {met.geometry=} {met.coordination_number=} {met.coordinating_atoms=}")
Expand All @@ -300,7 +425,7 @@ def assign_spin (cell: object, debug: int=0) -> object:
#if debug >= 1: print(f"{rel_g=} {rel_c=}")
met.relative_radius(rel, rel_g, rel_c)

if count_N(mol) % 2 == 0:
if N % 2 == 0:
mol.magnetism(1) # spin multiplicity = 1 Singlet
else:
mol.magnetism(2) # spin multiplicity = 2 Doublet
Expand All @@ -310,7 +435,7 @@ def assign_spin (cell: object, debug: int=0) -> object:
if debug >= 1: print(f"met_OS={met.totcharge} {d_elec=} {N=} {nitrosyl=} {met.rel=} {met.rel_g=} {met.rel_c=}\n")

else : # Bi- & Poly-metallic complexes
if count_N(mol) % 2 == 0:
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)
Expand All @@ -322,7 +447,7 @@ def assign_spin (cell: object, debug: int=0) -> object:
lig.magnetism(2)

else: # mol.type == "Other" or "Ligand"
if count_N(mol) % 2 == 0:
if N % 2 == 0:
mol.magnetism(1)
else:
mol.magnetism(2)
Expand Down Expand Up @@ -394,8 +519,9 @@ def cell2mol(infopath: str, refcode: str, output_dir: str, step: int=3, debug: i

if not any(newcell.warning_list):
if debug >= 1: print("Charge Assignment successfully finished.\n")
# TODO : Compare assigned charges with ML predicted charges

# Spin state assignment
# Spin state assignmentc
newcell = assign_spin(newcell, debug=debug)

if debug >= 1: newcell.print_charge_assignment()
Expand Down
43 changes: 28 additions & 15 deletions cell2mol/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

dataframe=argv[1]
metal = argv[2]
mode = argv[3]
#mode = argv[3]
prop = argv[3]


print("Sklearn version:", sklearn.__version__)

Expand All @@ -47,12 +49,21 @@

#exit()

prop = "spin_multiplicity"
#prop = "spin_multiplicity"

# extract = ["elem_nr", "m_ox", "d_elec"] # F_TM
# extract = ["CN", "geom_nr", "rel_m"] # F_CE
extract = ["elem_nr", "m_ox", "d_elec", "CN", "geom_nr", "rel_m"] # F_TM+CE
#prop = "m_ox"

if prop == "spin_multiplicity" or prop == "spin" or prop == "s" :
# extract = ["elem_nr", "m_ox", "d_elec"] # F_TM
# extract = ["CN", "geom_nr", "rel_m"] # F_CE
#extract = ["elem_nr", "m_ox", "d_elec", "CN", "geom_nr", "rel_m"] # F_TM+CE
extract = ["elem_nr", "m_ox", "d_elec", "CN", "geom_nr", "rel_m", "hapticity"] # F_TM+CE+hapticity
elif prop == "m_ox":
extract = ["elem_nr", "CN", "geom_nr", "rel_m", "hapticity"]
else :
print("No such property in the database")
exit()


Nfix=list(df["refcode"])
print("the number of complexes :", len(Nfix))
Expand Down Expand Up @@ -146,19 +157,21 @@
print(f"\n Incorrect predictions for replica {rep}:")
for idx, sys in enumerate(is_correct):
if not sys and print_incorrect:
m_ox = df[ df.refcode == l_te[idx] ]["m_ox"].item()
metal_elem = df[ df.refcode == l_te[idx] ]["metal"].item()
print(
f"System {l_te[idx]} has prediction {predictions[idx]} with probability {np.max(prediction_probs[idx])} and reference {y_te[idx]}"
f"System {l_te[idx]} has prediction {predictions[idx]} with probability {np.max(prediction_probs[idx])} and reference {y_te[idx]} metal {metal_elem} m_ox {m_ox}"
)

print("\n \n Summary of replica results:")
print(f"Training mean accuracy was {np.mean(acc_train)} with STD {np.std(acc_train)}")
print(f"Test mean accuracy was {np.mean(acc_test)} with STD {np.std(acc_test)}")
print(f"Training mean f1_score_micro was {np.mean(f1_train_micro)} with STD {np.std(f1_train_micro)}")
print(f"Test mean f1_score_micro was {np.mean(f1_test_micro)} with STD {np.std(f1_test_micro)}")
print(f"Training mean f1_score_macro was {np.mean(f1_train_macro)} with STD {np.std(f1_train_macro)}")
print(f"Test mean f1_score_macro was {np.mean(f1_test_macro)} with STD {np.std(f1_test_macro)}")
print(f"Training mean f1_score_weighted was {np.mean(f1_train_weighted)} with STD {np.std(f1_train_weighted)}")
print(f"Test mean f1_score_weighted was {np.mean(f1_test_weighted)} with STD {np.std(f1_test_weighted)}")
print(f"Training mean accuracy was {round(np.mean(acc_train),3)} with STD {round(np.std(acc_train),3)}")
print(f"Test mean accuracy was {round(np.mean(acc_test),3)} with STD {round(np.std(acc_test),3)}")
print(f"Training mean f1_score_micro was {round(np.mean(f1_train_micro),3)} with STD {round(np.std(f1_train_micro),3)}")
print(f"Test mean f1_score_micro was {round(np.mean(f1_test_micro),3)} with STD {round(np.std(f1_test_micro),3)}")
print(f"Training mean f1_score_macro was {round(np.mean(f1_train_macro),3)} with STD {round(np.std(f1_train_macro),3)}")
print(f"Test mean f1_score_macro was {round(np.mean(f1_test_macro),3)} with STD {round(np.std(f1_test_macro),3)}")
print(f"Training mean f1_score_weighted was {round(np.mean(f1_train_weighted),3)} with STD {round(np.std(f1_train_weighted),3)}")
print(f"Test mean f1_score_weighted was {round(np.mean(f1_test_weighted),3)} with STD {round(np.std(f1_test_weighted),3)}")

try:
assert len(maxprob) == len(l_oos)
Expand All @@ -170,7 +183,7 @@
np.savetxt(metal + "_maxprob_{}.txt".format(mode), dat, delimiter=" ", fmt="%s")

# We train a model on all available data and save it
filename = "{}_{}.pkl".format(metal, mode)
filename = "{}_{}_{}.pkl".format(metal, prop, len(df))
learner = learner = rf_random.best_estimator_.fit(X, Y)
pickle.dump(learner, open(filename, "wb"))
print("feature importance", learner.feature_importances_)
Expand Down
26 changes: 17 additions & 9 deletions cell2mol/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_centroid(arr: np.array) -> list:
return centroid

################################
def calcualte_relative_metal_radius_haptic_complexes (metal):
def calcualte_relative_metal_radius_haptic_complexes (metal, debug=0):

diff_list_g = []
diff_list_c = []
Expand Down Expand Up @@ -264,13 +264,13 @@ def calcualte_relative_metal_radius_haptic_complexes (metal):
average_c = round(np.average(diff_list_c), 3)
rel_c = round(average_c/elemdatabase.CovalentRadius3[metal.label], 3)

print(f"{len(metal.group_list)=}, {diff_list_g}, {average_g=}, {rel_g=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")
print(f"{len(metal.group_list)=}, {diff_list_c}, {average_c=}, {rel_c=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")
if debug >=2 : print(f"{len(metal.group_list)=}, {diff_list_g}, {average_g=}, {rel_g=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")
if debug >=2 : print(f"{len(metal.group_list)=}, {diff_list_c}, {average_c=}, {rel_c=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")

return rel_g, rel_c

################################
def calcualte_relative_metal_radius (metal):
def calcualte_relative_metal_radius (metal, debug=0):
""" Calculate relative metal radius for a given transition metal coordination complex
Args:
metal (obj): metal atom object
Expand All @@ -287,7 +287,7 @@ def calcualte_relative_metal_radius (metal):
average = round(np.average(diff_list), 3)
rel = round(average/elemdatabase.CovalentRadius3[metal.label], 3)

print(f"{metal.coordinating_atoms}, {diff_list}, {average=}, {rel=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")
if debug >=2 : print(f"{metal.coordinating_atoms}, {diff_list}, {average=}, {rel=}, {metal.label}, {elemdatabase.CovalentRadius3[metal.label]}")

return rel

Expand All @@ -304,9 +304,17 @@ def generate_feature_vector (metal):
d_elec = count_d_elec (metal.label, m_ox)
CN = metal.coordination_number
geom_nr = make_geom_list()[metal.geometry]
rel = calcualte_relative_metal_radius (metal)

feature = np.array([[elem_nr, m_ox, d_elec, CN, geom_nr, rel]])

if metal.hapticity == False :
rel = calcualte_relative_metal_radius (metal)
hapticity = 0
else :
dummy, rel = calcualte_relative_metal_radius_haptic_complexes (metal)
hapticity = 1

print(f"{elem_nr=} {m_ox=} {d_elec=} {rel=} {hapticity=}\n")

feature = np.array([[elem_nr, m_ox, d_elec, CN, geom_nr, rel, hapticity]])

return feature

Expand Down Expand Up @@ -348,7 +356,7 @@ def get_posspin_v2 (d_elec: int, m_ox: int, geometry: str, metal: str) -> list:
elif metal == "Ni" and m_ox == 3:
posspin = ["LS"]
else :
posspin = ["LS", "HS"]
posspin = ["LS", "IS"]
return posspin

################################
Expand Down
Binary file added cell2mol/total_spin_3131.pkl
Binary file not shown.

0 comments on commit 375c0ee

Please sign in to comment.