From 4d878a9c0910178e2b13b1b6ac3a96f638279171 Mon Sep 17 00:00:00 2001 From: Drew Oldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 13 Nov 2024 12:13:17 -0800 Subject: [PATCH 1/4] Updating the "train a model" demo notebook, adding it to the documentation. (#117) --- docs/notebooks.rst | 2 +- docs/notebooks/TrainingAModel.ipynb | 66 -------------------- docs/notebooks/intro_notebook.ipynb | 84 ------------------------- docs/notebooks/train_model.ipynb | 97 +++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 151 deletions(-) delete mode 100644 docs/notebooks/TrainingAModel.ipynb delete mode 100644 docs/notebooks/intro_notebook.ipynb create mode 100644 docs/notebooks/train_model.ipynb diff --git a/docs/notebooks.rst b/docs/notebooks.rst index 7f7e544..ed9952e 100644 --- a/docs/notebooks.rst +++ b/docs/notebooks.rst @@ -3,4 +3,4 @@ Notebooks .. toctree:: - Introducing Jupyter Notebooks + Training a simple model diff --git a/docs/notebooks/TrainingAModel.ipynb b/docs/notebooks/TrainingAModel.ipynb deleted file mode 100644 index 34220f5..0000000 --- a/docs/notebooks/TrainingAModel.ipynb +++ /dev/null @@ -1,66 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import fibad" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create an instance of a fibad object, instantiated (implicitly) with the default configuration file\n", - "fibad_instance = fibad.Fibad()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Update a few of the configuration parameters\n", - "fibad_instance.config[\"model\"][\"name\"] = \"ExampleCNN\"\n", - "fibad_instance.config[\"data_set\"][\"name\"] = \"CifarDataSet\"\n", - "fibad_instance.config[\"data_loader\"][\"batch_size\"] = 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Begin training the Example CNN model using the CIFAR-10 dataset\n", - "fibad_instance.train()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "fibad", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/notebooks/intro_notebook.ipynb b/docs/notebooks/intro_notebook.ipynb deleted file mode 100644 index 0589b29..0000000 --- a/docs/notebooks/intro_notebook.ipynb +++ /dev/null @@ -1,84 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "textblock1", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "# Introducing Jupyter Notebooks in Sphinx\n", - "\n", - "This notebook showcases very basic functionality of rendering your jupyter notebooks as tutorials inside your sphinx documentation.\n", - "\n", - "As part of the LINCC Frameworks python project template, your notebooks will be executed AND rendered at document build time.\n", - "\n", - "You can read more about Sphinx, ReadTheDocs, and building notebooks in [LINCC's documentation](https://lincc-ppt.readthedocs.io/en/latest/practices/sphinx.html)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "codeblock1", - "metadata": {}, - "outputs": [], - "source": [ - "def sierpinsky(order):\n", - " \"\"\"Define a method that will create a Sierpinsky triangle of given order,\n", - " and will print it out.\"\"\"\n", - " triangles = [\"*\"]\n", - " for i in range(order):\n", - " spaces = \" \" * (2**i)\n", - " triangles = [spaces + triangle + spaces for triangle in triangles] + [\n", - " triangle + \" \" + triangle for triangle in triangles\n", - " ]\n", - " print(f\"Printing order {order} triangle\")\n", - " print(\"\\n\".join(triangles))" - ] - }, - { - "cell_type": "markdown", - "id": "textblock2", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 1 - }, - "source": [ - "Then, call our method a few times. This will happen on the fly during notebook rendering." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "codeblock2", - "metadata": {}, - "outputs": [], - "source": [ - "for order in range(3):\n", - " sierpinsky(order)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "codeblock3", - "metadata": {}, - "outputs": [], - "source": [ - "sierpinsky(4)" - ] - } - ], - "metadata": { - "jupytext": { - "cell_markers": "\"\"\"" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/notebooks/train_model.ipynb b/docs/notebooks/train_model.ipynb new file mode 100644 index 0000000..91237d3 --- /dev/null +++ b/docs/notebooks/train_model.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intro to Training and Configurations\n", + "\n", + "First we import fibad and create a new fibad object, instantiated (implicitly), with the default configuration file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import fibad\n", + "\n", + "fibad_instance = fibad.Fibad()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this demo, we'll make a few adjustments to the default configuration settings that the `fibad` object was instantiated with. By accessing the `.config` attribute of the fibad instance, we can modify any configuration value. Here we change which built in model to use, the dataset, batch size, number of epochs for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fibad_instance.config[\"model\"][\"name\"] = \"ExampleCNN\"\n", + "fibad_instance.config[\"data_set\"][\"name\"] = \"CifarDataSet\"\n", + "fibad_instance.config[\"data_loader\"][\"batch_size\"] = 64\n", + "fibad_instance.config[\"train\"][\"epochs\"] = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We call the `.train()` method to train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fibad_instance.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output of the training will be stored in a time-stamped directory under the `./results/`. By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`. To run fibad again with the same configuration, you can reference the runtime_config.toml file.\n", + "\n", + "If running in another notebook, instantiate a fibad object like so:\n", + "```\n", + "new_fibad_instance = fibad.Fibad(config_file='./results//runtime_config.toml')\n", + "```\n", + "\n", + "Or from the command line:\n", + "```\n", + ">> fibad train --runtime-config ./results//runtime_config.toml\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fibad", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From ae484c1888aa767f2f7b0ccb6069e08142b663e7 Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Fri, 15 Nov 2024 14:56:45 -0800 Subject: [PATCH 2/4] Hsc data set updates (#120) * Update to default split creation logic and behaviour * We now do more complete inference of test/train/split depending on what is provided in config * New permutations of test/train/split being provided are covered by new tests * Old tested functionality remains unchanged. --------- Co-authored-by: Aritra Ghosh --- src/fibad/data_sets/hsc_data_set.py | 28 +++++++++++--- tests/fibad/test_hsc_dataset.py | 59 ++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 7 deletions(-) diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index 7df1bd8..7f88ad9 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -53,16 +53,34 @@ def _create_splits(self, config): if isinstance(validate_size, int): validate_size = validate_size / len(self.container) - # Fill in any values not provided + # Initialize Test size when not provided if test_size is None: if train_size is None: train_size = 0.25 - test_size = 1.0 - train_size - elif train_size is None: - train_size = 1.0 - test_size - elif validate_size is None: + + if validate_size is None: # noqa: SIM108 + test_size = 1.0 - train_size + else: + test_size = 1.0 - (train_size + validate_size) + + # Initialize train size when not provided, and can be inferred from test_size and validate_size. + if train_size is None: + if validate_size is None: # noqa: SIM108 + train_size = 1.0 - test_size + else: + train_size = 1.0 - (test_size + validate_size) + + # If we still don't have a validate size, decide whether we will infer a validate size + if (validate_size is None) and (np.round(train_size + test_size) != 1.0): validate_size = 1.0 - (train_size + test_size) + # If splits cover more than the entire dataset, error out. + if validate_size is None: + if np.round(train_size + test_size) > 1.0: + raise RuntimeError("Split fractions add up to more than 1.0") + elif np.round(train_size + test_size + validate_size) > 1.0: + raise RuntimeError("Split fractions add up to more than 1.0") + # Generate splits self.splits = {} self.splits["test"] = HSCDataSetSplit(self.container, test_size, seed=seed) diff --git a/tests/fibad/test_hsc_dataset.py b/tests/fibad/test_hsc_dataset.py index 2717685..534c9f3 100644 --- a/tests/fibad/test_hsc_dataset.py +++ b/tests/fibad/test_hsc_dataset.py @@ -58,7 +58,7 @@ def mkconfig( filters=False, train_size=0.2, test_size=0.6, - validate_size=0, + validate_size=0.1, seed=False, filter_catalog=False, ): @@ -367,7 +367,7 @@ def test_split(): test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100)) with FakeFitsFS(test_files): a = HSCDataSet(mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"]), split="validate") - assert len(a) == 20 + assert len(a) == 10 a = HSCDataSet(mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"]), split="test") assert len(a) == 60 @@ -388,6 +388,58 @@ def test_split_no_validate(): a = HSCDataSet(config, split="train") assert len(a) == 20 + with pytest.raises(RuntimeError): + a = HSCDataSet(config, split="validate") + + +def test_split_with_validate_no_test(): + """Test splitting when validate is provided by test size is not""" + test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100)) + with FakeFitsFS(test_files): + config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], test_size=False, validate_size=0.2) + + a = HSCDataSet(config, split="test") + assert len(a) == 60 + + a = HSCDataSet(config, split="train") + assert len(a) == 20 + + a = HSCDataSet(config, split="validate") + assert len(a) == 20 + + +def test_split_with_validate_no_test_no_train(): + """Test splitting when validate is provided by test size is not""" + test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100)) + with FakeFitsFS(test_files): + config = mkconfig( + filters=["HSC-G", "HSC-R", "HSC-I"], test_size=False, train_size=False, validate_size=0.2 + ) + + a = HSCDataSet(config, split="test") + assert len(a) == 55 + + a = HSCDataSet(config, split="train") + assert len(a) == 25 + + a = HSCDataSet(config, split="validate") + assert len(a) == 20 + + +def test_split_with_validate_with_test_no_train(): + """Test splitting when validate is provided by test size is not""" + test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100)) + with FakeFitsFS(test_files): + config = mkconfig( + filters=["HSC-G", "HSC-R", "HSC-I"], test_size=0.6, train_size=False, validate_size=0.2 + ) + + a = HSCDataSet(config, split="test") + assert len(a) == 60 + + a = HSCDataSet(config, split="train") + assert len(a) == 20 + a = HSCDataSet(config, split="validate") assert len(a) == 20 @@ -485,6 +537,9 @@ def test_split_values_configured_no_validate(): a = HSCDataSet(config, split="train") assert len(a) == 22 + a = HSCDataSet(config, split="validate") + assert len(a) == 10 + def test_split_invalid_configured(): """Test that split RuntimeErrors when provided with an invalid datapoint count""" From 4eb83019aa880b003fb1e625aa746af62f9fb185 Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Tue, 19 Nov 2024 15:03:02 -0800 Subject: [PATCH 3/4] Rebuild a fits manifest from an HSC data directory and speedup HSC data loading (#115) # Rebuild a fits manifest from an HSC data directory. * Added a new verb rebuild_manifest * When run with the HSC dataset class this verb will: 0) Scan the data directory and ingest HSC cutout files 1) Read in the original catalog file configured for download for metadata 2) Write out rebuilt_manifest.fits in the data directory * Fixed up config resolution so that fibad_config.toml in the cwd works again for CLI invocations. * Adding progressive logging for long steps. * Rebuild command will never open or use the manifest file in the data directory because the assumption is that file is corrupt. # Speeding up HSC Data loading * Parallelizing _scan_file_dimensions() Using Schwimmbad and multiprocessing to parallelize extracting the dimensions of files in HSCDataSet to effect speedup of 124x on 10M+ file datasets. * Added progressive log entries for HSCDataSet file scan * Use manifest by default when no filter_catalog provided. This skips the file scan on large datasets * Choose number of processes in a way that doesn't run afoul of system limits Co-authored-by: Drew Oldag <47493171+drewoldag@users.noreply.github.com> --- pyproject.toml | 1 + src/fibad/config_utils.py | 63 +++--- src/fibad/data_sets/hsc_data_set.py | 285 +++++++++++++++++++++++++--- src/fibad/download.py | 7 +- src/fibad/fibad.py | 16 +- src/fibad/rebuild_manifest.py | 26 +++ tests/fibad/test_hsc_dataset.py | 14 +- 7 files changed, 345 insertions(+), 67 deletions(-) create mode 100644 src/fibad/rebuild_manifest.py diff --git a/pyproject.toml b/pyproject.toml index f18689f..b195b10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "toml", # Used to load configuration files as dictionaries "torch", # Used for CNN model and in train.py "torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set + "schwimmbad", # Used to speedup hsc data loader file scans ] [project.scripts] diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py index a918039..0962871 100644 --- a/src/fibad/config_utils.py +++ b/src/fibad/config_utils.py @@ -76,15 +76,15 @@ def __init__( runtime_config_filepath: Union[Path, str] = None, default_config_filepath: Union[Path, str] = DEFAULT_CONFIG_FILEPATH, ): - self.fibad_default_config = self._read_runtime_config(default_config_filepath) + self.fibad_default_config = ConfigManager._read_runtime_config(default_config_filepath) - self.runtime_config_filepath = runtime_config_filepath - if self.runtime_config_filepath is None: + self.runtime_config_filepath = ConfigManager.resolve_runtime_config(runtime_config_filepath) + if self.runtime_config_filepath is DEFAULT_CONFIG_FILEPATH: self.user_specific_config = ConfigDict() else: - self.user_specific_config = self._read_runtime_config(self.runtime_config_filepath) + self.user_specific_config = ConfigManager._read_runtime_config(self.runtime_config_filepath) - self.external_library_config_paths = self._find_external_library_default_config_paths( + self.external_library_config_paths = ConfigManager._find_external_library_default_config_paths( self.user_specific_config ) @@ -93,7 +93,7 @@ def __init__( self.config = self.merge_configs(self.overall_default_config, self.user_specific_config) if not self.config["general"]["dev_mode"]: - self._validate_runtime_config(self.config, self.overall_default_config) + ConfigManager._validate_runtime_config(self.config, self.overall_default_config) @staticmethod def _read_runtime_config(config_filepath: Union[Path, str] = DEFAULT_CONFIG_FILEPATH) -> ConfigDict: @@ -232,38 +232,37 @@ def _validate_runtime_config(runtime_config: ConfigDict, default_config: ConfigD raise RuntimeError(msg) ConfigManager._validate_runtime_config(runtime_config[key], default_config[key]) + @staticmethod + def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = None) -> Path: + """Resolve a user-supplied runtime config to where we will actually pull config from. -def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = None) -> Path: - """Resolve a user-supplied runtime config to where we will actually pull config from. - - 1) If a runtime config file is specified, we will use that file - 2) If no file is specified and there is a file named "fibad_config.toml" in the cwd we will use that file - 3) If no file is specified and there is no file named "fibad_config.toml" in the current working directory - we will exclusively work off the configuration defaults in the packaged "fibad_default_config.toml" - file. + 1) If a runtime config file is specified, we will use that file. + 2) If no file is specified and there is a file named "fibad_config.toml" in the cwd we will use it. + 3) If no file is specified and there is no file named "fibad_config.toml" in the cwd we will + exclusively work off the configuration defaults in the packaged "fibad_default_config.toml" file. - Parameters - ---------- - runtime_config_filepath : Union[Path, str, None], optional - Location of the supplied config file, by default None + Parameters + ---------- + runtime_config_filepath : Union[Path, str, None], optional + Location of the supplied config file, by default None - Returns - ------- - Path - Path to the configuration file ultimately used for config resolution. When we fall back to the - package supplied default config file, the Path to that file is returned. - """ - if isinstance(runtime_config_filepath, str): - runtime_config_filepath = Path(runtime_config_filepath) + Returns + ------- + Path + Path to the configuration file ultimately used for config resolution. When we fall back to the + package supplied default config file, the Path to that file is returned. + """ + if isinstance(runtime_config_filepath, str): + runtime_config_filepath = Path(runtime_config_filepath) - # If a named config exists in cwd, and no config specified on cmdline, use cwd. - if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists(): - runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH + # If a named config exists in cwd, and no config specified on cmdline, use cwd. + if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists(): + runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH - if runtime_config_filepath is None: - runtime_config_filepath = DEFAULT_CONFIG_FILEPATH + if runtime_config_filepath is None: + runtime_config_filepath = DEFAULT_CONFIG_FILEPATH - return runtime_config_filepath + return runtime_config_filepath def create_results_dir(config: ConfigDict, postfix: Union[Path, str]) -> Path: diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index 7f88ad9..567d7f0 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -1,7 +1,10 @@ # ruff: noqa: D101, D102 import logging +import multiprocessing +import os import re +import resource from copy import copy, deepcopy from pathlib import Path from typing import Literal, Optional, Union @@ -10,9 +13,21 @@ import torch from astropy.io import fits from astropy.table import Table +from schwimmbad import MultiPool from torch.utils.data import Dataset from torchvision.transforms.v2 import CenterCrop, Compose, Lambda +from fibad.download import Downloader +from fibad.downloadCutout.downloadCutout import ( + parse_bool, + parse_degree, + parse_latitude, + parse_longitude, + parse_rerun, + parse_tract_opt, + parse_type, +) + from .data_set_registry import fibad_data_set logger = logging.getLogger(__name__) @@ -20,6 +35,8 @@ @fibad_data_set class HSCDataSet(Dataset): + _called_from_test = False + """Interface object to allow simple access to splits on a corpus of HSC data files f/s operations and management are handled in HSCDatSetContainer @@ -112,6 +129,9 @@ def __getitem__(self, idx: int) -> torch.Tensor: def __len__(self) -> int: return len(self.current_split) + def rebuild_manifest(self, config): + return self.container._rebuild_manifest(config) + class HSCDataSetSplit(Dataset): def __init__( @@ -273,7 +293,17 @@ def __init__(self, config): crop_to = config["data_set"]["crop_to"] filters = config["data_set"]["filters"] - filter_catalog = config["data_set"]["filter_catalog"] + + if config["data_set"]["filter_catalog"]: + filter_catalog = Path(config["data_set"]["filter_catalog"]) + elif not config.get("rebuild_manifest", False): + # Note "rebuild_manifest" is not a config, its a hack for rebuild_manifest mode + # to ensure we don't use the manifest we believe is corrupt. + filter_catalog = Path(config["general"]["data_dir"]) / Downloader.MANIFEST_FILE_NAME + if not filter_catalog.exists(): + filter_catalog = False + else: + filter_catalog = False self._init_from_path( config["general"]["data_dir"], @@ -344,6 +374,7 @@ def _init_from_path( self.cutout_shape = cutout_shape + self.pruned_objects = {} self._prune_objects(filters_ref) if self.cutout_shape is None: @@ -374,13 +405,15 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> files_dict: filter_name -> file name. Corresponds to self.files """ + logger.info(f"Scanning files in directory {self.path}") + object_id_regex = r"[0-9]{17}" filter_regex = r"HSC-[GRIZY]" if filters is None else "|".join(filters) full_regex = f"({object_id_regex})_.*_({filter_regex}).fits" files = {} # Go scan the path for object ID's so we have a list. - for filepath in Path(self.path).iterdir(): + for index, filepath in enumerate(Path(self.path).iterdir()): filename = filepath.name # If we are filtering based off a user-provided catalog of object ids, Filter out any @@ -408,6 +441,11 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> files_dict: msg += "and will not be included in the data set." logger.error(msg) + if index != 0 and index % 1_000_000 == 0: + logger.info(f"Processed {index} files.") + else: + logger.info(f"Processed {index+1} files") + return files def _read_filter_catalog( @@ -460,12 +498,81 @@ def _read_filter_catalog( return (filter_catalog, dim_catalog) if "dim" in colnames else filter_catalog + @staticmethod + def _determine_numprocs() -> int: + # Figure out how many CPUs we are allowed to use + cpu_count = None + sched_getaffinity = getattr(os, "sched_getaffinity", None) + + if sched_getaffinity: + cpu_count = len(sched_getaffinity(0)) + elif multiprocessing: + cpu_count = multiprocessing.cpu_count() + else: + cpu_count = 1 + + # Ideally we would use ~75 processes per CPU to attempt to saturate + # I/O bandwidth using a small number of CPUs. + numproc = 1 if HSCDataSet._called_from_test else 75 * cpu_count + numproc = HSCDataSetContainer._fixup_limit( + numproc, + resource.RLIMIT_NOFILE, + lambda proc: int(4 * proc + 10), + lambda nofile: int((nofile - 10) / 4), + ) + + numproc = HSCDataSetContainer._fixup_limit( + numproc, resource.RLIMIT_NPROC, lambda proc: proc, lambda proc: proc + ) + return numproc + + @staticmethod + def _fixup_limit(nproc, res, est_limit, est_procs) -> int: + # If launching this many processes would trigger other resource limits, work around them + limit_soft, limit_hard = resource.getrlimit(res) + + # If we would violate the hard limit, calculate the number of processes that wouldn't + # violate the limit + if limit_hard < est_limit(nproc): + nproc = est_procs(limit_hard) + + # If we would violate the soft limit, attempt to change it, leaving the hard limit alone + try: + if limit_soft < est_limit(nproc): + resource.setrlimit(res, (est_limit(nproc), limit_hard)) + finally: + # If the change doesn't take, then reduce the number of processes again + limit_soft, limit_hard = resource.getrlimit(res) + if limit_soft < est_limit(nproc): + nproc = est_procs(limit_soft) + + return nproc + def _scan_file_dimensions(self) -> dim_dict: # Scan the filesystem to get the widths and heights of all images into a dict - return { - object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)] - for object_id in self.ids() - } + logger.info("Scanning for dimensions...") + + retval = {} + with MultiPool(processes=HSCDataSetContainer._determine_numprocs()) as pool: + args = ( + (object_id, list(self._object_files(object_id))) + for object_id in self.ids(log_every=1_000_000) + ) + retval = dict(pool.imap(self._scan_file_dimension, args, chunksize=1000)) + return retval + + @staticmethod + def _scan_file_dimension(processing_unit: tuple[str, list[str]]) -> list[tuple[int, int]]: + object_id, filenames = processing_unit + return (object_id, [HSCDataSetContainer._fits_file_dims(filepath) for filepath in filenames]) + + @staticmethod + def _fits_file_dims(filepath): + try: + with fits.open(filepath) as hdul: + return hdul[1].shape + except OSError: + return (0, 0) def _prune_objects(self, filters_ref: list[str]): """Class initialization helper. Prunes objects from the list of objects. @@ -488,12 +595,12 @@ def _prune_objects(self, filters_ref: list[str]): """ filters_ref = sorted(filters_ref) self.prune_count = 0 - for object_id, filters in list(self.files.items()): + for index, (object_id, filters) in enumerate(self.files.items()): # Drop objects with missing filters filters = sorted(list(filters)) if filters != filters_ref: msg = f"HSCDataSet in {self.path} has the wrong group of filters for object {object_id}." - self._prune_object(object_id, msg) + self._mark_for_prune(object_id, msg) logger.info(f"Filters for object {object_id} were {filters}") logger.debug(f"Reference filters were {filters_ref}") @@ -504,8 +611,16 @@ def _prune_objects(self, filters_ref: list[str]): msg = f"A file for object {object_id} has shape ({shape[1]}px, {shape[1]}px)" msg += " this is too small for the given cutout size of " msg += f"({self.cutout_shape[0]}px, {self.cutout_shape[1]}px)" - self._prune_object(object_id, msg) + self._mark_for_prune(object_id, msg) break + if index != 0 and index % 1_000_000 == 0: + logger.info(f"Processed {index} objects for pruning") + else: + logger.info(f"Processed {index + 1} objects for pruning") + + # Prune marked objects + for object_id, reason in self.pruned_objects.items(): + self._prune_object(object_id, reason) # Log about the pruning process pre_prune_object_count = len(self.files) + self.prune_count @@ -516,6 +631,9 @@ def _prune_objects(self, filters_ref: list[str]): logger.warning("Greater than 1% of objects in the data directory were pruned.") logger.info(f"Pruned {self.prune_count} out of {pre_prune_object_count} objects") + def _mark_for_prune(self, object_id, reason): + self.pruned_objects[object_id] = reason + def _prune_object(self, object_id, reason: str): logger.warning(reason) logger.warning(f"Dropping object {object_id} from the dataset") @@ -524,10 +642,6 @@ def _prune_object(self, object_id, reason: str): del self.dims[object_id] self.prune_count += 1 - def _fits_file_dims(self, filepath): - with fits.open(filepath) as hdul: - return hdul[1].shape - def _check_file_dimensions(self) -> tuple[int, int]: """Class initialization helper. Find the maximal pixel size that all images can support @@ -546,12 +660,14 @@ def _check_file_dimensions(self) -> tuple[int, int]: The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping. """ + logger.info("Checking file dimensions to determine standard cutout size...") + # Find the maximal cutout size that all images can support all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list] - cutout_width = np.min(all_widths) - all_heights = [shape[1] for shape_list in self.dims.values() for shape in shape_list] - cutout_height = np.min(all_heights) + all_dimensions = all_widths + all_heights + cutout_height = np.min(all_dimensions) + cutout_width = cutout_height if ( np.abs(cutout_width - np.mean(all_widths)) > 1 @@ -571,6 +687,92 @@ def _check_file_dimensions(self) -> tuple[int, int]: return cutout_width, cutout_height + def _rebuild_manifest(self, config): + if self.filter_catalog: + raise RuntimeError("Cannot rebuild manifest. Set the filter_catalog=false and rerun") + + logger.info("Reading in catalog file... ") + location_table = Downloader.filterfits( + Path(config["download"]["fits_file"]).resolve(), ["object_id", "ra", "dec"] + ) + + obj_to_ra = { + str(location_table["object_id"][index]): location_table["ra"][index] + for index in range(len(location_table)) + } + obj_to_dec = { + str(location_table["object_id"][index]): location_table["dec"][index] + for index in range(len(location_table)) + } + + del location_table + + logger.info("Assembling Manifest...") + + # These are the column names expected in a manifest file by the downloader + column_names = Downloader.MANIFEST_COLUMN_NAMES + columns = {column_name: [] for column_name in column_names} + + # These will vary every object and must be implemented below + dynamic_column_names = ["object_id", "filter", "dim", "tract", "ra", "dec", "filename"] + # These are pulled from config ("sw", "sh", "rerun", "type", "image", "mask", and "variance") + static_column_names = [name for name in column_names if name not in dynamic_column_names] + + # Check that all column names we need for a manifest are either in static or dynamic columns + for column_name in column_names: + if column_name not in static_column_names and column_name not in dynamic_column_names: + raise RuntimeError(f"Error Assembling manifest {column_name} not implemented") + + static_values = { + "sw": parse_degree(config["download"]["sw"]), + "sh": parse_degree(config["download"]["sh"]), + "rerun": parse_rerun(config["download"]["rerun"]), + "type": parse_type(config["download"]["type"]), + "image": parse_bool(config["download"]["image"]), + "mask": parse_bool(config["download"]["mask"]), + "variance": parse_bool(config["download"]["variance"]), + } + + for index, (object_id, filter, filename, dim) in enumerate(self._all_files_full()): + for static_col in static_column_names: + columns[static_col].append(static_values[static_col]) + + for dynamic_col in dynamic_column_names: + if dynamic_col == "object_id": + columns[dynamic_col].append(int(object_id)) + elif dynamic_col == "filter": + columns[dynamic_col].append(filter) + elif dynamic_col == "dim": + columns[dynamic_col].append(dim) + elif dynamic_col == "tract": + # There's value in pulling tract from the filename rather than the download catalog + # in case The catalog had it wrong, the filename will have the value the cutout server + # provided. + tract = filename.split("_")[4] + columns[dynamic_col].append(parse_tract_opt(tract)) + elif dynamic_col == "ra": + ra = obj_to_ra[object_id] + columns[dynamic_col].append(parse_longitude(ra)) + elif dynamic_col == "dec": + dec = obj_to_dec[object_id] + columns[dynamic_col].append(parse_latitude(dec)) + elif dynamic_col == "filename": + columns[dynamic_col].append(filename) + else: + # The tower of if statements has been entirely to create this failure path. + # which will be hit when someone alters dynamic column names above without also + # writing an implementation. + raise RuntimeError(f"No implementation to process column {dynamic_col}") + if index != 0 and index % 1_000_000 == 0: + logger.info(f"Addeed {index} objects to manifest") + else: + logger.info(f"Addeed {index+1} objects to manifest") + + logger.info("Writing rebuilt manifest...") + manifest_table = Table(columns) + rebuilt_manifest_path = Path(config["general"]["data_dir"]) / "rebuilt_manifest.fits" + manifest_table.write(rebuilt_manifest_path, overwrite=True, format="fits") + def shape(self) -> tuple[int, int, int]: """Shape of the individual cutouts this will give to a model @@ -647,7 +849,7 @@ def _get_file(self, index: int) -> Path: filter = filter_names[index % self.num_filters] return self._file_to_path(filters[filter]) - def ids(self): + def ids(self, log_every=None): """Public read-only iterator over all object_ids that enforces a strict total order across objects. Will not work prior to self.files initialization in __init__ @@ -656,8 +858,33 @@ def ids(self): Iterator[str] Object IDs currently in the dataset """ - for object_id in self.files: + log = log_every is not None and isinstance(log_every, int) + for index, object_id in enumerate(self.files): + if log and index != 0 and index % log_every == 0: + logger.info(f"Processed {index} objects") yield object_id + else: + if log: + logger.info(f"Processed {index} objects") + + def _all_files_full(self): + """ + Private read-only iterator over all files that enforces a strict total order across + objects and filters. Will not work prior to self.files, and self.path initialization in __init__ + + Yields + ------ + Tuple[object_id, filter, filename, dim] + Members of this tuple are + - The object_id as a string + - The filter name as a string + - The filename relative to self.path + - A tuple containing the dimensions of the fits file in pixels. + """ + for object_id in self.ids(): + dims = self.dims[object_id] + for idx, (filter, filename) in enumerate(self._filter_filename(object_id)): + yield (object_id, filter, filename, dims[idx]) def _all_files(self): """ @@ -673,6 +900,22 @@ def _all_files(self): for filename in self._object_files(object_id): yield filename + def _filter_filename(self, object_id): + """ + Private read-only iterator over all files for a given object. This enforces a strict total order + across filters. Will not work prior to self.files initialization in __init__ + + Yields + ------ + filter_name, file name + The name of a filter and the file name for the fits file. + The file name is relative to self.path + """ + filters = self.files[object_id] + filter_names = sorted(list(filters)) + for filter_name in filter_names: + yield filter_name, filters[filter_name] + def _object_files(self, object_id): """ Private read-only iterator over all files for a given object. This enforces a strict total order @@ -683,10 +926,8 @@ def _object_files(self, object_id): Path The path to the file. """ - filters = self.files[object_id] - filter_names = sorted(list(filters)) - for filter in filter_names: - yield self._file_to_path(filters[filter]) + for _, filename in self._filter_filename(object_id): + yield self._file_to_path(filename) def _file_to_path(self, filename: str) -> Path: """Turns a filename into a full path suitable for open. Equivalent to: diff --git a/src/fibad/download.py b/src/fibad/download.py index 9e7e73e..8f78da6 100644 --- a/src/fibad/download.py +++ b/src/fibad/download.py @@ -28,6 +28,8 @@ class Downloader: # of the immutable fields that we rely on for hash checks are also included. RECT_COLUMN_NAMES = list(dict.fromkeys(VARIABLE_FIELDS + dC.Rect.immutable_fields + ["dim"])) + MANIFEST_COLUMN_NAMES = RECT_COLUMN_NAMES + ["filename", "object_id"] + MANIFEST_FILE_NAME = "manifest.fits" def __init__(self, config): @@ -280,9 +282,8 @@ def _write_manifest(self): logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") # Convert the combined manifest into an astropy table by building a dict of {column_name: column_data} - # for all the fields in a rect, plus our object_id and filename. - column_names = Downloader.RECT_COLUMN_NAMES + ["filename", "object_id"] - columns = {column_name: [] for column_name in column_names} + # for all the fields we require in a manifest + columns = {column_name: [] for column_name in Downloader.MANIFEST_COLUMN_NAMES} for rect, msg in combined_manifest.items(): # This parsing relies on the name format set up in create_rects to work properly diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 474a005..4d8cd13 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Union -from .config_utils import ConfigManager, resolve_runtime_config +from .config_utils import ConfigManager class Fibad: @@ -14,7 +14,7 @@ class Fibad: CLI functions in fibad_cli are implemented by calling this class """ - verbs = ["train", "predict", "download", "prepare"] + verbs = ["train", "predict", "download", "prepare", "rebuild_manifest"] def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True): """Initialize fibad. Always applies the default config, and merges it with any provided config file. @@ -88,7 +88,7 @@ def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool # Setup our handlers from config self._initialize_log_handlers() - self.logger.info(f"Runtime Config read from: {resolve_runtime_config(config_file)}") + self.logger.info(f"Runtime Config read from: {ConfigManager.resolve_runtime_config(config_file)}") def _initialize_log_handlers(self): """Private initialization helper, Adds handlers and level setting to the global self.logger object""" @@ -180,8 +180,16 @@ def predict(self, **kwargs): def prepare(self, **kwargs): """ - See Fibad.predict.run() + See Fibad.prepare.run() """ from .prepare import run return run(config=self.config, **kwargs) + + def rebuild_manifest(self, **kwargs): + """ + See Fibad.rebuild_manifest.run() + """ + from .rebuild_manifest import run + + return run(config=self.config, **kwargs) diff --git a/src/fibad/rebuild_manifest.py b/src/fibad/rebuild_manifest.py new file mode 100644 index 0000000..445df05 --- /dev/null +++ b/src/fibad/rebuild_manifest.py @@ -0,0 +1,26 @@ +import logging + +from fibad.pytorch_ignite import setup_model_and_dataset + +logger = logging.getLogger(__name__) + + +def run(config): + """Rebuild a broken download manifest + + Parameters + ---------- + config : dict + The parsed config file as a nested + dict + """ + + config["rebuild_manifest"] = True + + _, data_set = setup_model_and_dataset(config, split=config["train"]["split"]) + + logger.info("Starting rebuild of manifest") + + data_set.rebuild_manifest(config) + + logger.info("Finished Rebuild Manifest") diff --git a/tests/fibad/test_hsc_dataset.py b/tests/fibad/test_hsc_dataset.py index 534c9f3..d8f64a2 100644 --- a/tests/fibad/test_hsc_dataset.py +++ b/tests/fibad/test_hsc_dataset.py @@ -9,6 +9,8 @@ test_dir = Path(__file__).parent / "test_data" / "dataloader" +HSCDataSet._called_from_test = True + class FakeFitsFS: """ @@ -131,8 +133,8 @@ def test_load(caplog): # 10 objects should load assert len(a) == 10 - # The number of filters, and image dimensions should be correct - assert a.shape() == (5, 262, 263) + # The number of filters, and image dimensions should be correct and square + assert a.shape() == (5, 262, 262) # No warnings should be printed assert caplog.text == "" @@ -152,8 +154,8 @@ def test_load_duplicate(caplog): # Only 10 objects should load assert len(a) == 10 - # The number of filters, and image dimensions should be correct - assert a.shape() == (5, 262, 263) + # The number of filters, and image dimensions should be correct and square + assert a.shape() == (5, 262, 262) # We should get duplicate object errors assert "Duplicate object ID" in caplog.text @@ -327,8 +329,8 @@ def test_partial_filter(caplog): # 10 objects should load assert len(a) == 10 - # The number of filters, and image dimensions should be correct - assert a.shape() == (2, 262, 263) + # The number of filters, and image dimensions should be correct and square + assert a.shape() == (2, 262, 262) # No warnings should be printed assert caplog.text == "" From 6c775b9df39064fae85ba42b54cacb63f5555cbb Mon Sep 17 00:00:00 2001 From: Aritra Ghosh Date: Fri, 22 Nov 2024 19:25:28 -0800 Subject: [PATCH 4/4] Removing dimensions print statement --- src/fibad/data_sets/hsc_data_set.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index 567d7f0..6628f42 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -358,7 +358,6 @@ def _init_from_path( if isinstance(self.filter_catalog, tuple): self.files = self.filter_catalog[0] self.dims = self.filter_catalog[1] - print(self.dims) elif isinstance(self.filter_catalog, dict): self.files = self.filter_catalog self.dims = self._scan_file_dimensions()