Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes/2000.maintenance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Ensure that model data writer is not changing the general output path for all other modules.
Change and simplify model data writer API.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def main():
else layout.export_telescope_list_table(crs_name=app_context.args["export"])
)
writer.ModelDataWriter.dump(
args_dict=app_context.args,
output_file=app_context.args.get("output_file"),
output_file_format=app_context.args.get("output_file_format", "ascii.ecsv"),
metadata=metadata,
product_data=product_data,
validate_schema_file=validate_schema_file,
Expand Down
2 changes: 1 addition & 1 deletion src/simtools/applications/db_get_array_layouts_from_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def main():

if not app_context.args.get("output_file_from_default", False):
writer.ModelDataWriter.dump(
args_dict=app_context.args,
output_file=app_context.args["output_file"],
output_file_format=app_context.args.get("output_file_format"),
metadata=None,
product_data=layout,
)
Expand Down
7 changes: 3 additions & 4 deletions src/simtools/applications/generate_regular_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,16 @@ def main():
)

data_writer = writer.ModelDataWriter(
product_data_file=output_file,
product_data_format=app_context.args.get("output_file_format", "ascii.ecsv"),
args_dict=app_context.args,
output_file=output_file,
output_file_format=app_context.args.get("output_file_format", "ascii.ecsv"),
)
data_writer.write(metadata=None, product_data=array_table)

write_array_elements_info_yaml(
array_table,
app_context.args["site"],
app_context.args["model_version"],
Path(data_writer.product_data_file).with_suffix(".info.yml"),
Path(data_writer.output_file).with_suffix(".info.yml"),
)


Expand Down
3 changes: 2 additions & 1 deletion src/simtools/applications/submit_data_from_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def main():
)

writer.ModelDataWriter.dump(
args_dict=app_context.args,
output_file=app_context.args["output_file"],
output_file_format=app_context.args.get("output_file_format"),
metadata=_metadata,
product_data=data_validator.validate_and_transform(),
)
Expand Down
3 changes: 0 additions & 3 deletions src/simtools/camera/camera_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,6 @@ def dump_nsb_pixel_rate(self):
output_path=self.output_dir / cfg.get("telescope") / "nsb_pixel_rate",
)

# temporary fix
self.io_handler.set_paths(output_path=self.output_dir)

def _get_x_max_for_efficiency_type(self):
"""
Get X max value in g/cm2 depending on the efficiency type.
Expand Down
3 changes: 2 additions & 1 deletion src/simtools/camera/single_photon_electron_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def write_single_pe_spectrum(self):
)

writer.ModelDataWriter.dump(
args_dict=self.args_dict,
output_file=self.args_dict["output_file"],
output_file_format=self.args_dict.get("output_file_format"),
metadata=self.metadata,
product_data=table,
validate_schema_file=None,
Expand Down
95 changes: 46 additions & 49 deletions src/simtools/data_model/model_data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from astropy.io.registry.base import IORegistryError

import simtools.utils.general as gen
from simtools import settings
from simtools.data_model import schema, validate_data
from simtools.data_model.metadata_collector import MetadataCollector
from simtools.db import db_handler
Expand All @@ -20,69 +21,62 @@ class ModelDataWriter:

Parameters
----------
product_data_file: str
output_file: str
Name of output file.
product_data_format: str
output_file_format: str
Format of output file.
args_dict: Dictionary
Dictionary with configuration parameters.
output_path: str or Path
Path to output file.
args_dict: dict
Dictionary with configuration parameters.

"""

def __init__(
self,
product_data_file=None,
product_data_format=None,
output_path=None,
args_dict=None,
):
def __init__(self, output_file=None, output_file_format=None, output_path=None):
"""Initialize model data writer."""
self._logger = logging.getLogger(__name__)
self.io_handler = io_handler.IOHandler()
self.schema_dict = {}
if args_dict is not None:
output_path = args_dict.get("output_path", output_path)
if output_path is not None:
self.io_handler.set_paths(output_path=output_path)
self.output_label = "model_data_writer"
self.io_handler.set_paths(
output_path=output_path or settings.config.args.get("output_path"),
output_path_label=self.output_label,
)
try:
self.product_data_file = self.io_handler.get_output_file(file_name=product_data_file)
self.output_file = self.io_handler.get_output_file(
file_name=output_file, output_path_label=self.output_label
)
except TypeError:
self.product_data_file = None
self.product_data_format = self._astropy_data_format(product_data_format)
self.output_file = None
self.output_file_format = self._derive_data_format(output_file_format, self.output_file)

@staticmethod
def dump(
args_dict, output_file=None, metadata=None, product_data=None, validate_schema_file=None
output_file=None,
metadata=None,
product_data=None,
output_file_format="ascii.ecsv",
validate_schema_file=None,
):
"""
Write model data and metadata (as static method).

Parameters
----------
args_dict: dict
Dictionary with configuration parameters (including output file name and path).
output_file: string or Path
Name of output file (args["output_file"] is used if this parameter is not set).
metadata: MetadataCollector object
Metadata to be written.
product_data: astropy Table
Model data to be written
output_file_format: str
Format of output file.
validate_schema_file: str
Schema file used in validation of output data.

"""
writer = ModelDataWriter(
product_data_file=(
args_dict.get("output_file", None) if output_file is None else output_file
),
product_data_format=args_dict.get("output_file_format", "ascii.ecsv"),
args_dict=args_dict,
output_file=output_file,
output_file_format=output_file_format,
)
if validate_schema_file is not None and not args_dict.get("skip_output_validation", True):
skip_output_validation = settings.config.args.get("skip_output_validation", True)
if validate_schema_file is not None and not skip_output_validation:
product_data = writer.validate_and_transform(
product_data_table=product_data,
validate_schema_file=validate_schema_file,
Expand Down Expand Up @@ -137,9 +131,8 @@ def dump_model_parameter(
Validated parameter dictionary.
"""
writer = ModelDataWriter(
product_data_file=output_file,
product_data_format="json",
args_dict=None,
output_file=output_file,
output_file_format="json",
output_path=output_path,
)
if check_db_for_existing_parameter:
Expand Down Expand Up @@ -434,19 +427,17 @@ def write(self, product_data=None, metadata=None):
gen.change_dict_keys_case(metadata.get_top_level_metadata(), True)
)

self._logger.info(f"Writing data to {self.product_data_file}")
if isinstance(product_data, dict) and Path(self.product_data_file).suffix == ".json":
self.write_dict_to_model_parameter_json(self.product_data_file, product_data)
self._logger.info(f"Writing data to {self.output_file}")
if isinstance(product_data, dict) and Path(self.output_file).suffix == ".json":
self.write_dict_to_model_parameter_json(self.output_file, product_data)
return
try:
product_data.write(
self.product_data_file, format=self.product_data_format, overwrite=True
)
product_data.write(self.output_file, format=self.output_file_format, overwrite=True)
except IORegistryError:
self._logger.error(f"Error writing model data to {self.product_data_file}.")
self._logger.error(f"Error writing model data to {self.output_file}.")
raise
if metadata is not None:
metadata.write(self.product_data_file, add_activity_name=True)
metadata.write(self.output_file, add_activity_name=True)

def write_dict_to_model_parameter_json(self, file_name, data_dict):
"""
Expand All @@ -465,10 +456,13 @@ def write_dict_to_model_parameter_json(self, file_name, data_dict):
if data writing was not successful.
"""
data_dict = ModelDataWriter.prepare_data_dict_for_writing(data_dict)
self._logger.info(f"Writing data to {self.io_handler.get_output_file(file_name)}")
output_file = self.io_handler.get_output_file(
file_name, output_path_label=self.output_label
)
self._logger.info(f"Writing data to {output_file}")
ascii_handler.write_data_to_file(
data=data_dict,
output_file=self.io_handler.get_output_file(file_name),
output_file=output_file,
sort_keys=True,
numpy_types=True,
)
Expand Down Expand Up @@ -508,16 +502,19 @@ def prepare_data_dict_for_writing(data_dict):
return data_dict

@staticmethod
def _astropy_data_format(product_data_format):
def _derive_data_format(product_data_format, output_file=None):
"""
Ensure conformance with astropy data format naming.
Derive data format and ensure conformance with astropy data format naming.

If product_data_format is None and output_file is given, derive format
from output_file suffix.

Parameters
----------
product_data_format: string
format identifier

"""
if product_data_format == "ecsv":
product_data_format = "ascii.ecsv"
return product_data_format
if product_data_format is None and output_file is not None:
product_data_format = Path(output_file).suffix.lstrip(".")
return "ascii.ecsv" if product_data_format == "ecsv" else product_data_format
33 changes: 23 additions & 10 deletions src/simtools/io/io_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class IOHandler(metaclass=IOHandlerSingleton):
def __init__(self):
"""Initialize IOHandler."""
self.logger = logging.getLogger(__name__)
self.output_path = None
self.output_path = {}
self.model_path = None

def set_paths(self, output_path=None, model_path=None):
def set_paths(self, output_path=None, model_path=None, output_path_label="default"):
"""
Set paths for input and output.

Expand All @@ -35,18 +35,22 @@ def set_paths(self, output_path=None, model_path=None):
Path pointing to the output directory.
model_path: str or Path
Path pointing to the model file directory.
output_path_label: str
Label for the output path.
"""
self.output_path = output_path
self.output_path[output_path_label] = output_path
self.model_path = model_path

def get_output_directory(self, sub_dir=None):
def get_output_directory(self, sub_dir=None, output_path_label="default"):
"""
Create and get path of an output directory.

Parameters
----------
sub_dir: str or list of str, optional
Name of the subdirectory (ray_tracing, model etc)
output_path_label: str
Label for the output path.

Returns
-------
Expand All @@ -63,16 +67,19 @@ def get_output_directory(self, sub_dir=None):
parts = sub_dir
else:
parts = [sub_dir]
path = Path(self.output_path, *parts)
try:
output_path = Path(self.output_path[output_path_label], *parts)
except KeyError as exc:
raise KeyError(f"Output path label '{output_path_label}' not found") from exc

try:
path.mkdir(parents=True, exist_ok=True)
output_path.mkdir(parents=True, exist_ok=True)
except FileNotFoundError as exc:
raise FileNotFoundError(f"Error creating directory {path!s}") from exc
raise FileNotFoundError(f"Error creating directory {output_path!s}") from exc

return path.resolve()
return output_path.resolve()

def get_output_file(self, file_name, sub_dir=None):
def get_output_file(self, file_name, sub_dir=None, output_path_label="default"):
"""
Get path of an output file.

Expand All @@ -82,12 +89,18 @@ def get_output_file(self, file_name, sub_dir=None):
File name.
sub_dir: sub_dir: str or list of str, optional
Name of the subdirectory (ray_tracing, model etc)
output_path_label: str
Label for the output path.

Returns
-------
Path
"""
return self.get_output_directory(sub_dir).joinpath(file_name).absolute()
return (
self.get_output_directory(sub_dir, output_path_label=output_path_label)
.joinpath(file_name)
.absolute()
)

def get_test_data_file(self, file_name=None):
"""
Expand Down
6 changes: 1 addition & 5 deletions src/simtools/ray_tracing/incident_angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,7 @@ def save_model_parameters(self, results_by_offset):
table["Incidence angle"] = bin_centers * u.deg
table["Fraction"] = hist

writer = ModelDataWriter(
product_data_file=self.output_dir / f"{param_name}.ecsv",
product_data_format="ecsv",
args_dict=self.config_data,
)
writer = ModelDataWriter(output_file=self.output_dir / f"{param_name}.ecsv")
writer.write(
product_data=table,
metadata=MetadataCollector(args_dict=self.config_data),
Expand Down
3 changes: 2 additions & 1 deletion src/simtools/ray_tracing/mirror_panel_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def write_optimization_data(self):
),
)
writer.ModelDataWriter.dump(
args_dict=self.args_dict,
output_file=self.args_dict.get("output_file"),
output_file_format=self.args_dict.get("output_file_format"),
metadata=MetadataCollector(args_dict=self.args_dict),
product_data=result_table,
)
2 changes: 1 addition & 1 deletion tests/unit_tests/configuration/test_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_initialize_io_handler(configurator, tmp_test_directory):
configurator.config["output_path"] = tmp_test_directory
configurator._initialize_io_handler()

assert _io_handler.output_path == tmp_test_directory
assert _io_handler.output_path.get("default") == tmp_test_directory


def test_check_parameter_configuration_status(configurator, args_dict, tmp_test_directory):
Expand Down
Loading