Skip to content

Commit

Permalink
Add final_c2m_driver and modify codes accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
choglass committed Nov 9, 2024
1 parent 547ad56 commit 669856d
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 24 deletions.
3 changes: 2 additions & 1 deletion cell2mol/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
import cell2mol
#from cell2mol import c2m_driver
from cell2mol import new_c2m_driver
# from cell2mol import new_c2m_driver
from cell2mol import final_c2m_driver

if __package__ == "":
path = os.path.dirname(os.path.dirname(__file__))
Expand Down
57 changes: 40 additions & 17 deletions cell2mol/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
#### CLASSES FOR CELL2MOL 2 ####
##################################
class specie(object):
def __init__(self, labels: list, coord: list, frac_coord: list, radii: list=None) -> None:
def __init__(self, labels: list, coord: list, frac_coord: list=None, radii: list=None) -> None:

# Sanity Checks
assert len(labels) == len(coord)
assert len(coord) == len(frac_coord)
if frac_coord is not None:
self.frac_coord = frac_coord
assert len(coord) == len(frac_coord)

# Optional Information
if radii is not None: self.radii = radii
Expand Down Expand Up @@ -195,8 +197,12 @@ def set_atoms(self, atomlist=None, create_adjacencies: bool=False, debug: int=0)
## For each l in labels, create an atom class object.
ismetal = elemdatabase.elementblock[l] == "d" or elemdatabase.elementblock[l] == "f"
if debug > 0: print(f"SPECIE.SET_ATOMS: {ismetal=}")
if ismetal: newatom = metal(l, self.coord[idx], self.frac_coord[idx], radii=self.radii[idx])
else: newatom = atom(l, self.coord[idx], self.frac_coord[idx],radii=self.radii[idx])
if self.frac_coord is not None:
if ismetal: newatom = metal(l, self.coord[idx], self.frac_coord[idx], radii=self.radii[idx])
else: newatom = atom(l, self.coord[idx], self.frac_coord[idx],radii=self.radii[idx])
else :
if ismetal: newatom = metal(l, self.coord[idx], radii=self.radii[idx])
else: newatom = atom(l, self.coord[idx], radii=self.radii[idx])
if debug > 0: print(f"SPECIE.SET_ATOMS: added atom to specie: {self.formula}")
newatom.add_parent(self, index=idx)
self.atoms.append(newatom)
Expand Down Expand Up @@ -344,8 +350,9 @@ def __repr__(self, indirect: bool=False):
### MOLECULE ##
###############
class molecule(specie):
def __init__(self, labels: list, coord: list, frac_coord: list, radii: list=None) -> None:
def __init__(self, labels: list, coord: list, frac_coord: list=None, radii: list=None) -> None:
self.subtype = "molecule"
if frac_coord is not None: self.frac_coord = frac_coord
specie.__init__(self, labels, coord, frac_coord, radii)

def __repr__(self):
Expand Down Expand Up @@ -393,7 +400,8 @@ def split_complex(self, debug: int=0):
# Split the "rest" to obtain the ligands
rest_labels = extract_from_list(rest_idx, self.labels, dimension=1)
rest_coord = extract_from_list(rest_idx, self.coord, dimension=1)
rest_frac = extract_from_list(rest_idx, self.frac_coord, dimension=1)
if self.frac_coord is not None:
rest_frac = extract_from_list(rest_idx, self.frac_coord, dimension=1)
rest_indices = extract_from_list(rest_idx, self.indices, dimension=1)
rest_radii = extract_from_list(rest_idx, self.radii, dimension=1)
rest_atoms = extract_from_list(rest_idx, self.atoms, dimension=1)
Expand All @@ -413,21 +421,25 @@ def split_complex(self, debug: int=0):
lig_indices = extract_from_list(b, rest_indices, dimension=1)
lig_labels = extract_from_list(b, rest_labels, dimension=1)
lig_coord = extract_from_list(b, rest_coord, dimension=1)
lig_frac_coord = extract_from_list(b, rest_frac, dimension=1)
if self.frac_coord is not None:
lig_frac_coord = extract_from_list(b, rest_frac, dimension=1)
lig_radii = extract_from_list(b, rest_radii, dimension=1)
lig_atoms = extract_from_list(b, rest_atoms, dimension=1)

if debug > 0: print(f"CREATING LIGAND: {labels2formula(lig_labels)}")
# Create Ligand Object
newligand = ligand(lig_labels, lig_coord, lig_frac_coord, radii=lig_radii)
if self.frac_coord is not None:
newligand = ligand(lig_labels, lig_coord, lig_frac_coord, radii=lig_radii)
else :
newligand = ligand(lig_labels, lig_coord, radii=lig_radii)
# For debugging
newligand.origin = "split_complex"
# Define the molecule as parent of the ligand. Bottom-Up hierarchy
newligand.add_parent(self, indices=lig_indices)

if self.check_parent("unit_cell"):
cell_indices = [a.get_parent_index("unit_cell") for a in lig_atoms]
newligand.add_parent(self.get_parent("unit_cell"), indices=cell_indices)
if self.check_parent("unitcell"):
cell_indices = [a.get_parent_index("unitcell") for a in lig_atoms]
newligand.add_parent(self.get_parent("unitcell"), indices=cell_indices)

if self.check_parent("reference"):
ref_indices = [a.get_parent_index("reference") for a in lig_atoms]
Expand Down Expand Up @@ -463,13 +475,18 @@ def get_hapticity(self, debug: int=0):
for entry in lig.haptic_type:
if entry not in self.haptic_type: self.haptic_type.append(entry)
return self.haptic_type


def save(self, path):
print(f"SAVING cell2mol CELL object to {path}")
with open(path, "wb") as fil:
pickle.dump(self,fil)
###############
### LIGAND ####
###############
class ligand(specie):
def __init__(self, labels: list, coord: list, frac_coord: list, radii: list=None) -> None:
def __init__(self, labels: list, coord: list, frac_coord: list=None, radii: list=None) -> None:
self.subtype = "ligand"
if frac_coord is not None: self.frac_coord = frac_coord
specie.__init__(self, labels, coord, frac_coord, radii)
self.evaluate_as_nitrosyl()

Expand Down Expand Up @@ -592,7 +609,8 @@ def split_ligand(self, debug: int=0):
print(f"\tLIGAND.SPLIT_LIGAND: {connected_idx=}")
conn_labels = extract_from_list(connected_idx, self.labels, dimension=1)
conn_coord = extract_from_list(connected_idx, self.coord, dimension=1)
conn_frac_coord = extract_from_list(connected_idx, self.frac_coord, dimension=1)
if self.frac_coord is not None:
conn_frac_coord = extract_from_list(connected_idx, self.frac_coord, dimension=1)
conn_radii = extract_from_list(connected_idx, self.radii, dimension=1)
conn_atoms = extract_from_list(connected_idx, self.atoms, dimension=1)
if debug >= 2: print(f"\tLIGAND.SPLIT_LIGAND: {conn_labels=}")
Expand All @@ -607,11 +625,15 @@ def split_ligand(self, debug: int=0):
if debug > 1: print(f"\tLIGAND.SPLIT_LIGAND: {gr_indices=}")
gr_labels = extract_from_list(b, conn_labels, dimension=1, debug=debug)
gr_coord = extract_from_list(b, conn_coord, dimension=1)
gr_frac_coord = extract_from_list(b, conn_frac_coord, dimension=1)
if self.frac_coord is not None:
gr_frac_coord = extract_from_list(b, conn_frac_coord, dimension=1)
gr_radii = extract_from_list(b, conn_radii, dimension=1)
gr_atoms = extract_from_list(b, conn_atoms, dimension=1)
# Create Group Object
newgroup = group(gr_labels, gr_coord, gr_frac_coord, radii=gr_radii)
if self.frac_coord is not None:
newgroup = group(gr_labels, gr_coord, gr_frac_coord, radii=gr_radii)
else:
newgroup = group(gr_labels, gr_coord, radii=gr_radii)
# For debugging
newgroup.origin = "split_ligand"
# Define the ligand as parent of the group. Bottom-Up hierarchy
Expand Down Expand Up @@ -655,8 +677,9 @@ def get_hapticity(self, debug: int=0):
#### GROUP ####
###############
class group(specie):
def __init__(self, labels: list, coord: list, frac_coord: list, radii: list=None) -> None:
def __init__(self, labels: list, coord: list, frac_coord: list=None, radii: list=None) -> None:
self.subtype = "group"
if frac_coord is not None: self.frac_coord = frac_coord
specie.__init__(self, labels, coord, frac_coord, radii)

#######################################################
Expand Down
59 changes: 59 additions & 0 deletions cell2mol/final_c2m_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sys
from cell2mol.helper import parsing_arguments
from cell2mol.refcell import process_refcell
from cell2mol.unitcell import process_unitcell
from cell2mol.xyz_molecule import get_molecule

def main():
input, system_type, cell_para, debug_mode = parsing_arguments()
current_dir = os.getcwd()
input_path = os.path.normpath(input)
dir, file = os.path.split(input_path)
name, extension = os.path.splitext(file)

print(input, input_path, system_type, cell_para, debug_mode, name, extension)
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input file not found: {input_path}")

if extension == ".cif":
handle_cif_file(input_path, system_type, name, current_dir, debug_mode)
elif extension == ".xyz":
handle_xyz_file(input_path, system_type, name, cell_para, current_dir, debug_mode)
else:
sys.exit("Invalid file extension")


def handle_cif_file(input_path, system_type, name, current_dir, debug_mode):
if system_type == "reference":
print("Processing reference (Wyckoff sites) from .cif file")
process_refcell(input_path, name, current_dir, debug_mode)
elif system_type == "unitcell":
print("Processing unit cell from .cif file")
process_unitcell(input_path, name, current_dir, debug_mode)
else:
sys.exit("Invalid system type for .cif file")


def handle_xyz_file(input_path, system_type, name, cell_para, current_dir, debug_mode):
if system_type == "unitcell":
if cell_para is None:
sys.exit("Cell parameters must be provided for .xyz file of a unit cell")
else:
print("Processing unit cell from .xyz file")
# users should provide chemical formula (Fe-O2-H3) of refmoleculist
# if users provide smiles, we check compare_species in connectivity module
elif system_type == "molecule":
print("Processing molecule from .xyz file")
get_molecule(input_path, name, current_dir, debug_mode)
else:
sys.exit("Invalid system type for .xyz file")

def write_error(error_obj, prefix):
if hasattr(error_obj, "error_case"):
case = error_obj.error_case
error_filepath = os.path.join(current_dir, f"{prefix}_error_{case}.out")
write_to_file(error_filepath, lambda: handle_error(case))

if __name__ == "__main__" or __name__ == "cell2mol.final_c2m_driver":
main()
43 changes: 38 additions & 5 deletions cell2mol/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python

import argparse

import numpy as np

def parsing_arguments():
"""Parses the arguments of the command line.
Expand All @@ -22,28 +22,61 @@ def parsing_arguments():
dest="filename",
type=str,
required=True,
help="Filename of Input (.info or .cif file)",
help="Filename of Input (.info, .xyz, or .cif file)",
)

parser.add_argument(
"-t",
"--type",
dest="system_type",
type=str,
choices=["reference", "unitcell", "molecule"],
required=True,
help="Type of information in the input file ('reference', 'unitcell' or 'molecule')",
)

parser.add_argument(
"--cell-para",
dest="cell_para",
type=float,
nargs=6,
help="Cell parameters (a, b, c, alpha, beta, gamma) for .xyz file",
)

parser.add_argument(
"-v",
"--verbose",
#dest="verbose",
help="Extended output for debugging.",
action="store_true",
)

parser.add_argument(
"-q",
"--quiet",
#dest="quiet",
help="Suppress all screen output. Overrides --verbose flag.",
action="store_true",
)

args = parser.parse_args()

return args.filename, args.verbose, args.quiet
cell_para = None
if args.filename.endswith(".xyz") and args.system_type == "unitcell":
if args.cell_para is None:
parser.error("Cell parameters must be provided for .xyz file of an unit cell")
cell_para = np.array(args.cell_para)

debug_mode = determine_debug_level(args.verbose, args.quiet)
return args.filename, args.system_type, cell_para, debug_mode

def determine_debug_level(isverbose, isquiet):
if isverbose and not isquiet:
return 2
elif isverbose and isquiet:
return 0
elif not isverbose and isquiet:
return 0
elif not isverbose and not isquiet:
return 1


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion cell2mol/new_c2m_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
##########################################
# Define new cell object for the unit cell
newcell = cell(name, cell_labels, cell_pos, cell_fracs, cell_vector, cell_param)
newcell.get_subtype("unit_cell")
newcell.get_subtype("unitcell")

if refcell.error_case != 0:
pass
Expand Down
87 changes: 87 additions & 0 deletions cell2mol/refcell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import sys
from ase.io import read
from contextlib import redirect_stdout
from cell2mol.classes import cell
from cell2mol.read_write import get_wyckoff_positions
from cell2mol.cell_operations import frac2cart_fromparam
from cell2mol.new_cell_reconstruction import modify_cov_factor_due_to_H, modify_cov_factor_due_to_possible_charges
from cell2mol.other import handle_error

# Constants
VERSION = "2.0"
COV_FACTOR = 1.3
METAL_FACTOR = 1.0

def process_refcell(input_path, name, current_dir, debug=0):
# Set up filenames
cell_fname = os.path.join(current_dir, f"Cell_{name}.cell")
ref_cell_fname = os.path.join(current_dir, f"Ref_Cell_{name}.cell")
output_fname = os.path.join(current_dir, "cell2mol.out")
summary_fname = os.path.join(current_dir, "summary.out")


# Redirect stdout to the output file for logging
with open(output_fname, "w") as output:
with redirect_stdout(output):
print(f"cell2mol version {VERSION}")
print(f"INITIATING cell object from input path: {input_path}")
print(f"Debug level: {debug}")

# # Read .cif file
structure = read(input_path)
cell_vector = structure.cell.array
cell_param = structure.cell.cellpar()

# Create the reference cell
refcell = create_reference(input_path, name, cell_vector, cell_param, debug)

# Finalize and save the reference cell object if no errors
if refcell.error_case == 0:
get_unique_species_in_reference(refcell, debug)
else:
print(f"Error occurred in processing reference cell: error case {refcell.error_case}")
refcell.save(ref_cell_fname)

error_fname = os.path.join(current_dir, f"reference_error_{refcell.error_case}.out")
with open(error_fname, "w") as error_output:
with redirect_stdout(error_output):
handle_error(refcell.error_case)
return refcell

def create_reference (input_path, name, cell_vector, cell_param, debug):
"""Create the reference cell object."""

ref_labels, ref_fracs = get_wyckoff_positions(input_path)
ref_pos = frac2cart_fromparam(ref_fracs, cell_param)

refcell = cell(name, ref_labels, ref_pos, ref_fracs, cell_vector, cell_param)
refcell.get_subtype("reference")
refcell.get_reference_molecules(ref_labels, ref_fracs, cov_factor=COV_FACTOR, debug=debug)
refcell = modify_cov_factor_due_to_H(refcell, debug=debug)

return refcell

def get_unique_species_in_reference (refcell, debug):
"""Processes the reference cell to obtain unique species and handle any errors."""

refcell.get_unique_species(debug=debug)
if debug >= 1:
print(f"Unique species: {[specie.formula for specie in refcell.unique_species]}")
print(f"Species list: {[specie.formula for specie in refcell.species_list]}\n")

refcell = modify_cov_factor_due_to_possible_charges(refcell, debug=debug)
refcell.get_selected_cs(debug=debug)
refcell.assess_errors(mode="possible_charges")

# Run the main function
if __name__ == "__main__":

input = sys.argv[1]
current_dir = os.getcwd()
input_path = os.path.normpath(input)
dir, file = os.path.split(input_path)
name, extension = os.path.splitext(file)

# Example usage, replace with actual arguments
process_refcell(input_path, name, current_dir, debug=1)
Loading

0 comments on commit 669856d

Please sign in to comment.