Skip to content

Commit

Permalink
Cleanup of inputs - homogenized wording in error messages, added valu…
Browse files Browse the repository at this point in the history
…e validation methods to AbstractTag for validations that are called multiple times, cleaned up cluttered tests and added comments to make it more clear what is being tested
  • Loading branch information
benrich37 committed Dec 10, 2024
1 parent b32127a commit d391836
Show file tree
Hide file tree
Showing 4 changed files with 413 additions and 233 deletions.
166 changes: 97 additions & 69 deletions src/pymatgen/io/jdftx/generic_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _validate_value_type(
value = [self.read(tag, str(x)) for x in value] if self.can_repeat else self.read(tag, str(value))
tag, is_valid, value = self._validate_value_type(type_check, tag, value)
except (TypeError, ValueError):
warning = f"Could not fix the typing for {tag} "
warning = f"Could not fix the typing for tag '{tag}'"
try:
warning += f"{value}!"
except (ValueError, TypeError):
Expand All @@ -110,7 +110,7 @@ def _validate_value_type(

def _validate_repeat(self, tag: str, value: Any) -> None:
if not isinstance(value, list):
raise TypeError(f"The {tag} tag can repeat but is not a list: {value}")
raise TypeError(f"The '{tag}' tag can repeat but is not a list: '{value}'")

@abstractmethod
def read(self, tag: str, value_str: str) -> Any:
Expand All @@ -124,6 +124,39 @@ def read(self, tag: str, value_str: str) -> Any:
Any: The parsed value.
"""

def _general_read_validate(self, tag: str, value_str: Any) -> None:
"""General validation for values to be passed to a read method."""
try:
value = str(value_str)
except (ValueError, TypeError):
value = "(unstringable)"
if not isinstance(value_str, str):
raise TypeError(f"Value '{value}' for '{tag}' should be a string!")

def _single_value_read_validate(self, tag: str, value: str) -> None:
"""Validation for values to be passed to a read method for AbstractTag inheritors that only
read a single value."""
self._general_read_validate(tag, value)
if len(value.split()) > 1:
raise ValueError(f"'{value}' for '{tag}' should not have a space in it!")

def _check_unread_values(self, tag: str, unread_values: list[str]) -> None:
"""Check for unread values and raise an error if any are found. Used in the read method of TagContainers."""
if len(unread_values) > 0:
raise ValueError(
f"Something is wrong in the JDFTXInfile formatting, the following values for tag '{tag}' "
f"were not processed: {unread_values}"
)

def _check_nonoptional_subtags(self, tag: str, subdict: dict[str, Any], subtags: dict[str, AbstractTag]) -> None:
"""Check for non-optional subtags and raise an error if any are missing.
Used in the read method of TagContainers."""
for subtag, subtag_type in subtags.items():
if not subtag_type.optional and subtag not in subdict:
raise ValueError(
f"The subtag '{subtag}' for tag '{tag}' is not optional but was not populated during the read!"
)

@abstractmethod
def write(self, tag: str, value: Any) -> str:
"""Write the tag and its value as a string.
Expand All @@ -149,7 +182,7 @@ def _write(self, tag: str, value: Any, multiline_override: bool = False) -> str:
if self.multiline_tag or multiline_override:
tag_str += "\\\n"
if self.write_value:
tag_str += f"{value} "
tag_str += f"{value}".strip() + " "
return tag_str

def _get_token_len(self) -> int:
Expand All @@ -165,7 +198,7 @@ def get_list_representation(self, tag: str, value: Any) -> list | list[list]:
Returns:
list | list[list]: The value converted to a list representation.
"""
raise ValueError(f"Tag object has no get_list_representation method: {tag}")
raise ValueError(f"Tag object with tag '{tag}' has no get_list_representation method")

def get_dict_representation(self, tag: str, value: Any) -> dict | list[dict]:
"""Convert the value to a dict representation.
Expand All @@ -177,7 +210,7 @@ def get_dict_representation(self, tag: str, value: Any) -> dict | list[dict]:
Returns:
dict | list[dict]: The value converted to a dict representation.
"""
raise ValueError(f"Tag object has no get_dict_representation method: {tag}")
raise ValueError(f"Tag object with tag '{tag}' has no get_dict_representation method")


@dataclass
Expand Down Expand Up @@ -234,8 +267,7 @@ def read(self, tag: str, value: str) -> bool:
Returns:
bool: The parsed boolean value.
"""
if len(value.split()) > 1:
raise ValueError(f"'{value}' for {tag} should not have a space in it!")
self._single_value_read_validate(tag, value)
try:
if not self.write_value:
# accounts for exceptions where only the tagname is used, e.g.
Expand All @@ -246,7 +278,7 @@ def read(self, tag: str, value: str) -> bool:
self.raise_value_error(tag, value)
return self._TF_options["read"][value]
except (ValueError, TypeError, KeyError) as err:
raise ValueError(f"Could not set '{value}' as True/False for {tag}!") from err
raise ValueError(f"Could not set '{value}' as True/False for tag '{tag}'!") from err

def write(self, tag: str, value: Any) -> str:
"""Write the tag and its value as a string.
Expand Down Expand Up @@ -303,16 +335,10 @@ def read(self, tag: str, value: str) -> str:
Returns:
str: The parsed string value.
"""
# This try except block needs to go before the value.split check
try:
value = str(value)
except (ValueError, TypeError) as err:
raise ValueError(f"Could not set (unstringable) to a str for {tag}!") from err
if len(value.split()) > 1:
raise ValueError(f"'{value}' for {tag} should not have a space in it!")
self._single_value_read_validate(tag, value)
if self.options is None or value in self.options:
return value
raise ValueError(f"The '{value}' string must be one of {self.options} for {tag}")
raise ValueError(f"The string value '{value}' must be one of {self.options} for tag '{tag}'")

def write(self, tag: str, value: Any) -> str:
"""Write the tag and its value as a string.
Expand Down Expand Up @@ -366,14 +392,11 @@ def read(self, tag: str, value: str) -> int:
Returns:
int: The parsed integer value.
"""
if not isinstance(value, str):
raise TypeError(f"Value {value} for {tag} should be a string!")
if len(value.split()) > 1:
raise ValueError(f"'{value}' for {tag} should not have a space in it!")
self._single_value_read_validate(tag, value)
try:
return int(float(value))
except (ValueError, TypeError) as err:
raise ValueError(f"Could not set '{value}' to a int for {tag}!") from err
raise ValueError(f"Could not set value '{value}' to an int for tag '{tag}'!") from err

def write(self, tag: str, value: Any) -> str:
"""Write the tag and its value as a string.
Expand Down Expand Up @@ -429,14 +452,11 @@ def read(self, tag: str, value: str) -> float:
Returns:
float: The parsed float value.
"""
if not isinstance(value, str):
raise TypeError(f"Value {value} for {tag} should be a string!")
if len(value.split()) > 1:
raise ValueError(f"'{value}' for {tag} should not have a space in it!")
self._single_value_read_validate(tag, value)
try:
value_float = float(value)
except (ValueError, TypeError) as err:
raise ValueError(f"Could not set '{value}' to a float for {tag}!") from err
raise ValueError(f"Could not set value '{value}' to a float for tag '{tag}'!") from err
return value_float

def write(self, tag: str, value: Any) -> str:
Expand Down Expand Up @@ -508,11 +528,8 @@ def read(self, tag: str, value: str) -> str:
Returns:
str: The parsed string value.
"""
try:
value = str(value)
except (ValueError, TypeError) as err:
raise ValueError(f"Could not set (unstringable) to a str for {tag}!") from err
return value
self._general_read_validate(tag, value)
return str(value)

def write(self, tag: str, value: Any) -> str:
"""Write the tag and its value as a string.
Expand Down Expand Up @@ -557,7 +574,7 @@ def _validate_single_entry(
self, value: dict | list[dict], try_auto_type_fix: bool = False
) -> tuple[list[str], list[bool], Any]:
if not isinstance(value, dict):
raise TypeError(f"This tag should be a dict: {value}, which is of the type {type(value)}")
raise TypeError(f"The value '{value}' (of type {type(value)}) must be a dict for this TagContainer!")
tags_checked: list[str] = []
types_checks: list[bool] = []
updated_value = deepcopy(value)
Expand Down Expand Up @@ -625,6 +642,7 @@ def read(self, tag: str, value: str) -> dict:
Returns:
dict: The parsed value.
"""
self._general_read_validate(tag, value)
value_list = value.split()
if tag == "ion":
special_constraints = [x in ["HyperPlane", "Linear", "None", "Planar"] for x in value_list]
Expand All @@ -647,7 +665,9 @@ def read(self, tag: str, value: str) -> dict:
subtag_count = value_list.count(subtag) # Get number of times subtag appears in line
if not subtag_type.can_repeat:
if subtag_count > 1:
raise ValueError(f"Subtag {subtag} is not allowed to repeat repeats in {tag}'s value {value}")
raise ValueError(
f"Subtag '{subtag}' for tag '{tag}' is not allowed to repeat but repeats value {value}"
)
idx_start = value_list.index(subtag)
token_len = subtag_type.get_token_len()
idx_end = idx_start + token_len
Expand Down Expand Up @@ -680,14 +700,10 @@ def read(self, tag: str, value: str) -> dict:
del value_list[0]

# reorder all tags to match order of __MASTER_TAG_LIST__ and do coarse-grained validation of read.

subdict = {x: tempdict[x] for x in self.subtags if x in tempdict}
for subtag, subtag_type in self.subtags.items():
if not subtag_type.optional and subtag not in subdict:
raise ValueError(f"The {subtag} tag is not optional but was not populated during the read!")
if len(value_list) > 0:
raise ValueError(
f"Something is wrong in the JDFTXInfile formatting, some values were not processed: {value}"
)
self._check_nonoptional_subtags(tag, subdict, self.subtags)
self._check_unread_values(tag, value_list)
return subdict

def write(self, tag: str, value: Any) -> str:
Expand All @@ -702,7 +718,7 @@ def write(self, tag: str, value: Any) -> str:
"""
if not isinstance(value, dict):
raise TypeError(
f"value = {value}\nThe value to the {tag} write method must be a dict since it is a TagContainer!"
f"The value '{value}' (of type {type(value)}) for tag '{tag}' must be a dict for this TagContainer!"
)

final_value = ""
Expand All @@ -714,9 +730,10 @@ def write(self, tag: str, value: Any) -> str:
# if it is not a list, then the tag will still be printed by the else
# this could be relevant if someone manually sets the tag's can_repeat value to a non-list.
print_str_list = [self.subtags[subtag].write(subtag, entry) for entry in subvalue]
print_str = " ".join(print_str_list)
print_str = " ".join([v.strip() for v in print_str_list]) + " "
# print_str = " ".join(print_str_list)
else:
print_str = self.subtags[subtag].write(subtag, subvalue)
print_str = self.subtags[subtag].write(subtag, subvalue).strip() + " "

if self.multiline_tag:
final_value += f"{indent}{print_str}\\\n"
Expand Down Expand Up @@ -751,7 +768,7 @@ def get_token_len(self) -> int:

def _make_list(self, value: dict) -> list:
if not isinstance(value, dict):
raise TypeError(f"The value {value} is not a dict, so could not be converted")
raise TypeError(f"The value '{value}' is not a dict, so could not be converted")
value_list = []
for subtag, subtag_value in value.items():
subtag_type = self.subtags[subtag]
Expand Down Expand Up @@ -799,12 +816,18 @@ def get_list_representation(self, tag: str, value: Any) -> list:
# cannot repeat: list of bool/str/int/float (elec-cutoff)
# cannot repeat: list of lists (lattice)
if self.can_repeat and not isinstance(value, list):
raise ValueError("Values for repeatable tags must be a list here")
raise ValueError(
f"Value '{value}' must be a list when passed to 'get_list_representation' since "
f"tag '{tag}' is repeatable."
)
if self.can_repeat:
if all(isinstance(entry, list) for entry in value):
return value # no conversion needed
if any(not isinstance(entry, dict) for entry in value):
raise ValueError(f"The {tag} tag set to {value} must be a list of dict")
raise ValueError(
f"The tag '{tag}' set to value '{value}' must be a list of dicts when passed to "
"'get_list_representation' since the tag is repeatable."
)
tag_as_list = [self._make_list(entry) for entry in value]
else:
tag_as_list = self._make_list(value)
Expand All @@ -815,11 +838,17 @@ def _check_for_mixed_nesting(tag: str, value: Any) -> None:
has_nested_dict = any(isinstance(x, dict) for x in value)
has_nested_list = any(isinstance(x, list) for x in value)
if has_nested_dict and has_nested_list:
raise ValueError(f"{tag} with {value} cannot have nested lists/dicts mixed with bool/str/int/floats!")
raise ValueError(
f"tag '{tag}' with value '{value}' cannot have nested lists/dicts mixed with bool/str/int/floats!"
)
if has_nested_dict:
raise ValueError(f"{tag} with {value} cannot have nested dicts mixed with bool/str/int/floats!")
raise ValueError(
f"tag '{tag}' with value '{value}' cannot have nested dicts mixed with bool/str/int/floats!"
)
if has_nested_list:
raise ValueError(f"{tag} with {value} cannot have nested lists mixed with bool/str/int/floats!")
raise ValueError(
f"tag '{tag}' with value '{value}' cannot have nested lists mixed with bool/str/int/floats!"
)

def _make_str_for_dict(self, tag: str, value_list: list) -> str:
"""Convert the value to a string representation.
Expand Down Expand Up @@ -848,12 +877,15 @@ def get_dict_representation(self, tag: str, value: list) -> dict | list[dict]:
# convert list or list of lists representation into string the TagContainer can process back into (nested) dict

if self.can_repeat and not isinstance(value, list):
raise ValueError("Values for repeatable tags must be a list here")
raise ValueError(
f"Value '{value}' must be a list when passed to 'get_dict_representation' since "
f"tag '{tag}' is repeatable."
)
if (
self.can_repeat and len({len(x) for x in value}) > 1
): # Creates a list of every unique length of the subdicts
# TODO: Populate subdicts with fewer entries with JDFTx defaults to make compatible
raise ValueError(f"The values for {tag} {value} provided in a list of lists have different lengths")
raise ValueError(f"The values '{value}' for tag '{tag}' provided in a list of lists have different lengths")
value = value.tolist() if isinstance(value, np.ndarray) else value

# there are 4 types of TagContainers in the list representation:
Expand Down Expand Up @@ -959,9 +991,10 @@ def get_format_index_for_str_value(self, tag: str, value: str) -> int:
return i
except (ValueError, TypeError) as e:
problem_log.append(f"Format {i}: {e}")
errormsg = f"No valid read format for '{tag} {value}' tag\n"
"Add option to format_options or double-check the value string and retry!\n\n"
raise ValueError(errormsg)
raise ValueError(
f"No valid read format for tag '{tag}' with value '{value}'\n"
"Add option to format_options or double-check the value string and retry!\n\n"
)

def raise_invalid_format_option_error(self, tag: str, i: int) -> None:
"""Raise an error for an invalid format option.
Expand All @@ -973,7 +1006,7 @@ def raise_invalid_format_option_error(self, tag: str, i: int) -> None:
Raises:
ValueError: If the format option is invalid.
"""
raise ValueError(f"{tag} option {i} is not it: validation failed")
raise ValueError(f"tag '{tag}' failed to validate for option {i}")

def _determine_format_option(self, tag: str, value_any: Any, try_auto_type_fix: bool = False) -> tuple[int, Any]:
"""Determine the format option for the value of this tag.
Expand Down Expand Up @@ -1006,9 +1039,10 @@ def _determine_format_option(self, tag: str, value_any: Any, try_auto_type_fix:
return i, value
except (ValueError, TypeError, KeyError) as e:
exceptions.append(e)
err_str = f"The format for {tag} for:\n{value_any}\ncould not be determined from the available options! "
"Check your inputs and/or MASTER_TAG_LIST!"
raise ValueError(err_str)
raise ValueError(
f"The format for tag '{tag}' with value '{value_any}' could not be determined from the available options! "
"Check your inputs and/or MASTER_TAG_LIST!"
)

def get_token_len(self) -> int:
"""Get the token length of the tag.
Expand Down Expand Up @@ -1043,6 +1077,7 @@ def read(self, tag: str, value_str: str) -> dict:
Returns:
dict: The parsed value.
"""
self._general_read_validate(tag, value_str)
value = value_str.split()
tempdict = {}
for subtag, subtag_type in self.subtags.items():
Expand All @@ -1053,13 +1088,8 @@ def read(self, tag: str, value_str: str) -> dict:
tempdict[subtag] = subtag_type.read(subtag, subtag_value)
del value[idx_start:idx_end]
subdict = {x: tempdict[x] for x in self.subtags if x in tempdict}
for subtag, subtag_type in self.subtags.items():
if not subtag_type.optional and subtag not in subdict:
raise ValueError(f"The {subtag} tag is not optional but was not populated during the read!")
if len(value) > 0:
raise ValueError(
f"Something is wrong in the JDFTXInfile formatting, some values were not processed: {value}"
)
self._check_nonoptional_subtags(tag, subdict, self.subtags)
self._check_unread_values(tag, value)
return subdict


Expand All @@ -1083,6 +1113,7 @@ def read(self, tag: str, value_str: str) -> dict:
Returns:
dict: The parsed value.
"""
self._general_read_validate(tag, value_str)
value = value_str.split()
tempdict = {}
# Each subtag is a freq, which will be a BoolTagContainer
Expand All @@ -1095,10 +1126,7 @@ def read(self, tag: str, value_str: str) -> dict:
# reorder all tags to match order of __MASTER_TAG_LIST__ and do coarse-grained validation of read
subdict = {x: tempdict[x] for x in self.subtags if x in tempdict}
# There are no forced subtags for dump
if len(value) > 0:
raise ValueError(
f"Something is wrong in the JDFTXInfile formatting, some values were not processed: {value}"
)
self._check_unread_values(tag, value)
return subdict


Expand Down
Loading

0 comments on commit d391836

Please sign in to comment.