diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 591af9f1..4e2d0a06 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,7 +57,7 @@ jobs: - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 with: - python-version: "3.11" + python-version: "3.13" - name: Install Python package run: | diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..12543467 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +Tobias Buck, astroai@iwr.uni-heidelberg.de. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/FILESTRUCTURE.md b/FILESTRUCTURE.md deleted file mode 100644 index c422e004..00000000 --- a/FILESTRUCTURE.md +++ /dev/null @@ -1,40 +0,0 @@ -This is an explanation of the file structure that the cookiecutter generated for you: - -* Python source files: - * The Python package source files are located in the `rubix` directory. - * `tests/test_rubix.py` contains the unit tests for the package. - * `tests/conftest.py` contains testing setup and configuration for `pytest` - * The `notebooks` directory contains an example Jupyter notebook on how to use `rubix`. - This notebook is always executed during `pytest` execution and it is automatically - rendered into the Sphinx documentation. -* Markdown files with meta information on the project. [Markdown](https://www.markdownguide.org/basic-syntax/) is - a good language for these files, as it is easy to write and rendered into something beautiful by your git repository - hosting provider. - * `README.md` is the file that users will typically see first when discovering your project. - * `COPYING.md` provides a list of copyright holders. - * `LICENSE.md` contains the license you selected. - * `TODO.md` contains a list of TODOs after running the cookiecutter. Following the - instructions in that file will give you a fully functional repository with a lot - of integration into useful web services activated and running. - * `FILESTRUCTURE.md` describes the generated files. Feel free to remove this from the - repository if you do not need it. -* Python build system files - * `pyproject.toml` is the central place for configuration of your Python package. - It contains the project metadata, setuptools-specific information and the configuration - for your toolchain (like e.g. `pytest`). - * `setup.py` is still required for editable builds, but you should not need to change it. - In the future, `setuptools` will support editable builds purely from `pyproject.toml` - configuration. -* Configuration for CI/Code Analysis and documentation services - * `.github/workflows/ci.yml` describes the Github Workflow for Continuous - integration. For further reading on workflow files, we recommend the - [introduction into Github Actions](https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/introduction-to-github-actions) - and [the reference of available options](https://docs.github.com/en/free-pro-team@latest/actions/reference/workflow-syntax-for-github-actions). - * `.github/dependabot.yml` configures the DependaBot integration on GitHub that - allows you to automatically create pull requests for updates of the used actions - in `.github/workflows/ci.yml`. - * `.gitlab-ci.yml` describes the configuration for Gitlab CI. For further - reading, we recommend [Gitlabs quick start guide](https://docs.gitlab.com/ee/ci/quick_start/) - and the [Gitlab CI configuration reference](https://docs.gitlab.com/ce/ci/yaml/) - * `.readthedocs.yml` configures the documentation build process at [ReadTheDocs](https://readthedocs.org). - To customize your build, you can have a look at the [available options](https://docs.readthedocs.io/en/stable/config-file/v2.html). diff --git a/Makefile b/Makefile index d0c3cbf1..b97de95f 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = source +SOURCEDIR = docs BUILDDIR = build # Put it first so that "make" without argument is like "make help". diff --git a/README.md b/README.md index ba0008d8..8e914572 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/AstroAI-Lab/rubix/blob/main/docs/CONTRIBUTING.md) [![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/AstroAI-Lab/rubix/ci.yml?branch=main)](https://github.com/AstroAI-Lab/rubix/actions/workflows/ci.yml) -[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/AstroAI-Lab/rubix/CI?label=build)](https://github.com/AstroAI-Lab/rubix/actions/workflows/ci.yml) [![Documentation Status](https://readthedocs.org/projects/rubix/badge/)](https://astro-rubix.web.app) [![codecov](https://codecov.io/gh/AstroAI-Lab/rubix/branch/main/graph/badge.svg)](https://codecov.io/gh/AstroAI-Lab/rubix) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) @@ -26,14 +25,16 @@ Key features include: ## Installation -The Python package `rubix` can be downloades from git and can be installed: +The Python package `rubix` is published on GitHub and can be installed alongside its runtime dependencies (including JAX) by choosing the relevant extras. For a CPU-only environment, install with: ``` git clone https://github.com/AstroAI-Lab/rubix.git cd rubix -pip install . +pip install .[cpu] ``` +If you need GPU acceleration, replace `[cpu]` with `[cuda]` (or install `jax[cuda]` following the [JAX instructions](https://github.com/google/jax#installation) before installing Rubix). The plain `pip install .` command installs the minimal package without JAX and will raise `ImportError` if you try to import `rubix` before adding `jax` manually. + ## Development installation If you want to contribute to the development of `rubix`, we recommend @@ -42,7 +43,7 @@ the following editable installation from this repository: ``` git clone https://github.com/AstroAI-Lab/rubix.git cd rubix -python -m pip install --editable .[tests] +python -m pip install --editable .[cpu,tests,dev] ``` Having done so, the test suite can be run using `pytest`: @@ -51,9 +52,21 @@ Having done so, the test suite can be run using `pytest`: python -m pytest ``` -This project depends on [jax](https://github.com/google/jax). It only installed for cpu computations with the testing dependencies. For installation instructions with gpu support, -please refer to [here](https://github.com/google/jax?tab=readme-ov-file#installation). +This project depends on [jax](https://github.com/google/jax). For the pytests we only test the `cpu` version. +For installation instructions with gpu support, +please refer to [here](https://github.com/google/jax?tab=readme-ov-file#installation) or simply use the `cuda` option when pip installing. + +## Configuration overview + +Rubix ships with two YAML files in `rubix/config/`: `rubix_config.yml` (constants, SSP templates, dust recipes, handler mappings, etc.) and `pipeline_config.yml` (pipeline graphs such as `calc_ifu` and `calc_dusty_ifu`). There is no configuration wizard — your runtime settings must supply a dictionary with the following blocks: + +- `pipeline.name`: Identifies the pipeline from `pipeline_config.yml` (e.g., `calc_ifu`, `calc_dusty_ifu`, or `calc_gradient`). +- `galaxy`: Must provide `dist_z` and a `rotation` section (`type` or explicit `alpha`, `beta`, `gamma`). +- `telescope`: Requires `name`, `psf` (currently only the `gaussian` kernel with `size` and `sigma`), `lsf` (`sigma`), and `noise` (`signal_to_noise` plus `noise_distribution`, choose from `normal` or `uniform`). +- `ssp.dust`: Must declare `extinction_model` and `Rv` before calling the dusty pipeline (see `rubix/spectra/dust/extinction_models.py` for the supported models such as `Cardelli89`). +- `data.args.particle_type`: Should include `"stars"` (and `"gas"` if you want the gas branch) so the filters and rotation functions know which components exist. +The tutorials and notebooks assume square spaxels, so the default telescope factory currently only supports `pixel_type: square`. For a working example, inspect `notebooks/rubix_pipeline_single_function_shard_map.ipynb`, which runs the exact pipeline used in the tests. ## Documentation Sphinx Documentation of all the functions is currently available under [this link](https://astro-rubix.web.app/). @@ -63,7 +76,7 @@ Sphinx Documentation of all the functions is currently available under [this lin Contributions to `rubix` are welcome and greatly appreciated! Whether you're fixing bugs, improving documentation, or suggesting new features, your help is valuable to us. -Please see [here](source/CONTRIBUTING.md) for contribution guidelines. +Please see [here](docs/CONTRIBUTING.md) for contribution guidelines. Thank you for helping improve `rubix`! @@ -71,7 +84,8 @@ Thank you for helping improve `rubix`! Please cite **both** of the following papers ([Cakir et al. 2024](https://arxiv.org/abs/2412.08265), [Schaible et al. 2025](https://arxiv.org/abs/2511.17110)) if you use Rubix in your research: -@ARTICLE{2024arXiv241208265C, +``` + @ARTICLE{2024arXiv241208265C, author = {{{\c{C}}ak{\i}r}, Ufuk and {Schaible}, Anna Lena and {Buck}, Tobias}, title = "{Fast GPU-Powered and Auto-Differentiable Forward Modeling of IFU Data Cubes}", journal = {arXiv e-prints}, @@ -81,14 +95,14 @@ Please cite **both** of the following papers ([Cakir et al. 2024](https://arxiv. eid = {arXiv:2412.08265}, pages = {arXiv:2412.08265}, doi = {10.48550/arXiv.2412.08265}, -archivePrefix = {arXiv}, + archivePrefix = {arXiv}, eprint = {2412.08265}, - primaryClass = {astro-ph.IM}, + primaryClass = {astro-ph.IM}, adsurl = {https://ui.adsabs.harvard.edu/abs/2024arXiv241208265C}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} -} + } -@ARTICLE{2025arXiv251117110S, + @ARTICLE{2025arXiv251117110S, author = {{Schaible}, Anna Lena and {{\c{C}}ak{\i}r}, Ufuk and {Buck}, Tobias and {Mack}, Harald and {Obreja}, Aura and {Oguz}, Nihat and {Oliver}, William H. and {C{\u{a}}r{\u{a}}mizaru}, Horea-Alexandru}, title = "{RUBIX: Differentiable forward modelling of galaxy spectral data cubes for gradient-based parameter estimation}", journal = {arXiv e-prints}, @@ -98,13 +112,13 @@ archivePrefix = {arXiv}, eid = {arXiv:2511.17110}, pages = {arXiv:2511.17110}, doi = {10.48550/arXiv.2511.17110}, -archivePrefix = {arXiv}, + archivePrefix = {arXiv}, eprint = {2511.17110}, - primaryClass = {astro-ph.GA}, + primaryClass = {astro-ph.GA}, adsurl = {https://ui.adsabs.harvard.edu/abs/2025arXiv251117110S}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} -} - + } +``` @@ -138,7 +152,7 @@ archivePrefix = {arXiv}, ## Licence -[GNU General Public License v3.0](https://github.com/synthesizer-project/synthesizer/blob/main/LICENSE.md) +[MIT License](https://github.com/AstroAI-Lab/rubix/blob/main/LICENSE.md) ## Acknowledgments diff --git a/source/Contributing.md b/docs/CONTRIBUTING.md similarity index 96% rename from source/Contributing.md rename to docs/CONTRIBUTING.md index a62b1805..0863887a 100644 --- a/source/Contributing.md +++ b/docs/CONTRIBUTING.md @@ -49,6 +49,7 @@ at the root of the repo to activate the pre-commit hooks. If you would like to test whether it works you can run `pre-commit run --all-files` to run the pre-commit hook on the whole repo. You should see each stage complete without issue in a clean clone. + ## Using Black We use [Black](https://black.readthedocs.io/en/stable/) for code formatting. Assuming you installed the development dependencies (if not you can install `black` with pip: `pip install black`), you can run the linting with `black {source_file_or_directory}`. For more details see the [Black documentation](https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html). @@ -105,13 +106,13 @@ Adding content should be relatively simple if you follow the instructions below. To add Jupyter notebooks to the documentation: -1. Add your Jupyter notebook to the `notebooks` directory under the `docs/source` folder. Make sure that you 'Restart Kernel and run all cells' to ensure that the notebook is producing up to date, consistent outputs. -2. Add your notebook to the relevant toctree. See below for an example toctree. Each toctree is contained within a Sphinx `.rst` file in each documentation source directory. The top-level file is `source/index.rst`. If your file is in a subfolder, you need to update the `.rst` file in that directory. +1. Add your Jupyter notebook to the `notebooks` directory under the `docs` folder. Make sure that you 'Restart Kernel and run all cells' to ensure that the notebook is producing up to date, consistent outputs. +2. Add your notebook to the relevant toctree. See below for an example toctree. Each toctree is contained within a Sphinx `.rst` file in each documentation source directory. The top-level file is `docs/index.rst`. If your file is in a subfolder, you need to update the `.rst` file in that directory. - If you're creating a new sub-directory of documentation, you will need to carry out a couple more steps: 1. Create a new `.rst` file in that directory -2. Update `source/index.rst` with the path to that `.rst` file +2. Update `docs/index.rst` with the path to that `.rst` file 3. Currently we do not run pytests on jupyter notebooks. So please make sure your notebooks are actually working fine. Example toctree: @@ -164,7 +165,7 @@ sphinx-quickstart #### Configuration and Content -The core of the documentation setup resides in the source folder: +The core of the documentation setup resides in the `docs` folder: - `conf.py`: This is the main configuration file where you define extensions (like myst_nb for notebooks), set the theme, and manage global build settings. diff --git a/source/acknowledgments.rst b/docs/acknowledgments.rst similarity index 100% rename from source/acknowledgments.rst rename to docs/acknowledgments.rst diff --git a/source/conf.py b/docs/conf.py similarity index 91% rename from source/conf.py rename to docs/conf.py index 288aba81..00e6cc16 100644 --- a/source/conf.py +++ b/docs/conf.py @@ -11,9 +11,9 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = "rubix" -copyright = "2024, Ufuk, Tobias, Anna Lena" -author = "Ufuk, Tobias, Anna Lena" +project = "Rubix" +copyright = "2025, Anna Lena Schaible, Ufuk Cakir, Tobias Buck" +author = "Anna Lena Schaible, Ufuk Cakir, Tobias Buck" release = "0.1" # -- General configuration --------------------------------------------------- diff --git a/source/index.rst b/docs/index.rst similarity index 100% rename from source/index.rst rename to docs/index.rst diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 00000000..1483c677 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,42 @@ +Installation +============ + +`RUBIX` can be installed via `pip` + +Clone the repository and navigate to the root directory of the repository. Then run + +``` +pip install .[cpu] +``` + +If you want to contribute to the development of `RUBIX`, we recommend the following editable installation from this repository: + +``` +git clone https://github.com/AstroAI-Lab/rubix +cd rubix +pip install -e .[cpu,tests,dev] +``` +Having done so, the test suit can be run unsing `pytest`: + +``` +python -m pytest +``` + +Note that if `JAX` is not yet installed, with the `cpu` option only the CPU version of `JAX` will be installed +as a dependency. For a GPU-compatible installation of `JAX`, please refer to the +[JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html) or use the option `cuda`. + +Get started with this simple example notebooks/rubix_pipeline_single_function_shard_map.ipynb. + +Configuration +============= + +When you run the pipeline you provide a configuration dict that references the files in `rubix/config/`. The following sections are required for the default pipelines: + +- `pipeline.name`: Choose one of `calc_ifu`, `calc_dusty_ifu`, or another entry from `pipeline_config.yml`. +- `galaxy`: Must include `dist_z` and a `rotation` block (`type` or explicit `alpha`, `beta`, `gamma`). +- `telescope`: Needs `name`, a `psf` block (Gaussian kernel with both `size` and `sigma`), an `lsf` block with `sigma`, and `noise` containing `signal_to_noise` plus a `noise_distribution` (`normal` or `uniform`). +- `ssp.dust`: Declares `extinction_model` and `Rv` before the dusty pipeline can produce an extincted datacube. +- `data.args.particle_type`: Must include `"stars"` (add `"gas"` if you rely on the optional gas branch) so the filtering/rotation steps know which components to process. + +The telescopes in `rubix/telescope` currently only support square pixels, so every config should set `pixel_type: square` in the relevant telescope definition. diff --git a/source/license.rst b/docs/license.rst similarity index 100% rename from source/license.rst rename to docs/license.rst diff --git a/source/modules.rst b/docs/modules.rst similarity index 100% rename from source/modules.rst rename to docs/modules.rst diff --git a/source/notebooks/cosmology.ipynb b/docs/notebooks/cosmology.ipynb similarity index 100% rename from source/notebooks/cosmology.ipynb rename to docs/notebooks/cosmology.ipynb diff --git a/source/notebooks/create_rubix_data.ipynb b/docs/notebooks/create_rubix_data.ipynb similarity index 100% rename from source/notebooks/create_rubix_data.ipynb rename to docs/notebooks/create_rubix_data.ipynb diff --git a/source/notebooks/demo.yml b/docs/notebooks/demo.yml similarity index 100% rename from source/notebooks/demo.yml rename to docs/notebooks/demo.yml diff --git a/source/notebooks/dust_extinction.ipynb b/docs/notebooks/dust_extinction.ipynb similarity index 100% rename from source/notebooks/dust_extinction.ipynb rename to docs/notebooks/dust_extinction.ipynb diff --git a/source/notebooks/filter_curves.ipynb b/docs/notebooks/filter_curves.ipynb similarity index 100% rename from source/notebooks/filter_curves.ipynb rename to docs/notebooks/filter_curves.ipynb diff --git a/source/notebooks/gradient_age_metallicity_adamoptimizer_multi.ipynb b/docs/notebooks/gradient_age_metallicity_adamoptimizer_multi.ipynb similarity index 100% rename from source/notebooks/gradient_age_metallicity_adamoptimizer_multi.ipynb rename to docs/notebooks/gradient_age_metallicity_adamoptimizer_multi.ipynb diff --git a/source/notebooks/gradient_age_metallicity_adamoptimizer_vs_finite_diff.ipynb b/docs/notebooks/gradient_age_metallicity_adamoptimizer_vs_finite_diff.ipynb similarity index 100% rename from source/notebooks/gradient_age_metallicity_adamoptimizer_vs_finite_diff.ipynb rename to docs/notebooks/gradient_age_metallicity_adamoptimizer_vs_finite_diff.ipynb diff --git a/source/notebooks/output/rubix_galaxy.h5 b/docs/notebooks/output/rubix_galaxy.h5 similarity index 100% rename from source/notebooks/output/rubix_galaxy.h5 rename to docs/notebooks/output/rubix_galaxy.h5 diff --git a/source/notebooks/pipeline_demo.ipynb b/docs/notebooks/pipeline_demo.ipynb similarity index 100% rename from source/notebooks/pipeline_demo.ipynb rename to docs/notebooks/pipeline_demo.ipynb diff --git a/source/notebooks/psf.ipynb b/docs/notebooks/psf.ipynb similarity index 100% rename from source/notebooks/psf.ipynb rename to docs/notebooks/psf.ipynb diff --git a/source/notebooks/rubix_pipeline_single_function_shard_map.ipynb b/docs/notebooks/rubix_pipeline_single_function_shard_map.ipynb similarity index 100% rename from source/notebooks/rubix_pipeline_single_function_shard_map.ipynb rename to docs/notebooks/rubix_pipeline_single_function_shard_map.ipynb diff --git a/source/notebooks/rubix_pipeline_stepwise.ipynb b/docs/notebooks/rubix_pipeline_stepwise.ipynb similarity index 83% rename from source/notebooks/rubix_pipeline_stepwise.ipynb rename to docs/notebooks/rubix_pipeline_stepwise.ipynb index ea3b191a..04d9926d 100644 --- a/source/notebooks/rubix_pipeline_stepwise.ipynb +++ b/docs/notebooks/rubix_pipeline_stepwise.ipynb @@ -53,6 +53,17 @@ "os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "import os\n", + "os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'" + ] + }, { "cell_type": "code", "execution_count": null, @@ -241,18 +252,88 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 6: Data cube calculation\n", + "## Step 6: Reshape data\n", "\n", - "This is the heart of the `pipeline`. Now we do the lookup for the spectrum for each stellar particle. For the simple stellar population model by `BruzualCharlot2003`, each stellar particle gets a spectrum assigned based on its age and metallicity.\n", + "At the moment we have to reshape the rubix data that we can split the data on multiple GPUs. We plan to move from pmap to shard_map. Then this step should not be necessary any more. This step has purely computational reason and no physics motivated reason." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.data import get_reshape_data\n", + "reshape_data = get_reshape_data(config)\n", "\n", - "We scale the stellar particle spectra by its mass. The stellar spectra have to be scaled by the stellar mass. Later heavier stellar particles should contribute more to the spectrum in a spaxel than lighter stellar particles.\n", + "rubixdata = reshape_data(rubixdata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Spectra calculation\n", "\n", - "The stellar particles are not at rest and therefore the emitted light is doppler shifted with respect to the observer. Before adding all stellar spectra in each spaxel, we dopplershift the spectra according to their particle velocity and we resample the spectra to the wavelength grid of the observing instrument.\n", + "This is the heart of the `pipeline`. Now we do the lookup for the spectrum for each stellar particle. For the simple stellar population model by `BruzualCharlot2003`, each stellar particle gets a spectrum assigned based on its age and metallicity. In the plot we can see that the spectrum differs for different stellar particles." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.ifu import get_calculate_spectra\n", + "calcultae_spectra = get_calculate_spectra(config)\n", + "\n", + "rubixdata = calcultae_spectra(rubixdata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "import jax.numpy as jnp\n", + "\n", + "plt.plot(jnp.arange(len(rubixdata.stars.spectra[0][0][:])), rubixdata.stars.spectra[0][0][:])\n", + "plt.plot(jnp.arange(len(rubixdata.stars.spectra[0][0][:])), rubixdata.stars.spectra[0][1][:])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Scaling by mass\n", + "\n", + "The stellar spectra have to be scaled by the stellar mass. Later heavier stellar particles should contribute more to the spectrum in a spaxel than lighter stellar particles." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.ifu import get_scale_spectrum_by_mass\n", + "scale_spectrum_by_mass = get_scale_spectrum_by_mass(config)\n", "\n", - "Each stellar particle falls into one spaxel in the datacube. We ad the stellar particles spectra contribution to the according spaxel in the datacube. We do these steps for all atellar particles.\n", + "rubixdata = scale_spectrum_by_mass(rubixdata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Doppler shifting and resampling\n", "\n", - "The first plot shows the spectra for two different spaxels.\n", - "The second plot shows the spatial dimension of the `datacube`, where we summed over the wavelength dimension." + "The stellar particles are not at rest and therefore the emitted light is doppler shifted with respect to the observer. Before adding all stellar spectra in each spaxel, we dopplershift the spectra according to their particle velocity and we resample the spectra to the wavelength grid of the observing instrument." ] }, { @@ -262,10 +343,10 @@ "outputs": [], "source": [ "# NBVAL_SKIP\n", - "from rubix.core.ifu import get_calculate_datacube_particlewise\n", - "calculate_datacube_particlewise = get_calculate_datacube_particlewise(config)\n", + "from rubix.core.ifu import get_doppler_shift_and_resampling\n", + "doppler_shift_and_resampling = get_doppler_shift_and_resampling(config)\n", "\n", - "rubixdata = calculate_datacube_particlewise(rubixdata)" + "rubixdata = doppler_shift_and_resampling(rubixdata)" ] }, { @@ -283,8 +364,30 @@ "print(wave)\n", "print(rubixdata.stars.datacube[0][0][:])\n", "\n", - "plt.plot(wave, rubixdata.stars.datacube[12][12][:])\n", - "plt.plot(wave, rubixdata.stars.datacube[10][5][:])" + "plt.plot(wave, rubixdata.stars.spectra[0][0][:])\n", + "plt.plot(wave, rubixdata.stars.spectra[0][1][:])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 10: Datacube\n", + "\n", + "Now we can add all stellar spectra that contribute to one spaxel and get the IFU datacube. The plot shows the spatial dimension of the `datacube`, where we summed over the wavelength dimension." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.ifu import get_calculate_datacube\n", + "calculate_datacube = get_calculate_datacube(config)\n", + "\n", + "rubixdata = calculate_datacube(rubixdata)" ] }, { @@ -414,59 +517,6 @@ "\n", "Congratulations, you have now created step by step your own mock-observed IFU datacube! Now enjoy playing around with the RUBIX pipeline and enjoy doing amazing science with RUBIX :)" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Store datacube in a fits file with header\n", - "\n", - "Keep in mind that this it the luminosity cube. If you want to have a flux cube, you have to convert it. You can do this with the `rubix.spectra.ifu.convert_luminoisty_to_flux` function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.fits import store_fits\n", - "\n", - "store_fits(config, rubixdata.stars.datacube, \"./output/\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Visualisation of the datacube\n", - "\n", - "We show, how you can visualize the datacube and see the image and spectra and explore the datacube. We show our own build tool to visualize the datacube and second we provide the opportunity to load the datacube with the Cubeviz module from jdaviz.\n", - "\n", - "`visualize_rubix` uses mpdaf to load the datacube. This is a package specialized to load MUSE datacubes. The function will thisplay you on the left an image collapsed along the wavelength and on the right a spectrum for a certain pixel or aperture. \n", - "\n", - "Explanation of the sliders:\n", - "- Waveindex: Waveindex, which wavelength slice is plotted in the image.\n", - "- Wavelengthrange: Range in wavelength that is collapsed to the image.\n", - "- X Pixel: X coordinate of the displayed spectrum and x coordinate of the center of the aperture.\n", - "- Y Pixel: Y coordinate of the displayed spectrum and y coordinate of the center of the aperture.\n", - "- Radius: size of the circular aperture in pixels. If this value is set to zerro, only the spaxel specified in the x and y pixel is considered for the spectrum plot.\n", - "\n", - "Now you can explore your datacube with the sliders!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.visualisation import visualize_rubix\n", - "\n", - "#visualize_rubix(\"./output/IllustrisTNG_id11_snap99_stars_subsetTrue.fits\")" - ] } ], "metadata": { diff --git a/source/notebooks/ssp_interpolation.ipynb b/docs/notebooks/ssp_interpolation.ipynb similarity index 100% rename from source/notebooks/ssp_interpolation.ipynb rename to docs/notebooks/ssp_interpolation.ipynb diff --git a/source/notebooks/ssp_template.ipynb b/docs/notebooks/ssp_template.ipynb similarity index 98% rename from source/notebooks/ssp_template.ipynb rename to docs/notebooks/ssp_template.ipynb index c5a07ca8..9be9ee16 100644 --- a/source/notebooks/ssp_template.ipynb +++ b/docs/notebooks/ssp_template.ipynb @@ -20,6 +20,17 @@ "os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "import os\n", + "os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/source/notebooks/ssp_template_fsps.ipynb b/docs/notebooks/ssp_template_fsps.ipynb similarity index 79% rename from source/notebooks/ssp_template_fsps.ipynb rename to docs/notebooks/ssp_template_fsps.ipynb index 5f13da1e..1c046068 100644 --- a/source/notebooks/ssp_template_fsps.ipynb +++ b/docs/notebooks/ssp_template_fsps.ipynb @@ -107,6 +107,52 @@ "}" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "config = {\n", + " \"name\": \"FSPS (Conroy et al. 2009)\",\n", + " # more information on how those models are synthesized: https://github.com/cconroy20/fsps\n", + " # and https://dfm.io/python-fsps/current/\n", + " \"format\": \"fsps\", # Format of the template\n", + " \"source\": \"load_from_file\", # the source can be \"load_from_file\" or \"rerun_from_scratch\"\n", + " # \"load_from_file\" is the default and loads the template from a pre-existing file in h5 format specified by \"file_name\"\n", + " # if that file is not found, it will automatically run fsps and save the output to disk in h5 format under the \"file_name\" given.\n", + " # \"rerun_from_scratch\" # note: this is just meant for the case in which you really want to rerun your template library.\n", + " # You should be aware that fsps templates will silently be overwritten by this. Use with caution.\n", + " \"file_name\": \"fsps.h5\", # File name of the template, stored in templates directory\n", + " # Define the Fields in the template and their units\n", + " # This is used to convert them to the required units\n", + " \"fields\":{ # Fields in the template and their units\n", + " # Name defines the name of the key stored in the hdf5 file\n", + " \"age\":{\n", + " \"name\": \"age\",\n", + " \"units\": \"Gyr\", # Age of the template\n", + " \"in_log\": True # If the field is stored in log scale\n", + " },\n", + " \"metallicity\":{\n", + " \"name\": \"metallicity\",\n", + " \"units\": \"\", # Metallicity of the template\n", + " \"in_log\": True # If the field is stored in log scale\n", + " },\n", + " \"wavelength\":{\n", + " \"name\": \"wavelength\",\n", + " \"units\": \"Angstrom\", # Wavelength of the template\n", + " \"in_log\": False # If the field is stored in log scale\n", + " },\n", + " \"flux\":{\n", + " \"name\": \"flux\",\n", + " \"units\": \"Lsun/Angstrom\", # Luminosity of the template as per pyFSPS documentation\n", + " \"in_log\": False # If the field is stored in log scale\n", + " }\n", + " }\n", + "}" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/source/notebooks/telescope.ipynb b/docs/notebooks/telescope.ipynb similarity index 100% rename from source/notebooks/telescope.ipynb rename to docs/notebooks/telescope.ipynb diff --git a/source/publications.rst b/docs/publications.rst similarity index 100% rename from source/publications.rst rename to docs/publications.rst diff --git a/source/rubix.core.rst b/docs/rubix.core.rst similarity index 100% rename from source/rubix.core.rst rename to docs/rubix.core.rst diff --git a/source/rubix.cosmology.rst b/docs/rubix.cosmology.rst similarity index 100% rename from source/rubix.cosmology.rst rename to docs/rubix.cosmology.rst diff --git a/source/rubix.galaxy.input_handler.rst b/docs/rubix.galaxy.input_handler.rst similarity index 100% rename from source/rubix.galaxy.input_handler.rst rename to docs/rubix.galaxy.input_handler.rst diff --git a/source/rubix.galaxy.rst b/docs/rubix.galaxy.rst similarity index 100% rename from source/rubix.galaxy.rst rename to docs/rubix.galaxy.rst diff --git a/source/rubix.pipeline.rst b/docs/rubix.pipeline.rst similarity index 100% rename from source/rubix.pipeline.rst rename to docs/rubix.pipeline.rst diff --git a/source/rubix.rst b/docs/rubix.rst similarity index 100% rename from source/rubix.rst rename to docs/rubix.rst diff --git a/source/rubix.spectra.rst b/docs/rubix.spectra.rst similarity index 100% rename from source/rubix.spectra.rst rename to docs/rubix.spectra.rst diff --git a/source/rubix.spectra.ssp.rst b/docs/rubix.spectra.ssp.rst similarity index 100% rename from source/rubix.spectra.ssp.rst rename to docs/rubix.spectra.ssp.rst diff --git a/source/rubix.telescope.filters.rst b/docs/rubix.telescope.filters.rst similarity index 100% rename from source/rubix.telescope.filters.rst rename to docs/rubix.telescope.filters.rst diff --git a/source/rubix.telescope.lsf.rst b/docs/rubix.telescope.lsf.rst similarity index 100% rename from source/rubix.telescope.lsf.rst rename to docs/rubix.telescope.lsf.rst diff --git a/source/rubix.telescope.noise.rst b/docs/rubix.telescope.noise.rst similarity index 100% rename from source/rubix.telescope.noise.rst rename to docs/rubix.telescope.noise.rst diff --git a/source/rubix.telescope.psf.rst b/docs/rubix.telescope.psf.rst similarity index 100% rename from source/rubix.telescope.psf.rst rename to docs/rubix.telescope.psf.rst diff --git a/source/rubix.telescope.rst b/docs/rubix.telescope.rst similarity index 100% rename from source/rubix.telescope.rst rename to docs/rubix.telescope.rst diff --git a/source/rubix.utils.rst b/docs/rubix.utils.rst similarity index 100% rename from source/rubix.utils.rst rename to docs/rubix.utils.rst diff --git a/docs/versions.rst b/docs/versions.rst new file mode 100644 index 00000000..ef8b08b4 --- /dev/null +++ b/docs/versions.rst @@ -0,0 +1,14 @@ +Code versions +============ + +`RUBIX` has different code versions. The current version is `0.1`. + +Version 0.1 +----------- +Forwardmodel IFU cubes of galaxies from cosmological hydrodynamical simulations (IllustrisTNG50, NIHAO, ...) for stellar particles from different stellar templates (Bruzual&Charlot, Mastar, FSPS, EMILES). +Gradient calculation through the whole pipeline for gradient-based parameter estimation on particle parameters. + + +Version 0.2 +----------- +Under developement diff --git a/notebooks/cosmology.ipynb b/notebooks/cosmology.ipynb index e956c1bb..7d97377f 100644 --- a/notebooks/cosmology.ipynb +++ b/notebooks/cosmology.ipynb @@ -75,8 +75,8 @@ "from rubix.cosmology.utils import trapz\n", "import jax.numpy as jnp\n", "\n", - "x = jnp.array([0, 1, 2, 3])\n", - "y = jnp.array([0, 1, 4, 9])\n", + "x = jnp.array([0.0, 1.0, 2.0, 3.0])\n", + "y = jnp.array([0.0, 1.0, 4.0, 9.0])\n", "print(trapz(x, y))" ] }, @@ -102,7 +102,7 @@ ], "metadata": { "kernelspec": { - "display_name": "rubix", + "display_name": "publishrubix", "language": "python", "name": "python3" }, @@ -116,7 +116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/notebooks/gradient_age_metallicity_variational_inference.ipynb b/notebooks/gradient_age_metallicity_variational_inference.ipynb deleted file mode 100644 index 3bcc8526..00000000 --- a/notebooks/gradient_age_metallicity_variational_inference.ipynb +++ /dev/null @@ -1,643 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from jax import config\n", - "#config.update(\"jax_enable_x64\", True)\n", - "#config.update('jax_num_cpu_devices', 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "import os\n", - "\n", - "# Tell XLA to fake 2 host CPU devices\n", - "#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=3'\n", - "\n", - "# Only make GPU 0 and GPU 1 visible to JAX:\n", - "#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'\n", - "\n", - "#os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", - "\n", - "import jax\n", - "\n", - "# Now JAX will list two CpuDevice entries\n", - "print(jax.devices())\n", - "# → [CpuDevice(id=0), CpuDevice(id=1)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import os\n", - "#os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n", - "#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", - "os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'\n", - "#os.environ['SPS_HOME'] = '/export/home/aschaibl/fsps'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load ssp template from FSPS" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.spectra.ssp.factory import get_ssp_template\n", - "ssp_fsps = get_ssp_template(\"FSPS\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "age_values = ssp_fsps.age\n", - "print(age_values.shape)\n", - "\n", - "metallicity_values = ssp_fsps.metallicity\n", - "print(metallicity_values.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Configure pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.pipeline import RubixPipeline\n", - "import os\n", - "config = {\n", - " \"pipeline\":{\"name\": \"calc_gradient\",},\n", - " \n", - " \"logger\": {\n", - " \"log_level\": \"DEBUG\",\n", - " \"log_file_path\": None,\n", - " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", - " },\n", - " \"data\": {\n", - " \"name\": \"IllustrisAPI\",\n", - " \"args\": {\n", - " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", - " \"particle_type\": [\"stars\"],\n", - " \"simulation\": \"TNG50-1\",\n", - " \"snapshot\": 99,\n", - " \"save_data_path\": \"data\",\n", - " },\n", - " \n", - " \"load_galaxy_args\": {\n", - " \"id\": 14,\n", - " \"reuse\": True,\n", - " },\n", - " \n", - " \"subset\": {\n", - " \"use_subset\": True,\n", - " \"subset_size\": 2,\n", - " },\n", - " },\n", - " \"simulation\": {\n", - " \"name\": \"IllustrisTNG\",\n", - " \"args\": {\n", - " \"path\": \"data/galaxy-id-14.hdf5\",\n", - " },\n", - " \n", - " },\n", - " \"output_path\": \"output\",\n", - "\n", - " \"telescope\":\n", - " {\"name\": \"TESTGRADIENT\",\n", - " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", - " \"lsf\": {\"sigma\": 1.2},\n", - " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},\n", - " },\n", - " \"cosmology\":\n", - " {\"name\": \"PLANCK15\"},\n", - " \n", - " \"galaxy\":\n", - " {\"dist_z\": 0.1,\n", - " \"rotation\": {\"type\": \"edge-on\"},\n", - " },\n", - " \n", - " \"ssp\": {\n", - " \"template\": {\n", - " \"name\": \"FSPS\"\n", - " },\n", - " \"dust\": {\n", - " \"extinction_model\": \"Cardelli89\",\n", - " \"dust_to_gas_ratio\": 0.01,\n", - " \"dust_to_metals_ratio\": 0.4,\n", - " \"dust_grain_density\": 3.5,\n", - " \"Rv\": 3.1,\n", - " },\n", - " }, \n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "pipe = RubixPipeline(config)\n", - "inputdata = pipe.prepare_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Gradient on the spectrum for each wavelenght" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.pipeline import linear_pipeline as pipeline\n", - "\n", - "pipeline_instance = RubixPipeline(config)\n", - "\n", - "pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(\n", - " pipeline_instance.pipeline_config, \n", - " pipeline_instance._get_pipeline_functions()\n", - ")\n", - "pipeline_instance._pipeline.assemble()\n", - "pipeline_instance.func = pipeline_instance._pipeline.compile_expression()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# pick values\n", - "initial_age_index = 95\n", - "initial_metallicity_index = 4\n", - "age0 = age_values[initial_age_index]\n", - "Z0 = metallicity_values[initial_metallicity_index]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "print(f\"age0 = {age0}, Z0 = {Z0}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import jax.numpy as jnp\n", - "\n", - "inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])\n", - "inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])\n", - "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", - "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", - "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import dataclasses\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "def spectrum_1d(age, Z, base_data, pipeline_instance):\n", - " # broadcast per-star\n", - " nstar = base_data.stars.age.shape[0]\n", - " stars2 = dataclasses.replace(\n", - " base_data.stars,\n", - " age=jnp.full((nstar,), age),\n", - " metallicity=jnp.full((nstar,), Z),\n", - " )\n", - " data2 = dataclasses.replace(base_data, stars=stars2)\n", - "\n", - " out = pipeline_instance.func(data2)\n", - "\n", - " cube = out.stars.datacube # shape (…, n_lambda)\n", - " # collapse all non-wavelength axes, keep wavelength last\n", - " spec = cube.reshape((-1, cube.shape[-1])).sum(axis=0)\n", - "\n", - " return jnp.ravel(spec) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "spec0 = spectrum_1d(age0, Z0, inputdata, pipeline_instance)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import matplotlib.pyplot as plt\n", - "wave = pipe.telescope.wave_seq" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from tensorflow_probability.substrates import jax as tfp\n", - "tfd = tfp.distributions\n", - "tfb = tfp.bijectors\n", - "\n", - "import tqdm\n", - "import optax\n", - "import flax.linen as nn\n", - "from flax.metrics import tensorboard" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "class AffineCoupling(nn.Module):\n", - " @nn.compact\n", - " def __call__(self, x, nunits):\n", - " net = nn.leaky_relu(nn.Dense(128)(x))\n", - " net = nn.leaky_relu(nn.Dense(128)(net))\n", - " shift = nn.Dense(nunits)(net)\n", - " scale = nn.softplus(nn.Dense(nunits)(net))\n", - " return tfb.Chain([ tfb.Shift(shift), tfb.Scale(scale)])\n", - "\n", - "def make_nvp_fn(n_layers=2, d=2):\n", - " # We alternate between permutations and flow layers\n", - " layers = [ tfb.Permute([1,0])(tfb.RealNVP(d//2,\n", - " bijector_fn=AffineCoupling(name='affine%d'%i)))\n", - " for i in range(n_layers) ]\n", - "\n", - " # We build the actual nvp from these bijectors and a standard Gaussian distribution\n", - " nvp = tfd.TransformedDistribution(\n", - " tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=0.05*jnp.ones(2)),\n", - " bijector=tfb.Chain([tfb.Shift([5,0.05])] + layers ))\n", - " # Note that we have here added a shift to the bijector\n", - " return nvp\n", - "\n", - "class NeuralSplineFlowSampler(nn.Module):\n", - " @nn.compact\n", - " def __call__(self, key, n_samples):\n", - " nvp = make_nvp_fn()\n", - " x = nvp.sample(n_samples, seed=key)\n", - " return x, nvp.log_prob(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "model = NeuralSplineFlowSampler()\n", - "params = model.init(jax.random.PRNGKey(42), jax.random.PRNGKey(1), 16)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import pandas as pd\n", - "from chainconsumer import ChainConsumer, Chain, Truth\n", - "\n", - "# 1) Draw samples from the untrained bounded flow\n", - "theta0, logq0 = model.apply(params, key=jax.random.PRNGKey(1), n_samples=500)\n", - "df = pd.DataFrame(theta0, columns=[\"age\", \"Z\"])\n", - "\n", - "# 2) Optional: pick a fiducial point (for synthetic tests use your known truth)\n", - "fid_age = age0 # example: mid of [0, 20]\n", - "fid_Z = Z0 # example: inside [4.5e-5, 4.5e-2]\n", - "\n", - "# 3) Build the ChainConsumer plot\n", - "c = ChainConsumer()\n", - "c.add_chain(Chain(samples=df, name=\"Initial VI\"))\n", - "c.add_truth(Truth(location={\"age\": fid_age, \"Z\": fid_Z}))\n", - "\n", - "fig = c.plotter.plot(figsize=\"column\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "def log_prior_gaussian(theta_batch,\n", - " mu_age=6.0, sigma_age=3.0,\n", - " mu_Z=1.3e-3, sigma_Z=2e-4):\n", - " \"\"\"Gaussian prior in physical space.\"\"\"\n", - " age = theta_batch[:, 0]\n", - " Z = theta_batch[:, 1]\n", - " lp_age = -0.5 * (((age - mu_age) / sigma_age)**2\n", - " + jnp.log(2*jnp.pi*sigma_age**2))\n", - " lp_Z = -0.5 * (((Z - mu_Z) / sigma_Z)**2\n", - " + jnp.log(2*jnp.pi*sigma_Z**2))\n", - " return lp_age + lp_Z # shape (batch,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import jax, jax.numpy as jnp\n", - "\n", - "def log_likelihood(y, s, mask=None):\n", - " \"\"\"Full-vector Gaussian log-likelihood.\"\"\"\n", - " if mask is None:\n", - " mask = jnp.ones_like(y)\n", - " r = y - s\n", - " term = (r**2)\n", - " return jnp.sum(term * mask)\n", - "\n", - "def make_batched_loglik(y, base_data, pipeline_instance, mask=None):\n", - " \"\"\"Returns a function mapping a batch of theta -> per-sample log-likelihood.\"\"\"\n", - " def one_theta(theta):\n", - " age, Z = theta[0], theta[1]\n", - " s = spectrum_1d(age, Z, base_data, pipeline_instance) # -> (n_lambda,)\n", - " return log_likelihood(y, s, mask=mask, )\n", - " return jax.vmap(one_theta) # (batch,2) -> (batch,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "def make_elbo_fn(y, base_data, pipeline_instance,\n", - " mask=None, \n", - " mu_age=7.0, sigma_age=2.0,\n", - " mu_Z=0.001, sigma_Z=1e-3):\n", - " batched_loglik = make_batched_loglik(y, base_data,\n", - " pipeline_instance, mask)\n", - "\n", - " def elbo(params, seed, n_samples=128):\n", - " # Draw θ ~ q_φ(θ)\n", - " theta_batch, log_q = model.apply(params, key=seed, n_samples=n_samples)\n", - " # Compute log p(θ)\n", - " log_p = log_prior_gaussian(theta_batch, mu_age, sigma_age, mu_Z, sigma_Z)\n", - " # Compute log p(y|θ)\n", - " log_lik = batched_loglik(theta_batch)\n", - " # ELBO\n", - " elbo_value = jnp.mean(log_lik + log_p - log_q)\n", - " return -elbo_value # minimize\n", - " return elbo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# Random key\n", - "seed = jax.random.PRNGKey(0)\n", - "\n", - "# Scheduler and optimizer\n", - "total_steps = 20_000\n", - "lr = 2e-3\n", - "# lr_scheduler = optax.piecewise_constant_schedule(\n", - "# init_value=1e-3,\n", - "# boundaries_and_scales={int(total_steps*0.5): 0.2}\n", - "# )\n", - "optimizer = optax.adam(lr) #lr_scheduler)\n", - "opt_state = optimizer.init(params)\n", - "\n", - "# TensorBoard logs\n", - "from flax.metrics import tensorboard\n", - "summary_writer = tensorboard.SummaryWriter(\"logs/elbo\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "eps = 1e-6\n", - "sigma_obs = jnp.maximum(jnp.abs(spec0) / 1000.0, eps)\n", - "y = spec0\n", - "base_data = inputdata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# Build once, outside update_model\n", - "elbo = make_elbo_fn(\n", - " y, # observed full flux vector\n", - " base_data,\n", - " pipeline_instance, \n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "@jax.jit\n", - "def update_model(params, opt_state, seed):#, n_samples=128):\n", - " # split RNG: first return is new seed you’ll keep, second is used to sample θ\n", - " seed, key = jax.random.split(seed)\n", - "\n", - " # loss(params) = -ELBO(params, key, n_samples)\n", - " loss, grads = jax.value_and_grad(elbo)(params, key)#, n_samples)\n", - "\n", - " # apply Adam step; passing params is safest for transforms that need them\n", - " updates, opt_state = optimizer.update(grads, opt_state, params)\n", - " params = optax.apply_updates(params, updates)\n", - "\n", - " return params, opt_state, loss, seed" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import tqdm\n", - "\n", - "losses = []\n", - "\n", - "for i in tqdm.tqdm(range(total_steps)):\n", - " # one optimization step (minimizes -ELBO)\n", - " params, opt_state, loss, seed = update_model(params, opt_state, seed)\n", - "\n", - " losses.append(float(loss))\n", - "\n", - " # log every 10 steps\n", - " if i % 10 == 0:\n", - " summary_writer.scalar(\"neg_elbo\", float(loss), i)\n", - " #summary_writer.scalar(\"learning_rate\", float(lr_scheduler(i)), i)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# 1) Sample posterior θ = (age, Z)\n", - "seed, sub = jax.random.split(seed)\n", - "theta, log_q = model.apply(params, key=sub, n_samples=5000) # theta.shape == (5000, 2)\n", - "age = theta[:, 0]\n", - "Z = theta[:, 1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "c = ChainConsumer()\n", - "\n", - "# fresh RNG split so we don’t reuse training key\n", - "seed, sub = jax.random.split(seed)\n", - "\n", - "# sample θ ~ qϕ(θ)\n", - "theta, log_q = model.apply(params, key=sub, n_samples=20_000) # shape (N, 2)\n", - "age = theta[:, 0]\n", - "Z = theta[:, 1]\n", - "\n", - "# ChainConsumer expects a pandas DataFrame\n", - "df = pd.DataFrame({\"age\": age, \"Z\": Z})\n", - "\n", - "# add the VI chain\n", - "c.add_chain(Chain(samples=df, name=\"VI\"))\n", - "\n", - "# optional “truth” dot: use known synthetic truth if you have it; else posterior mean\n", - "# truth_age, truth_Z = 8.0, 1.0e-2 # <- set these if you know them\n", - "# truth_age, truth_Z = float(age.mean()), float(Z.mean())\n", - "truth_age, truth_Z = age0, Z0\n", - "c.add_truth(Truth(location={\"age\": truth_age, \"Z\": truth_Z}))\n", - "\n", - "fig = c.plotter.plot(figsize=\"column\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "plt.figure(figsize=(7,3))\n", - "plt.plot(np.arange(len(losses)), losses, lw=1)\n", - "plt.xlabel(\"Iteration\")\n", - "plt.ylabel(\"Loss\")\n", - "plt.grid(True)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rubix", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index a53894a3..49b9347e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "ipywidgets", "jdaviz", "pynbody", + "optax", "opt-einsum >=3.3.0", ] [project.optional-dependencies] @@ -91,7 +92,6 @@ tests = [ "pytest-mock", "requests-mock", "nbval", - "jax[cpu]>0.5.1", "pre-commit", ] docs = [ diff --git a/rubix/core/cosmology.py b/rubix/core/cosmology.py index 49f73994..193f6d51 100644 --- a/rubix/core/cosmology.py +++ b/rubix/core/cosmology.py @@ -21,7 +21,7 @@ def get_cosmology(config: dict) -> RubixCosmology: ValueError: When ``config["cosmology"]["name"]`` is not supported. Example: - :: + >>> config = { ... ... ... "cosmology": diff --git a/rubix/core/data.py b/rubix/core/data.py index 0d4dbf20..ebc55fe1 100644 --- a/rubix/core/data.py +++ b/rubix/core/data.py @@ -2,12 +2,12 @@ import os from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker +from beartype.typing import Any, Callable, Optional, Union from jaxtyping import jaxtyped from rubix.galaxy import IllustrisAPI, get_input_handler @@ -265,7 +265,7 @@ def convert_to_rubix(config: Union[dict, str]): ValueError: When ``config['data']['name']`` is unsupported. Example: - :: + >>> import os >>> from rubix.core.data import convert_to_rubix @@ -397,7 +397,7 @@ def prepare_input(config: Union[dict, str]) -> RubixData: ValueError: When subset mode is enabled but neither stars nor gas coordinates exist. Example: - :: + >>> import os >>> from rubix.core.data import convert_to_rubix, prepare_input @@ -430,7 +430,7 @@ def prepare_input(config: Union[dict, str]) -> RubixData: # Set the galaxy attributes rubixdata.galaxy.redshift = jnp.float64(data["redshift"]) rubixdata.galaxy.redshift_unit = units["galaxy"]["redshift"] - rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64) + rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float32) rubixdata.galaxy.center_unit = units["galaxy"]["center"] rubixdata.galaxy.halfmassrad_stars = jnp.float64(data["subhalo_halfmassrad_stars"]) rubixdata.galaxy.halfmassrad_stars_unit = units["galaxy"]["halfmassrad_stars"] @@ -550,7 +550,7 @@ def get_reshape_data(config: Union[dict, str]) -> Callable: Function that reshapes a `RubixData` instance. Example: - :: + >>> from rubix.core.data import get_reshape_data >>> reshape_data = get_reshape_data(config) >>> rubixdata = reshape_data(rubixdata) diff --git a/rubix/core/dust.py b/rubix/core/dust.py index 45a38b35..e86d85a6 100644 --- a/rubix/core/dust.py +++ b/rubix/core/dust.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.core.cosmology import get_cosmology diff --git a/rubix/core/ifu.py b/rubix/core/ifu.py index 19af7ef1..6402e845 100644 --- a/rubix/core/ifu.py +++ b/rubix/core/ifu.py @@ -1,8 +1,7 @@ -from typing import Callable, Union - import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable from jax import lax from jaxtyping import Array, Float, jaxtyped @@ -192,7 +191,7 @@ def calculate_dusty_datacube_particlewise( ext_model = config["ssp"]["dust"]["extinction_model"] Rv = config["ssp"]["dust"]["Rv"] # Dynamically choose the extinction model based on the string name - if ext_model not in RV_MODELS: + if ext_model not in RV_MODELS: # pragma: no cover raise ValueError( "Extinction model '{ext_model}' is not available. " f"Choose from {RV_MODELS}." diff --git a/rubix/core/lsf.py b/rubix/core/lsf.py index 5e60760c..49839878 100644 --- a/rubix/core/lsf.py +++ b/rubix/core/lsf.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -24,7 +23,6 @@ def get_convolve_lsf(config: dict) -> Callable[[RubixData], RubixData]: ValueError: When the telescope LSF configuration or sigma is missing. Example: - :: >>> config = { ... ... diff --git a/rubix/core/noise.py b/rubix/core/noise.py index 472023e7..d5ffc037 100644 --- a/rubix/core/noise.py +++ b/rubix/core/noise.py @@ -1,7 +1,6 @@ -from typing import Callable - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -27,7 +26,6 @@ def get_apply_noise(config: dict) -> Callable[[RubixData], RubixData]: ValueError: When required noise configuration keys are missing. Example: - :: >>> config = { ... ... diff --git a/rubix/core/pipeline.py b/rubix/core/pipeline.py index 3fb5a03f..b8bd72bd 100644 --- a/rubix/core/pipeline.py +++ b/rubix/core/pipeline.py @@ -1,11 +1,23 @@ import time -from typing import Any, Optional, Sequence, Union +import warnings import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Any, Optional, Sequence, Union from jax import lax -from jax.experimental.shard_map import shard_map + +try: + from jax.shard_map import shard_map # type: ignore[attr-defined] +except ImportError: # pragma: no cover - older JAX compatibility + warnings.filterwarnings( + "ignore", + message="jax.experimental.shard_map is deprecated in v0.8.0.*", + category=DeprecationWarning, + module=__name__, + ) + from jax.experimental.shard_map import shard_map + from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.tree_util import tree_map from jaxtyping import jaxtyped @@ -36,16 +48,17 @@ class RubixPipeline: Parsed configuration dictionary or path to a configuration file. Example: - :: >>> from rubix.core.pipeline import RubixPipeline >>> config = "path/to/config.yml" + >>> target_datacube = ... # Load or define your target datacube here >>> pipe = RubixPipeline(config) >>> inputdata = pipe.prepare_data() - >>> output = pipe.run(inputdata) >>> final_datacube = pipe.run_sharded(inputdata) - >>> ssp_model = pipeline.ssp - >>> telescope = pipeline.telescope + >>> ssp_model = pipe.ssp + >>> telescope = pipe.telescope + >>> loss_value = pipe.loss(inputdata, target_datacube) + >>> gradient_data = pipe.gradient(inputdata, target_datacube) """ def __init__(self, user_config: Union[dict, str]): @@ -157,7 +170,7 @@ def run_sharded( self.logger.info("Compiling the expressions...") self.func = self._pipeline.compile_expression() - if devices is None: + if devices is None: # pragma: no cover devices = jax.devices() num_devices = len(devices) else: @@ -304,6 +317,6 @@ def loss( jnp.ndarray: Scalar mean squared error value. """ - output = self.run(rubixdata) + output = self.run_sharded(rubixdata) loss_value = jnp.sum((output - targetdata) ** 2) return loss_value diff --git a/rubix/core/psf.py b/rubix/core/psf.py index 274c4ef8..46dc40e5 100644 --- a/rubix/core/psf.py +++ b/rubix/core/psf.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -29,7 +28,6 @@ def get_convolve_psf(config: dict) -> Callable: kernel type. Example: - :: >>> config = { ... ... diff --git a/rubix/core/rotation.py b/rubix/core/rotation.py index 6023270c..f2db5c3f 100644 --- a/rubix/core/rotation.py +++ b/rubix/core/rotation.py @@ -26,7 +26,7 @@ def get_galaxy_rotation(config: dict): or missing. Example: - :: + >>> config = { ... ... ... "galaxy": { diff --git a/rubix/core/ssp.py b/rubix/core/ssp.py index 850f33cb..dd9bb448 100644 --- a/rubix/core/ssp.py +++ b/rubix/core/ssp.py @@ -1,7 +1,6 @@ -from typing import Callable - import jax from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger diff --git a/rubix/core/telescope.py b/rubix/core/telescope.py index d9fddd6e..b847f20b 100644 --- a/rubix/core/telescope.py +++ b/rubix/core/telescope.py @@ -1,7 +1,6 @@ -from typing import Callable, Union - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable, Union from jaxtyping import Array, Float, jaxtyped from rubix.logger import get_logger @@ -153,7 +152,7 @@ def get_filter_particles(config: dict) -> Callable: Callable[[RubixData], RubixData]: Function that filters particles. Example: - :: + >>> from rubix.core.telescope import get_filter_particles >>> filter_particles = get_filter_particles(config) diff --git a/rubix/cosmology/base.py b/rubix/cosmology/base.py index a6c460ee..feff6005 100644 --- a/rubix/cosmology/base.py +++ b/rubix/cosmology/base.py @@ -40,7 +40,7 @@ class BaseCosmology(eqx.Module): h (jnp.float32): Dimensionless Hubble constant. Example: - :: + >>> # Create Planck15 cosmology >>> from rubix.cosmology import COSMOLOGY >>> cosmo = COSMOLOGY(0.3089, -1.0, 0.0, 0.6774) @@ -73,7 +73,7 @@ def scale_factor_to_redshift( Float[Array, "..."]: Redshift ``1/a - 1``. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Convert scale factor 0.5 to redshift >>> cosmo.scale_factor_to_redshift(jnp.array(0.5)) @@ -121,7 +121,7 @@ def comoving_distance_to_z( Float[Array, "..."]: Comoving distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate comoving distance to redshift 0.5 >>> cosmo.comoving_distance_to_z(0.5) @@ -145,7 +145,7 @@ def luminosity_distance_to_z( Float[Array, "..."]: Luminosity distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the luminosity distance to redshift 0.5 >>> cosmo.luminosity_distance_to_z(0.5) @@ -167,7 +167,7 @@ def angular_diameter_distance_to_z( Float[Array, "..."]: Angular diameter distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the angular diameter distance to redshift 0.5 >>> cosmo.angular_diameter_distance_to_z(0.5) @@ -189,7 +189,7 @@ def distance_modulus_to_z( Float[Array, "..."]: Distance modulus. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the distance modulus to redshift 0.5 >>> cosmo.distance_modulus_to_z(0.5) @@ -211,7 +211,7 @@ def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, ".. Float[Array, "..."]: Hubble time in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the Hubble time at redshift 0.5 >>> cosmo._hubble_time(0.5) @@ -235,7 +235,7 @@ def lookback_to_z( Float[Array, "..."]: Lookback time in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the lookback time to redshift 0.5 >>> cosmo.lookback_to_z(0.5) @@ -256,7 +256,7 @@ def age_at_z0(self) -> Float[Array, "..."]: The age of the universe at redshift 0 (float). Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the age of the universe at redshift 0 >>> cosmo.age_at_z0() @@ -294,7 +294,7 @@ def age_at_z( Float[Array, "..."]: Age in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the age of the universe at redshift 0.5 >>> cosmo.age_at_z(0.5) @@ -317,7 +317,7 @@ def angular_scale( Float[Array, "..."]: Angular scale in kpc/arcsec. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the angular scale at redshift 0.5 >>> cosmo.angular_scale(0.5) @@ -326,34 +326,3 @@ def angular_scale( D_A = self.angular_diameter_distance_to_z(z) # in Mpc scale = D_A * (jnp.pi / (180 * 60 * 60)) * 1e3 # in kpc/arcsec return scale - - """ - I dont think we need this currently, but keeping it here for reference - @jit - def rho_crit(self, redshift): - rho_crit0 = RHO_CRIT0_KPC3_UNITY_H * self.h * self.h - rho_crit = rho_crit0 * self._Ez(redshift) ** 2 - return rho_crit - - @jit - def _integrand_oneOverEz1pz(self, z): - return 1.0 / self._Ez(z) / (1.0 + z) - - @jit - def _Om_at_z(self, z): - E = self._Ez(z) - return self.Om0 * (1.0 + z) ** 3 / E / E - - @jit - def _delta_vir(self, z): - x = self._Om(z) - 1.0 - Delta = 18 * jnp.pi**2 + 82.0 * x - 39.0 * x**2 - return Delta - - @jit - def virial_dynamical_time(self, redshift): - delta = self._delta_vir(redshift) - t_cross = 2**1.5 * self._hubble_time(redshift) * delta**-0.5 - return t_cross - -""" diff --git a/rubix/cosmology/utils.py b/rubix/cosmology/utils.py index 07b07109..be4eb497 100644 --- a/rubix/cosmology/utils.py +++ b/rubix/cosmology/utils.py @@ -71,7 +71,7 @@ def trapz( jnp.ndarray: Scalar results collected from the scan. Example: - :: + >>> from rubix.cosmology.utils import trapz >>> import jax.numpy as jnp diff --git a/rubix/galaxy/alignment.py b/rubix/galaxy/alignment.py index 7ff9c52e..f2010802 100644 --- a/rubix/galaxy/alignment.py +++ b/rubix/galaxy/alignment.py @@ -1,7 +1,6 @@ -from typing import Tuple, Union - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Tuple, Union from jax.scipy.spatial.transform import Rotation from jaxtyping import Array, Float, jaxtyped @@ -23,7 +22,7 @@ def center_particles(rubixdata: object, key: str) -> object: ValueError: If the galaxy center lies outside the particle bounds. Example: - :: + >>> from rubix.galaxy.alignment import center_particles >>> rubixdata = center_particles(rubixdata, "stars") """ @@ -84,7 +83,7 @@ def moment_of_inertia_tensor( Float[Array, "..."]: Moment of inertia tensor. Example: - :: + >>> from rubix.galaxy.alignment import moment_of_inertia_tensor >>> I = moment_of_inertia_tensor( ... rubixdata.stars.coords, diff --git a/rubix/galaxy/input_handler/api/illustris_api.py b/rubix/galaxy/input_handler/api/illustris_api.py index b5340c23..c35c2c21 100644 --- a/rubix/galaxy/input_handler/api/illustris_api.py +++ b/rubix/galaxy/input_handler/api/illustris_api.py @@ -224,7 +224,7 @@ def load_galaxy( unsupported particle type is configured. Example: - :: + >>> illustris_api = IllustrisAPI( ... api_key, ... simulation="TNG50-1", diff --git a/rubix/galaxy/input_handler/base.py b/rubix/galaxy/input_handler/base.py index 33941fd5..b2ef4d47 100644 --- a/rubix/galaxy/input_handler/base.py +++ b/rubix/galaxy/input_handler/base.py @@ -163,25 +163,6 @@ def _check_galaxy_data(self, galaxy_data, units): if field not in units["galaxy"]: raise ValueError(f"Units for {field} not found in units") - """ - def _check_particle_data(self, particle_data, units): - # Check if all required fields are present - for key in self.config["particles"]: - if key not in particle_data: - raise ValueError(f"Missing particle type {key} in particle data") - for field in self.config["particles"][key]: - if field not in particle_data[key]: - raise ValueError( - f"Missing field {field} in particle data for particle type {key}" - ) - - # Check if the units are correct - for key in particle_data: - for field in particle_data[key]: - if field not in units[key]: - raise ValueError(f"Units for {field} not found in units") - """ - def _check_particle_data(self, particle_data, units): # Get the list of expected particle types from the configuration expected_particle_types = list(self.config["particles"].keys()) diff --git a/rubix/spectra/dust/extinction_models.py b/rubix/spectra/dust/extinction_models.py index 453ec369..8935767a 100644 --- a/rubix/spectra/dust/extinction_models.py +++ b/rubix/spectra/dust/extinction_models.py @@ -39,7 +39,6 @@ class Cardelli89(BaseExtRvModel): Example: Example showing CCM89 curves for a range of R(V) values. - :: .. plot:: :include-source: @@ -209,7 +208,6 @@ class Gordon23(BaseExtRvModel): Example: Example showing G23 curves for a range of R(V) values. - :: .. plot:: :include-source: diff --git a/rubix/spectra/dust/generic_models.py b/rubix/spectra/dust/generic_models.py index 885e7ac6..f1e36c42 100644 --- a/rubix/spectra/dust/generic_models.py +++ b/rubix/spectra/dust/generic_models.py @@ -117,6 +117,7 @@ def Drude1d( ValueError: If ``x_0`` is zero. Examples: + .. plot:: :include-source: @@ -214,7 +215,6 @@ def FM90( Examples: Example showing a FM90 curve with components identified. - :: .. plot:: :include-source: diff --git a/rubix/spectra/dust/helpers.py b/rubix/spectra/dust/helpers.py index 29674488..692e08c3 100644 --- a/rubix/spectra/dust/helpers.py +++ b/rubix/spectra/dust/helpers.py @@ -1,8 +1,7 @@ -from typing import Final, Tuple - import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Final, Tuple from jaxtyping import Array, Float, jaxtyped # from jax.scipy.special import comb diff --git a/rubix/spectra/ssp/factory.py b/rubix/spectra/ssp/factory.py index b003ae95..b3f078d1 100644 --- a/rubix/spectra/ssp/factory.py +++ b/rubix/spectra/ssp/factory.py @@ -22,6 +22,7 @@ def get_ssp_template(template: str) -> SSPGrid: ValueError: If the template name or source format is not supported. Example: + >>> from rubix.spectra.ssp.factory import get_ssp_template >>> ssp = get_ssp_template("FSPS") >>> ssp.age.shape diff --git a/rubix/spectra/ssp/grid.py b/rubix/spectra/ssp/grid.py index d01a2b89..88c9e593 100644 --- a/rubix/spectra/ssp/grid.py +++ b/rubix/spectra/ssp/grid.py @@ -1,6 +1,5 @@ import os from dataclasses import dataclass, fields -from typing import List, Tuple, Union # import equinox as eqx import h5py @@ -9,6 +8,7 @@ from astropy import units as u from astropy.io import fits from beartype import beartype as typechecker +from beartype.typing import List, Tuple, Union from interpax import interp2d from jax.tree_util import Partial from jaxtyping import Array, Float, Int, jaxtyped @@ -77,7 +77,7 @@ def get_lookup_interpolation( Partial: Interpolation function ``f(metallicity, age)``. Examples: - :: + >>> grid = SSPGrid(...) >>> lookup = grid.get_lookup_interpolation() >>> metallicity = 0.02 @@ -256,7 +256,6 @@ class HDF5SSPGrid(SSPGrid): flux (Float[Array, FLUX_AXES]): SSP fluxes in Lsun/Angstrom. Example: - :: >>> config = { ... "name": "Bruzual & Charlot (2003)", @@ -363,7 +362,6 @@ class pyPipe3DSSPGrid(SSPGrid): flux (Float[Array, FLUX_AXES]): SSP fluxes in Lsun/Angstrom. Example: - :: >>> config = { ... "name": "Mastar Charlot & Bruzual (2019)", diff --git a/rubix/spectra/ssp/templates.py b/rubix/spectra/ssp/templates.py index 27353c26..9230b4e6 100644 --- a/rubix/spectra/ssp/templates.py +++ b/rubix/spectra/ssp/templates.py @@ -2,6 +2,7 @@ This module contains the supported templates for the SSP grid. Example: + >>> from rubix.spectra.ssp.templates import BruzualCharlot2003 >>> BruzualCharlot2003 >>> print(BruzualCharlot2003.age) diff --git a/rubix/telescope/base.py b/rubix/telescope/base.py index cbb3f80a..0212150f 100644 --- a/rubix/telescope/base.py +++ b/rubix/telescope/base.py @@ -1,8 +1,7 @@ -from typing import List, Optional, Union - import equinox as eqx import numpy as np from beartype import beartype as typechecker +from beartype.typing import List, Optional, Union from jaxtyping import Array, Float, Int, jaxtyped diff --git a/rubix/telescope/factory.py b/rubix/telescope/factory.py index 60db8682..7649c544 100644 --- a/rubix/telescope/factory.py +++ b/rubix/telescope/factory.py @@ -6,6 +6,7 @@ from beartype import beartype as typechecker from jaxtyping import jaxtyped +from rubix.logger import get_logger from rubix.telescope.apertures import ( CIRCULAR_APERTURE, HEXAGONAL_APERTURE, @@ -22,11 +23,17 @@ class TelescopeFactory: @jaxtyped(typechecker=typechecker) def __init__(self, telescopes_config: Optional[Union[dict, str]] = None) -> None: + logger = get_logger() if telescopes_config is None: + logger.info( + "No telescope config provided, falling back to %s", + TELESCOPE_CONFIG_PATH, + ) warnings.warn( - "No telescope config provided, using default stored in {}".format( + ("No telescope config provided, " "using default stored in {}").format( TELESCOPE_CONFIG_PATH - ) + ), + UserWarning, ) self.telescopes_config = read_yaml(TELESCOPE_CONFIG_PATH) elif isinstance(telescopes_config, str): @@ -46,7 +53,8 @@ def create_telescope(self, name: str) -> BaseTelescope: The telescope object as BaseTelescope. Raises: - ValueError: If the telescope name is not present in the configuration. + ValueError: If the telescope name is not present in the + configuration. Example 1 (Uses the defined telescope configuration) ----------------------------------------------------- diff --git a/rubix/telescope/utils.py b/rubix/telescope/utils.py index 2400e510..6a9e3219 100644 --- a/rubix/telescope/utils.py +++ b/rubix/telescope/utils.py @@ -1,8 +1,7 @@ -from typing import List, Tuple, Union - import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker +from beartype.typing import List, Tuple, Union from jaxtyping import Array, Bool, Float, Int, jaxtyped from rubix.cosmology.base import BaseCosmology diff --git a/rubix/utils.py b/rubix/utils.py index 46126dd4..09829e8a 100644 --- a/rubix/utils.py +++ b/rubix/utils.py @@ -7,7 +7,7 @@ import yaml from astropy.cosmology import Planck15 as cosmo -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from rubix.core.data import RubixData @@ -180,7 +180,7 @@ def load_galaxy_data( Tuple[Dict[str, Any], Dict[str, Any]]: Galaxy data and associated units Example: - :: + >>> from rubix.utils import load_galaxy_data >>> galaxy_data, units = load_galaxy_data("path/to/file.hdf5") """ diff --git a/source/installation.rst b/source/installation.rst deleted file mode 100644 index adc3c531..00000000 --- a/source/installation.rst +++ /dev/null @@ -1,29 +0,0 @@ -Installation -============ - -`RUBIX` can be installed via `pip` - -Clone the repository and navigate to the root directory of the repository. Then run - -``` -pip install . -``` - -If you want to contribute to the development of `RUBIX`, we recommend the following editable installation from this repository: - -``` -git clone https://github.com/ufuk-cakir/rubix -cd rubix -pip install -e . -``` -Having done so, the test suit can be run unsing `pytest`: - -``` -python -m pytest -``` - -Note that if `JAX` is not yet installed, only the CPU version of `JAX` will be installed -as a dependency. For a GPU-compatible installation of `JAX`, please refer to the -[JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html). - -Get started with this simple example notebooks/rubix_pipeline_single_function.ipynb. diff --git a/source/versions.rst b/source/versions.rst deleted file mode 100644 index 2c96bb77..00000000 --- a/source/versions.rst +++ /dev/null @@ -1,18 +0,0 @@ -Code versions -============ - -`RUBIX` has different code versions. The current version is `0.1`. - -Version 0.1 ------------ -Forwardmodel IFU cubes of galaxies from cosmological hydrodynamical simulations for stellar particles from IllustrisTNG50. - - -Version 0.2 ------------ -Under developement: Forwardmodel IFU cubes of galaxies from cosmological hydrodynamical simulations for gas particles from IllustrisTNG50. - - -Version 0.3 ------------ -Under developement: Add dust attenuation to the IFU cubes. diff --git a/tests/test_core_fits.py b/tests/test_core_fits.py new file mode 100644 index 00000000..440f6e47 --- /dev/null +++ b/tests/test_core_fits.py @@ -0,0 +1,130 @@ +import os +from types import SimpleNamespace +from unittest.mock import MagicMock + +import numpy as np +import pytest +from astropy.io import fits + +from rubix.core.fits import load_fits, store_fits + + +def _make_config(rotation_type: str = "face-on") -> dict: + rotation = { + "type": rotation_type, + "alpha": 0.0, + "beta": 0.0, + "gamma": 0.0, + } + if rotation_type not in ("face-on", "edge-on"): + rotation.update(alpha=0.11, beta=-0.22, gamma=0.33) + + return { + "pipeline": {"name": "test_pipeline"}, + "simulation": {"name": "TestSim"}, + "galaxy": {"dist_z": 0.2, "rotation": rotation}, + "data": { + "subset": {"use_subset": True}, + "load_galaxy_args": {"id": 42}, + "args": {"snapshot": 7}, + }, + "ssp": {"template": {"name": "Template"}}, + "telescope": { + "name": "DummyScope", + "psf": {"name": "gaussian", "size": 3, "sigma": 0.6}, + "lsf": {"sigma": 0.8}, + "noise": {"signal_to_noise": 5, "noise_distribution": "gaussian"}, + }, + "cosmology": {"name": "TEST"}, + } + + +def _patch_logger_and_telescope(monkeypatch): + logger = MagicMock() + monkeypatch.setattr("rubix.core.fits.get_logger", lambda cfg=None: logger) + telescope = SimpleNamespace( + spatial_res=0.5, + wave_res=1.5, + wave_range=(3600, 7000), + ) + monkeypatch.setattr("rubix.core.fits.get_telescope", lambda cfg: telescope) + return logger, telescope + + +def _expected_filename(filepath: str, config: dict) -> str: + base_filename = ( + f"{config['simulation']['name']}" + f"_id{config['data']['load_galaxy_args']['id']}" + f"_snap{config['data']['args']['snapshot']}" + f"_{config['telescope']['name']}" + f"_{config['pipeline']['name']}.fits" + ) + return f"{filepath}{base_filename}" + + +def test_store_fits_face_on_rotation(tmp_path, monkeypatch): + logger, telescope = _patch_logger_and_telescope(monkeypatch) + config = _make_config(rotation_type="face-on") + data = np.arange(24, dtype=np.float32).reshape(2, 3, 4) + filepath = os.path.join(str(tmp_path), "fits_output", "") + + store_fits(config, data, filepath) + + expected_file = _expected_filename(filepath, config) + assert os.path.exists(expected_file) + + with fits.open(expected_file) as hdul: + primary = hdul[0].header + assert primary["PIPELINE"] == config["pipeline"]["name"] + assert primary["DIST_z"] == config["galaxy"]["dist_z"] + assert primary["ROTATION"] == config["galaxy"]["rotation"]["type"] + assert primary["SIM"] == config["simulation"]["name"] + assert primary["GALAXYID"] == config["data"]["load_galaxy_args"]["id"] + assert primary["SNAPSHOT"] == config["data"]["args"]["snapshot"] + assert primary["SUBSET"] == config["data"]["subset"]["use_subset"] + assert primary["SSP"] == config["ssp"]["template"]["name"] + assert primary["INSTR"] == config["telescope"]["name"] + assert primary["PSF"] == config["telescope"]["psf"]["name"] + assert primary["PSF_SIZE"] == config["telescope"]["psf"]["size"] + assert primary["PSFSIGMA"] == config["telescope"]["psf"]["sigma"] + assert primary["LSF"] == config["telescope"]["lsf"]["sigma"] + assert primary["S_TO_N"] == config["telescope"]["noise"]["signal_to_noise"] + assert primary["N_DISTR"] == config["telescope"]["noise"]["noise_distribution"] + assert primary["COSMO"] == config["cosmology"]["name"] + + data_hdu = hdul[1].data + np.testing.assert_array_equal(data_hdu, data.T) + + logger.info.assert_called_once_with(f"Datacube saved to {expected_file}") + + +def test_store_fits_custom_rotation_exposes_angles(tmp_path, monkeypatch): + logger, telescope = _patch_logger_and_telescope(monkeypatch) + config = _make_config(rotation_type="custom") + data = np.zeros((1, 1, 1), dtype=np.float32) + filepath = os.path.join(str(tmp_path), "fits_output", "") + + store_fits(config, data, filepath) + expected_file = _expected_filename(filepath, config) + assert os.path.exists(expected_file) + + with fits.open(expected_file) as hdul: + primary = hdul[0].header + assert "ROTATION" not in primary + assert primary["ROT_A"] == pytest.approx(0.11) + assert primary["ROT_B"] == pytest.approx(-0.22) + assert primary["ROT_C"] == pytest.approx(0.33) + + logger.info.assert_called_with(f"Datacube saved to {expected_file}") + + +def test_load_fits_returns_cube_instance(monkeypatch): + cube_instance = MagicMock() + cube_factory = MagicMock(return_value=cube_instance) + monkeypatch.setattr("rubix.core.fits.Cube", cube_factory) + + path = "/tmp/dummy.fits" + result = load_fits(path) + + cube_factory.assert_called_once_with(filename=path) + assert result is cube_instance diff --git a/tests/test_core_ifu_dusty.py b/tests/test_core_ifu_dusty.py new file mode 100644 index 00000000..680ab5d5 --- /dev/null +++ b/tests/test_core_ifu_dusty.py @@ -0,0 +1,103 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import jax.numpy as jnp +import numpy as np + +from rubix.core.data import Galaxy, GasData, RubixData, StarsData +from rubix.core.ifu import get_calculate_dusty_datacube_particlewise + + +class DummyExtinctionModel: + def __init__(self, Rv): + self.Rv = Rv + + def extinguish(self, wavelengths, av): + del wavelengths, av + return jnp.array([0.5, 0.5]) + + +def _patch_dusty_dependencies(monkeypatch): + logger = MagicMock() + + telescope = SimpleNamespace( + sbin=2, + wave_seq=jnp.array([4000.0, 5000.0]), + ) + + monkeypatch.setattr("rubix.core.ifu.get_logger", lambda cfg=None: logger) + monkeypatch.setattr("rubix.core.ifu.get_telescope", lambda cfg: telescope) + monkeypatch.setattr( + "rubix.core.ifu.get_lookup_interpolation", + lambda cfg: lambda Z, age: jnp.array([1.0, 1.0]), + ) + monkeypatch.setattr( + "rubix.core.ifu.get_ssp", + lambda cfg: SimpleNamespace(wavelength=jnp.array([1.0, 2.0])), + ) + monkeypatch.setattr( + "rubix.core.ifu.cosmological_doppler_shift", + lambda z, wavelength: wavelength, + ) + monkeypatch.setattr( + "rubix.core.ifu._velocity_doppler_shift_single", + lambda wavelength, velocity, direction: wavelength, + ) + + def fake_resample(initial_spectrum, initial_wavelength, target_wavelength): + del initial_wavelength + return initial_spectrum[: target_wavelength.shape[0]] + + monkeypatch.setattr( + "rubix.core.ifu.resample_spectrum", + fake_resample, + ) + monkeypatch.setattr( + "rubix.core.ifu.Rv_model_dict", + {"Dummy": DummyExtinctionModel}, + ) + monkeypatch.setattr("rubix.core.ifu.RV_MODELS", ["Dummy"]) + + return logger, telescope + + +def _build_rubixdata() -> RubixData: + stars = StarsData() + stars.age = jnp.array([1.0, 2.0]) + stars.metallicity = jnp.array([0.1, 0.2]) + stars.mass = jnp.array([1.0, 2.0]) + stars.velocity = jnp.array([0.0, 0.0]) + stars.pixel_assignment = jnp.array([0, 1], dtype=jnp.int32) + stars.extinction = jnp.ones((2, 2), dtype=jnp.float32) + return RubixData(galaxy=Galaxy(), stars=stars, gas=GasData()) + + +def test_calculate_dusty_datacube_particlewise(monkeypatch): + logger, telescope = _patch_dusty_dependencies(monkeypatch) + + config = { + "pipeline": {"name": "calc_ifu"}, + "logger": {"log_level": "DEBUG", "log_file_path": None, "format": ""}, + "telescope": {"name": "Dummy"}, + "cosmology": {"name": "PLANCK15"}, + "galaxy": {"dist_z": 0.1}, + "ssp": { + "template": {"name": "BruzualCharlot2003"}, + "dust": {"extinction_model": "Dummy", "Rv": 3.1}, + }, + } + + rubixdata = _build_rubixdata() + + calculate = get_calculate_dusty_datacube_particlewise(config) + result = calculate(rubixdata) + + datacube = result.stars.datacube + assert datacube.shape == (2, 2, telescope.wave_seq.shape[0]) + + flattened = datacube.reshape(-1, telescope.wave_seq.shape[0]) + np.testing.assert_allclose(flattened[0], [0.5, 0.5]) + np.testing.assert_allclose(flattened[1], [1.0, 1.0]) + assert np.all(flattened[2:] == 0) + + logger.info.assert_called() diff --git a/tests/test_core_rotation.py b/tests/test_core_rotation.py index 5e74f8f8..ce858f23 100644 --- a/tests/test_core_rotation.py +++ b/tests/test_core_rotation.py @@ -1,8 +1,29 @@ +from unittest.mock import MagicMock, patch + +import numpy as np import pytest +from rubix.core.data import Galaxy, GasData, RubixData, StarsData from rubix.core.rotation import get_galaxy_rotation +def _build_rubix_data(): + galaxy = Galaxy(center=np.zeros(3), halfmassrad_stars=1.0) + stars = StarsData( + coords=np.zeros((1, 3)), velocity=np.zeros((1, 3)), mass=np.ones(1) + ) + gas = GasData(coords=np.ones((1, 3)), velocity=np.ones((1, 3))) + return RubixData(galaxy=galaxy, stars=stars, gas=gas) + + +def _base_config(particle_types): + return { + "galaxy": {"rotation": {"alpha": 0.0, "beta": 0.0, "gamma": 0.0}}, + "simulation": {"name": "mock"}, + "data": {"args": {"particle_type": particle_types}}, + } + + def _get_data(): return { "coords": None, @@ -62,3 +83,47 @@ def test_custom_rotation(): config = {"galaxy": {"rotation": {"alpha": 45, "beta": 30, "gamma": 15}}} rotate_galaxy = get_galaxy_rotation(config) assert callable(rotate_galaxy) + + +@patch("rubix.core.rotation.rotate_galaxy_core") +@patch("rubix.core.rotation.get_logger") +def test_rotation_applies_to_gas_and_stars(mock_get_logger, mock_rotate_core): + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + mock_rotate_core.side_effect = lambda **kwargs: ( + kwargs["positions"] + 1, + kwargs["velocities"] + 2, + ) + + rubixdata = _build_rubix_data() + config = _base_config(["stars", "gas"]) + rotate = get_galaxy_rotation(config) + + rotated = rotate(rubixdata) + + assert np.all(rotated.gas.coords == np.ones((1, 3)) + 1) + assert np.all(rotated.gas.velocity == np.ones((1, 3)) + 2) + assert np.all(rotated.stars.coords == np.zeros((1, 3)) + 1) + assert np.all(rotated.stars.velocity == np.zeros((1, 3)) + 2) + assert mock_rotate_core.call_count == 2 + + +@patch("rubix.core.rotation.rotate_galaxy_core") +@patch("rubix.core.rotation.get_logger") +def test_rotation_warns_when_gas_missing(mock_get_logger, mock_rotate_core): + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + mock_rotate_core.side_effect = lambda **kwargs: ( + kwargs["positions"] + 5, + kwargs["velocities"] + 5, + ) + + rubixdata = _build_rubix_data() + config = _base_config(["stars"]) + rotate = get_galaxy_rotation(config) + rotate(rubixdata) + + mock_logger.warning.assert_called_with( + "Gas not found in particle_type, only rotating stellar component." + ) + assert mock_rotate_core.call_count == 1 diff --git a/tests/test_core_telescope.py b/tests/test_core_telescope.py index 5767c6c2..7f902b22 100644 --- a/tests/test_core_telescope.py +++ b/tests/test_core_telescope.py @@ -1,10 +1,11 @@ -from typing import cast from unittest.mock import MagicMock, patch import jax.numpy as jnp import pytest +from rubix.core.data import Galaxy, GasData, RubixData, StarsData from rubix.core.telescope import ( + get_filter_particles, get_spatial_bin_edges, get_spaxel_assignment, get_telescope, @@ -128,3 +129,90 @@ def test_get_spatial_bin_edges( # Assertions assert isinstance(result, jnp.ndarray) # Ensure the return type matches assert result.shape == (3,) # Check the shape of spatial_bin_edges + + +@patch("rubix.core.telescope.TelescopeFactory") +def test_get_telescope_type_error(mock_factory): + config = {"telescope": {"name": "MUSE"}} + mock_factory.return_value.create_telescope.return_value = MagicMock() + + with pytest.raises(TypeError, match="Expected type BaseTelescope"): + get_telescope(config) + + +def test_spaxel_assignment_handles_stars_and_gas(): + config = { + "telescope": {"name": "MUSE"}, + "galaxy": {"dist_z": 0.5}, + "cosmology": {"name": "PLANCK15"}, + } + + with ( + patch("rubix.core.telescope.get_telescope") as mock_get_telescope, + patch("rubix.core.telescope.get_spatial_bin_edges") as mock_get_spatial, + patch("rubix.core.telescope.square_spaxel_assignment") as mock_assignment, + ): + mock_get_telescope.return_value = MagicMock(pixel_type="square") + mock_get_spatial.return_value = jnp.array([0.0, 1.0]) + mock_assignment.side_effect = ["star-pa", "gas-pa"] + + spaxel_assignment = get_spaxel_assignment(config) + + stars = StarsData(coords=jnp.zeros((1, 3))) + gas = GasData(coords=jnp.ones((1, 3))) + data = RubixData(galaxy=Galaxy(), stars=stars, gas=gas) + + result = spaxel_assignment(data) + + assert result.stars.pixel_assignment == "star-pa" + assert result.stars.spatial_bin_edges is mock_get_spatial.return_value + assert result.gas.pixel_assignment == "gas-pa" + assert result.gas.spatial_bin_edges is mock_get_spatial.return_value + assert mock_assignment.call_count == 2 + + +@patch("rubix.core.telescope.mask_particles_outside_aperture") +@patch("rubix.core.telescope.get_spatial_bin_edges") +def test_filter_particles_masks_stars_and_gas(mock_get_edges, mock_mask_particles): + config = { + "telescope": {"name": "MUSE"}, + "galaxy": {"dist_z": 0.5}, + "cosmology": {"name": "PLANCK15"}, + "data": {"args": {"particle_type": ["stars", "gas"]}}, + } + + mock_get_edges.return_value = jnp.array([0.0, 1.0, 2.0]) + star_mask = jnp.array([True, False, True]) + gas_mask = jnp.array([False, True, True]) + mock_mask_particles.side_effect = [star_mask, gas_mask] + + stars = StarsData( + coords=jnp.zeros((3, 3)), + velocity=jnp.zeros((3, 3)), + mass=jnp.array([1.0, 2.0, 3.0]), + age=jnp.array([1.0, 2.0, 3.0]), + metallicity=jnp.array([1.0, 2.0, 3.0]), + ) + gas = GasData( + coords=jnp.zeros((3, 3)), + velocity=jnp.zeros((3, 3)), + mass=jnp.array([4.0, 5.0, 6.0]), + density=jnp.array([4.0, 5.0, 6.0]), + internal_energy=jnp.array([4.0, 5.0, 6.0]), + metallicity=jnp.array([4.0, 5.0, 6.0]), + ) + data = RubixData(galaxy=Galaxy(), stars=stars, gas=gas) + + filter_fn = get_filter_particles(config) + + result = filter_fn(data) + + assert jnp.array_equal(result.stars.mass, jnp.array([1.0, 0.0, 3.0])) + assert jnp.array_equal(result.stars.age, jnp.array([1.0, 0.0, 3.0])) + assert jnp.array_equal(result.stars.metallicity, jnp.array([1.0, 0.0, 3.0])) + assert jnp.array_equal(result.stars.mask, star_mask) + assert jnp.array_equal(result.gas.mass, jnp.array([0.0, 5.0, 6.0])) + assert jnp.array_equal(result.gas.density, jnp.array([0.0, 5.0, 6.0])) + assert jnp.array_equal(result.gas.internal_energy, jnp.array([0.0, 5.0, 6.0])) + assert jnp.array_equal(result.gas.mask, gas_mask) + assert mock_mask_particles.call_count == 2 diff --git a/tests/test_cosmology.py b/tests/test_cosmology.py index fa8b05a3..79f9f992 100644 --- a/tests/test_cosmology.py +++ b/tests/test_cosmology.py @@ -68,7 +68,7 @@ def test_age_at_z(z): @pytest.mark.parametrize("z", [0.1, 0.2, 0.5, 1.0, 2.0]) def test_angular_scale(z): rubix_scale = rubix_cosmo.angular_scale(z) - # Compute the scale using Astropy's angular diameter distance in Mpc and converting to kpc/arcsec + # Use Astropy's angular diameter distance (Mpc) to get kpc/arcsec astropy_scale = ( astropy_cosmo.angular_diameter_distance(z).value * (jnp.pi / (180 * 60 * 60)) diff --git a/tests/test_factory.py b/tests/test_factory.py index 96a0f9a5..2234123d 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -1,3 +1,4 @@ +import logging from unittest.mock import MagicMock, patch import pytest @@ -14,9 +15,7 @@ def test_get_input_handler_illustris(): result = get_input_handler(config) - # Check if the mock instance is returned assert result == mock_instance - # Ensure that the constructor is called with the correct arguments mock_handler.assert_called_once_with(path="value1", logger=None) @@ -26,26 +25,31 @@ def test_get_input_handler_unsupported(): with pytest.raises(ValueError) as excinfo: get_input_handler(config) + assert "not supported" in str(excinfo.value) -def test_get_input_handler_illustris(): - config = {"simulation": {"name": "IllustrisTNG", "args": {"path": "value1"}}} - with patch("rubix.galaxy.input_handler.factory.IllustrisHandler") as mock_handler: +def test_get_input_handler_pynbody(): + config = { + "simulation": { + "name": "NIHAO", + "args": {"path": "/tmp/nihao", "halo_path": "/tmp/halo"}, + }, + "galaxy": {"dist_z": 0.12}, + } + logger = logging.getLogger("rubix.tests.factory.pynbody") + logger.info = MagicMock() + + with patch("rubix.galaxy.input_handler.factory.PynbodyHandler") as mock_handler: mock_instance = MagicMock() mock_handler.return_value = mock_instance - result = get_input_handler(config) + result = get_input_handler(config, logger=logger) - # Check if the mock instance is returned assert result == mock_instance - # Ensure that the constructor is called with the correct arguments - mock_handler.assert_called_once_with(path="value1", logger=None) - - -def test_get_input_handler_unsupported(): - config = {"simulation": {"name": "UnknownSim", "args": {}}} - - with pytest.raises(ValueError) as excinfo: - get_input_handler(config) - - assert "not supported" in str(excinfo.value) + mock_handler.assert_called_once_with( + path="/tmp/nihao", + halo_path="/tmp/halo", + dist_z=0.12, + logger=logger, + ) + logger.info.assert_any_call("Using PynbodyHandler to load a NIHAO galaxy") diff --git a/tests/test_galaxy_alignment.py b/tests/test_galaxy_alignment.py index bea693a2..a1fa64ea 100644 --- a/tests/test_galaxy_alignment.py +++ b/tests/test_galaxy_alignment.py @@ -73,6 +73,24 @@ def test_center_galaxy_sucessful(): ) +def test_center_galaxy_gas_branch(): + gas_coordinates = np.array([[1, 2, 3], [4, 5, 6]]) + gas_velocities = np.array([[1, 1, 1], [2, 2, 2]]) + center = np.array([1, 2, 3]) + + mockdata = MockRubixData( + MockGalaxyData(center=center), + MockStarsData(coords=gas_coordinates, velocity=gas_velocities), + MockGasData(coords=gas_coordinates, velocity=gas_velocities), + ) + + result = center_particles(mockdata, "gas") + assert np.all(result.gas.coords == gas_coordinates - center) + assert np.all( + result.gas.velocity == gas_velocities - np.median(gas_velocities, axis=0) + ) + + def test_moment_of_inertia_tensor(): """Test the moment_of_inertia_tensor function.""" @@ -206,3 +224,53 @@ def test_rotate_galaxy(): # assert jnp.allclose(rotated_velocities, expected_rotated_velocities), \ # f"Test failed. Expected velocities {expected_rotated_velocities}, got {rotated_velocities}" + + +def test_rotate_galaxy_unknown_key(): + positions = jnp.array([[1.0, 0.0, 0.0]]) + velocities = jnp.array([[0.0, 1.0, 0.0]]) + masses = jnp.array([1.0]) + halfmass = 1.0 + + with pytest.raises(ValueError, match="Unknown key"): + rotate_galaxy( + positions, + velocities, + positions, + masses, + halfmass, + 0.0, + 0.0, + 0.0, + "Unknown", + ) + + +def test_rotate_galaxy_uses_nihao_branch(): + positions = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) + velocities = jnp.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + stars = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + masses = jnp.array([1.0, 1.0]) + halfmass = 1.0 + + alpha = 45.0 + beta = 15.0 + gamma = 30.0 + + expected_positions = apply_rotation(positions, alpha, beta, gamma) + expected_velocities = apply_rotation(velocities, alpha, beta, gamma) + + rotated_positions, rotated_velocities = rotate_galaxy( + positions, + velocities, + stars, + masses, + halfmass, + alpha, + beta, + gamma, + "NIHAO", + ) + + assert jnp.allclose(rotated_positions, expected_positions) + assert jnp.allclose(rotated_velocities, expected_velocities) diff --git a/tests/test_illustris_handler.py b/tests/test_illustris_handler.py index 73cf9073..dc4c8e57 100644 --- a/tests/test_illustris_handler.py +++ b/tests/test_illustris_handler.py @@ -272,3 +272,115 @@ def test_load_data_without_GFM_stellarformation_time(mock_file, mock_exists): data["test_part"] = data["PartType4"] data = handler._get_particle_data(data, "test_part") assert "coordinates" in data + + +def _make_stub_handler(): + handler = object.__new__(IllustrisHandler) + handler._logger = MagicMock() + return handler + + +def test_check_fields_missing_expected(): + handler = _make_stub_handler() + with pytest.raises(ValueError) as exc: + handler._check_fields({"random": {}}) + assert "No expected fields" in str(exc.value) + + +def test_check_fields_unexpected_extra_field(): + handler = _make_stub_handler() + fake_data = { + "Header": {}, + "SubhaloData": {}, + "PartType4": {}, + "Random": {}, + } + with pytest.raises(ValueError) as exc: + handler._check_fields(fake_data) + assert "Unexpected fields" in str(exc.value) + + +def test_check_fields_unsupported_part_type(): + handler = _make_stub_handler() + fake_data = { + "Header": {}, + "SubhaloData": {}, + "PartType4": {}, + "PartType99": {}, + } + with pytest.raises(NotImplementedError) as exc: + handler._check_fields(fake_data) + assert "PartType99" in str(exc.value) + + +def test_check_particle_data_requires_mapped_fields(): + handler = _make_stub_handler() + valid_data = {"stars": {"coords": np.array([0.0])}} + with pytest.raises(ValueError): + handler._check_particle_data(valid_data, {}) + + +def test_get_particle_keys_unsupported_type(): + handler = _make_stub_handler() + handler.MAPPED_PARTICLE_KEYS = {"PartType4": "stars"} + handler.ILLUSTRIS_DATA = [ + "Header", + "SubhaloData", + "PartType4", + "PartTypeX", + ] + fake_file = { + "Header": {}, + "SubhaloData": {}, + "PartType4": {}, + "PartTypeX": {}, + } + with pytest.raises(NotImplementedError) as exc: + handler._get_particle_keys(fake_file) + assert "PartTypeX" in str(exc.value) + + +def test_check_particle_data_no_matching_fields(): + handler = _make_stub_handler() + with pytest.raises(ValueError) as exc: + handler._check_particle_data({"unexpected": {}}, {}) + assert "No expected fields" in str(exc.value) + + +def test_check_particle_data_extra_parttype_field_raises_not_implemented(): + handler = _make_stub_handler() + handler.MAPPED_PARTICLE_KEYS = {"PartType4": "stars"} + handler.MAPPED_FIELDS = {"PartType4": {"Coordinates": "coords"}} + particle_data = { + "stars": {"coords": np.array([0.0])}, + "PartType99": {}, + } + with pytest.raises(NotImplementedError) as exc: + handler._check_particle_data(particle_data, {}) + assert "PartType99" in str(exc.value) + + +def test_check_particle_data_extra_field_raises_value_error(): + handler = _make_stub_handler() + handler.MAPPED_PARTICLE_KEYS = {"PartType4": "stars"} + handler.MAPPED_FIELDS = {"PartType4": {"Coordinates": "coords"}} + particle_data = { + "stars": {"coords": np.array([0.0])}, + "extra": {}, + } + with pytest.raises(ValueError) as exc: + handler._check_particle_data(particle_data, {}) + assert "Unexpected fields" in str(exc.value) + + +def test_halfmassrad_stars_requires_coordinates(): + handler = _make_stub_handler() + handler.TIME = 1.0 + handler.HUBBLE_PARAM = 0.5 + fake_file = { + "SubhaloData": {"halfmassrad_stars": np.array(1.0)}, + "PartType4": {}, + } + with pytest.raises(ValueError) as exc: + handler._get_halfmassrad_stars(fake_file) + assert "Coordinates" in str(exc.value) diff --git a/tests/test_input_handler.py b/tests/test_input_handler.py index a7f2b80e..36209f37 100644 --- a/tests/test_input_handler.py +++ b/tests/test_input_handler.py @@ -213,3 +213,15 @@ def test_particle_field_unit_info_missing_error(input_handler): particle_data["stars"]["unsupported_field"] = 1 input_handler._check_particle_data(particle_data, units) assert "Units for unsupported_field not found in units" in str(excinfo.value) + + +def test_missing_particle_field_error(input_handler): + with pytest.raises(ValueError) as excinfo: + particle_data = input_handler.get_particle_data() + del particle_data["stars"]["mass"] + units = input_handler.get_units() + input_handler._check_particle_data(particle_data, units) + assert ( + str(excinfo.value) + == "Missing field mass in particle data for particle type stars" + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f6267ee1..fcc433d2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,13 +1,16 @@ from copy import deepcopy from pathlib import Path +from unittest.mock import MagicMock import jax.numpy as jnp import pytest from jax import jit, make_jaxpr from jax.tree_util import Partial +from rubix.core.data import Galaxy, GasData, RubixData, StarsData +from rubix.core.pipeline import RubixPipeline from rubix.pipeline import linear_pipeline as lp -from rubix.utils import read_yaml +from rubix.utils import _pad_particles, read_yaml # helper stuff that we need @@ -75,7 +78,12 @@ def test_register_transformer(pipeline_fixture): pipeline.register_transformer(mult) pipeline.register_transformer(div) - assert pipeline.transformers == {"add": add, "sub": sub, "mult": mult, "div": div} + assert pipeline.transformers == { + "add": add, + "sub": sub, + "mult": mult, + "div": div, + } with pytest.raises( ValueError, match="A transformer of this name is already present" @@ -97,7 +105,10 @@ def test_update_pipeline(pipeline_fixture_full): assert pipeline._names == ["C", "B", "A", "X", "Z", "D"] - with pytest.raises(RuntimeError, match="Node 'Not there' not found in the config"): + with pytest.raises( + RuntimeError, + match="Node 'Not there' not found in the config", + ): pipeline.update_pipeline("Not there") assert pipeline._names == ["C", "B", "A", "X", "Z", "D"] @@ -121,7 +132,7 @@ def test_build_pipeline_broken(pipeline_fixture_full): with pytest.raises( ValueError, - match="Each node of a pipeline must have a config node containing 'name'", + match=("Each node of a pipeline must have a config node containing " "'name'"), ): pipeline.build_pipeline() @@ -129,7 +140,10 @@ def test_build_pipeline_broken(pipeline_fixture_full): pipeline.config["Transformers"]["X"]["depends_on"] = None - with pytest.raises(ValueError, match="There can only be one starting point"): + with pytest.raises( + ValueError, + match="There can only be one starting point", + ): pipeline.build_pipeline() pipeline.config = deepcopy(cfg) @@ -146,7 +160,11 @@ def test_build_pipeline_broken(pipeline_fixture_full): ): pipeline.build_pipeline() - pipeline.config["Transformers"]["D"] = {"name": add, "depends_on": "X", "args": []} + pipeline.config["Transformers"]["D"] = { + "name": add, + "depends_on": "X", + "args": [], + } with pytest.raises( ValueError, @@ -163,7 +181,10 @@ def test_build_pipeline_broken(pipeline_fixture_full): with pytest.raises( ValueError, - match="Dependencies must be unique in a linear pipeline as branching is not allowed. Found X at least twice", + match=( + "Dependencies must be unique in a linear pipeline as branching is " + "not allowed. Found X at least twice" + ), ): pipeline.build_pipeline() @@ -181,7 +202,10 @@ def test_build_pipeline_broken(pipeline_fixture_full): pipeline.transformers = [] - with pytest.raises(RuntimeError, match="No registered transformers present"): + with pytest.raises( + RuntimeError, + match="No registered transformers present", + ): pipeline.build_pipeline() @@ -261,7 +285,8 @@ def test_apply(pipeline_fixture_full): pipeline, x = pipeline_fixture_full with pytest.raises( - ValueError, match="Cannot apply the pipeline to an empty list of arguments" + ValueError, + match="Cannot apply the pipeline to an empty list of arguments", ): pipeline.apply() @@ -298,7 +323,8 @@ def test_get_jaxpr_for_element(pipeline_fixture_full): assert str(expr) == str(manual_expr) with pytest.raises( - RuntimeError, match="Cannot create intermediate expression for 'Not there'" + RuntimeError, + match="Cannot create intermediate expression for 'Not there'", ): pipeline.get_jaxpr_for_element( "Not there", @@ -328,10 +354,194 @@ def test_compile_element(pipeline_fixture_full): assert jnp.allclose(manual(x), fp(x)) - with pytest.raises(RuntimeError, match="Compilation of element 'Not there' failed"): + with pytest.raises( + RuntimeError, + match="Compilation of element 'Not there' failed", + ): pipeline.compile_element( "Not there", static_kwargs=[ "m", ], ) + + +@pytest.fixture +def simple_pipeline(monkeypatch): + user_config = { + "pipeline": {"name": "test_pipeline"}, + "logger": {"log_level": "INFO"}, + } + pipeline_config = {"Transformers": {}} + logger = MagicMock() + + monkeypatch.setattr("rubix.core.pipeline.get_config", lambda cfg: cfg) + monkeypatch.setattr( + "rubix.core.pipeline.get_pipeline_config", + lambda name: pipeline_config, + ) + monkeypatch.setattr("rubix.core.pipeline.get_logger", lambda cfg: logger) + monkeypatch.setattr("rubix.core.pipeline.get_ssp", lambda cfg: MagicMock()) + monkeypatch.setattr( + "rubix.core.pipeline.get_telescope", + lambda cfg: MagicMock(), + ) + + pipeline = RubixPipeline(user_config) + return pipeline, logger + + +def _make_rubix_data(star_count=3, gas_count=2): + stars = StarsData( + coords=jnp.zeros((star_count, 3)), + velocity=jnp.zeros((star_count, 3)), + mass=jnp.arange(star_count, dtype=jnp.float32), + age=jnp.arange(star_count, dtype=jnp.float32), + metallicity=jnp.arange(star_count, dtype=jnp.float32), + ) + gas = GasData( + coords=jnp.zeros((gas_count, 3)), + velocity=jnp.zeros((gas_count, 3)), + mass=jnp.ones(gas_count, dtype=jnp.float32), + density=jnp.ones(gas_count, dtype=jnp.float32), + internal_energy=jnp.ones(gas_count, dtype=jnp.float32), + metallicity=jnp.ones(gas_count, dtype=jnp.float32), + ) + data = RubixData(galaxy=Galaxy(), stars=stars, gas=gas) + return data + + +def test_prepare_data_logs_counts(simple_pipeline, monkeypatch): + pipeline, logger = simple_pipeline + rubixdata = _make_rubix_data(star_count=4, gas_count=3) + + monkeypatch.setattr( + "rubix.core.pipeline.get_rubix_data", + lambda cfg: rubixdata, + ) + + result = pipeline.prepare_data() + + assert result is rubixdata + assert any("Data loaded" in call.args[0] for call in logger.info.call_args_list) + + +def test_pad_particles_extends_arrays(): + data = _make_rubix_data(star_count=2) + padded = _pad_particles(data, pad=3) + + assert padded.stars.coords.shape[0] == 5 + assert jnp.count_nonzero(padded.stars.coords[-3:]) == 0 + assert padded.stars.mass[-3:].sum() == 0 + + +def test_run_sharded_triggers_padding(simple_pipeline, monkeypatch): + pipeline, _ = simple_pipeline + data = _make_rubix_data(star_count=3, gas_count=1) + + mock_pad = MagicMock(side_effect=lambda inp, pad: inp) + monkeypatch.setattr("rubix.core.pipeline._pad_particles", mock_pad) + + monkeypatch.setattr(RubixPipeline, "_get_pipeline_functions", lambda self: []) + + class DummyLinearPipeline: + def __init__(self, cfg, functions): + self.config = cfg + + def assemble(self): + pass + + def compile_expression(self): + class DummyOutput: + def __init__(self): + self.stars = MagicMock(datacube=jnp.zeros((1, 1, 1))) + + return lambda *_: DummyOutput() + + monkeypatch.setattr( + "rubix.core.pipeline.pipeline.LinearTransformerPipeline", + DummyLinearPipeline, + ) + monkeypatch.setattr( + "rubix.core.pipeline.Mesh", + lambda devices, axis_names: None, + ) + + class DummyNamedSharding: + def __init__(self, mesh, spec): + self.spec = spec + + monkeypatch.setattr( + "rubix.core.pipeline.NamedSharding", + DummyNamedSharding, + ) + monkeypatch.setattr( + "rubix.core.pipeline.P", + lambda *args, **kwargs: (args, kwargs), + ) + monkeypatch.setattr( + "rubix.core.pipeline.jax.device_put", + lambda data, spec: data, + ) + monkeypatch.setattr( + "rubix.core.pipeline.lax.psum", + lambda value, axis_name: value, + ) + monkeypatch.setattr( + "rubix.core.pipeline.shard_map", + lambda func, mesh, in_specs, out_specs, check_rep: ( + lambda inputdata: func(inputdata) + ), + ) + + result = pipeline.run_sharded(data, devices=[object(), object()]) + + assert mock_pad.call_count == 1 + _, pad_arg = mock_pad.call_args[0] + assert pad_arg == 1 + assert isinstance(result, jnp.ndarray) + + +def test_gradient_calls_jax_grad(simple_pipeline, monkeypatch): + pipeline, _ = simple_pipeline + expected = MagicMock() + captured = {} + + def fake_grad(fn, argnums=0): + captured["fn"] = fn + captured["argnums"] = argnums + + def gradient_fn(rubixdata, targetdata): + captured["rubixdata"] = rubixdata + captured["targetdata"] = targetdata + return expected + + return gradient_fn + + monkeypatch.setattr("rubix.core.pipeline.jax.grad", fake_grad) + rubixdata = MagicMock() + target = MagicMock() + + result = pipeline.gradient(rubixdata, target) + + assert captured["fn"].__func__ is pipeline.loss.__func__ + assert captured["fn"].__self__ is pipeline + assert captured["argnums"] == 0 + assert captured["rubixdata"] is rubixdata + assert captured["targetdata"] is target + assert result is expected + + +def test_loss_uses_run(simple_pipeline): + pipeline, _ = simple_pipeline + rubixdata = MagicMock() + target = jnp.array([1.0, 2.0]) + output = jnp.array([3.0, 4.0]) + + pipeline.run_sharded = MagicMock(return_value=output) + + loss_value = pipeline.loss(rubixdata, target) + + pipeline.run_sharded.assert_called_once_with(rubixdata) + expected = jnp.sum((output - target) ** 2) + assert jnp.allclose(loss_value, expected) diff --git a/tests/test_pynbody_handler.py b/tests/test_pynbody_handler.py index 74a4f4f9..b7fc200b 100644 --- a/tests/test_pynbody_handler.py +++ b/tests/test_pynbody_handler.py @@ -1,5 +1,8 @@ -from unittest.mock import MagicMock, patch +import copy +from contextlib import ExitStack +from unittest.mock import MagicMock, mock_open, patch +import astropy.units as u import numpy as np import pytest @@ -95,28 +98,44 @@ def dm_getitem(key): return mock_sim -@pytest.fixture -def handler_with_mock_data(mock_simulation, mock_config): - with ( - patch("pynbody.load", return_value=mock_simulation), - patch("pynbody.analysis.angmom.faceon", return_value=None), - patch( - "pynbody.analysis.angmom.ang_mom_vec", - return_value=np.array([0.0, 0.0, 1.0]), - ), - patch("pynbody.analysis.angmom.calc_sideon_matrix", return_value=np.eye(3)), - ): - +def _build_pynbody_handler(mock_simulation, mock_config, **overrides): + with ExitStack() as stack: + stack.enter_context(patch("pynbody.load", return_value=mock_simulation)) + stack.enter_context(patch("pynbody.analysis.angmom.faceon", return_value=None)) + stack.enter_context( + patch( + "pynbody.analysis.angmom.ang_mom_vec", + return_value=np.array([0.0, 0.0, 1.0]), + ) + ) + stack.enter_context( + patch( + "pynbody.analysis.angmom.calc_sideon_matrix", + return_value=np.eye(3), + ) + ) handler = PynbodyHandler( path="mock_path", - halo_path="mock_halo_path", + halo_path=overrides.get("halo_path", "mock_halo_path"), + rotation_path=overrides.get("rotation_path", "./data"), + logger=overrides.get("logger"), config=mock_config, - dist_z=mock_config["galaxy"]["dist_z"], - halo_id=1, + dist_z=overrides.get("dist_z", mock_config["galaxy"]["dist_z"]), + halo_id=overrides.get("halo_id", 1), ) return handler +@pytest.fixture +def handler_with_mock_data(mock_simulation, mock_config, tmp_path): + rotation_path = tmp_path / "rotation" + return _build_pynbody_handler( + mock_simulation, + mock_config, + rotation_path=str(rotation_path), + ) + + def test_pynbody_handler_initialization(handler_with_mock_data): """Test initialization of PynbodyHandler.""" assert handler_with_mock_data is not None @@ -165,3 +184,162 @@ def test_stars_data_load(handler_with_mock_data): assert "stars" in data assert "coords" in data["stars"] assert "mass" in data["stars"] + + +def test_load_config_uses_env_path(): + handler = object.__new__(PynbodyHandler) + handler.logger = MagicMock() + env_path = "/tmp/mock_config.yml" + config_content = "fields: {}" + with patch.dict( + "os.environ", + {"RUBIX_PYNBODY_CONFIG": env_path}, + clear=True, + ): + with ( + patch( + "rubix.galaxy.input_handler.pynbody.os.path.exists", + return_value=True, + ), + patch("builtins.open", mock_open(read_data=config_content)), + ): + config = handler._load_config() + handler.logger.info.assert_called_with( + f"Using environment-specified config path: {env_path}" + ) + assert config == {"fields": {}} + + +def test_load_config_default_missing(): + handler = object.__new__(PynbodyHandler) + handler.logger = MagicMock() + with patch.dict("os.environ", {}, clear=True): + with ( + patch( + "rubix.galaxy.input_handler.pynbody.os.path.exists", + return_value=False, + ), + pytest.raises(FileNotFoundError), + ): + handler._load_config() + + +def test_rotation_matrix_saved(mock_simulation, mock_config, tmp_path): + logger = MagicMock() + rotation_path = tmp_path / "rotation_saved" + rotation_path.mkdir() + with ( + patch( + "rubix.galaxy.input_handler.pynbody.os.path.exists", + return_value=True, + ), + patch("rubix.galaxy.input_handler.pynbody.np.save") as mock_save, + ): + handler = _build_pynbody_handler( + mock_simulation, + mock_config, + rotation_path=str(rotation_path), + logger=logger, + ) + assert handler is not None + mock_save.assert_called_once() + logger.info.assert_any_call( + "Rotation matrix calculated and saved to " + f"'{rotation_path}/rotation_matrix.npy'." + ) + + +def test_rotation_matrix_not_saved(mock_simulation, mock_config, tmp_path): + logger = MagicMock() + rotation_path = tmp_path / "rotation_nosave" + with ( + patch( + "rubix.galaxy.input_handler.pynbody.os.path.exists", + return_value=False, + ), + patch("rubix.galaxy.input_handler.pynbody.np.save") as mock_save, + ): + handler = _build_pynbody_handler( + mock_simulation, + mock_config, + rotation_path=str(rotation_path), + logger=logger, + ) + assert handler is not None + mock_save.assert_not_called() + logger.info.assert_any_call("Rotation matrix calculated and not saved.") + + +def test_get_halo_data_without_halo_path(mock_simulation, mock_config): + logger = MagicMock() + handler = _build_pynbody_handler( + mock_simulation, + mock_config, + halo_path=None, + logger=logger, + ) + handler.logger.warning.reset_mock() + result = handler.get_halo_data() + assert result is None + handler.logger.warning.assert_called_once_with("No halo file provided or found.") + + +def test_get_halo_data_default_index(mock_simulation, mock_config): + handler = _build_pynbody_handler(mock_simulation, mock_config) + handler.sim.halos.reset_mock() + handler.sim.halos.return_value.__getitem__.reset_mock() + result = handler.get_halo_data(halo_id=None) + handler.sim.halos.assert_called_once() + handler.sim.halos.return_value.__getitem__.assert_called_once_with(0) + assert result == handler.sim.halos.return_value.__getitem__.return_value + + +def test_get_galaxy_data_without_stars(mock_simulation, mock_config): + logger = MagicMock() + handler = _build_pynbody_handler( + mock_simulation, + mock_config, + logger=logger, + ) + handler.data.pop("stars", None) + handler.logger.warning.reset_mock() + galaxy_data = handler.get_galaxy_data() + assert galaxy_data["halfmassrad_stars"] is None + handler.logger.warning.assert_called_once_with( + "No star data available to calculate the half-mass radius." + ) + + +def test_get_simulation_metadata_returns_expected_values(mock_simulation, mock_config): + handler = _build_pynbody_handler(mock_simulation, mock_config) + metadata = handler.get_simulation_metadata() + assert metadata["path"] == "mock_path" + assert metadata["halo_path"] == "mock_halo_path" + assert "logger" in metadata + + +def test_calculate_halfmass_radius_handles_1d_positions(): + handler = object.__new__(PynbodyHandler) + positions = np.array([1.0, 2.0, 3.0]) + masses = np.array([1.0, 1.0, 1.0]) + radius = handler.calculate_halfmass_radius(positions, masses) + assert radius == 2.0 + + +def test_get_units_warns_for_unknown_unit(mock_simulation, mock_config, tmp_path): + logger = MagicMock() + bad_config = copy.deepcopy(mock_config) + bad_config["units"]["stars"]["coords"] = "NotAUnit" + rotation_path = tmp_path / "rotation_units" + handler = _build_pynbody_handler( + mock_simulation, + bad_config, + rotation_path=str(rotation_path), + logger=logger, + ) + handler.logger.warning.reset_mock() + units = handler.get_units() + assert units["stars"]["coords"] == u.dimensionless_unscaled + handler.logger.warning.assert_called_with( + "Unit 'NotAUnit' for 'stars.coords' not recognized. " "Using dimensionless." + ) diff --git a/tests/test_spectra_ifu.py b/tests/test_spectra_ifu.py index 3d6d9d72..1be48493 100644 --- a/tests/test_spectra_ifu.py +++ b/tests/test_spectra_ifu.py @@ -8,6 +8,7 @@ calculate_cube, calculate_diff, convert_luminoisty_to_flux, + convert_luminoisty_to_flux_factor, cosmological_doppler_shift, get_velocity_component, resample_spectrum, @@ -147,6 +148,28 @@ def test_convert_luminoisty_to_flux(): assert jnp.allclose(flux, expected_flux, rtol=1e-5) +def test_convert_luminoisty_to_flux_factor(): + observation_lum_dist = 10.0 + observation_z = 0.5 + pixel_size = 2.0 + + factor = convert_luminoisty_to_flux_factor( + observation_lum_dist, + observation_z, + pixel_size, + CONSTANTS=mock_config["constants"], + ) + + CONST = mock_config["constants"]["LSOL_TO_ERG"] / ( + mock_config["constants"]["MPC_TO_CM"] ** 2 + ) + expected_factor = ( + CONST / (4 * np.pi * observation_lum_dist**2) / (1 + observation_z) / pixel_size + ) + + assert jnp.isclose(factor, expected_factor, rtol=1e-6) + + def test_velocity_doppler_shift(): wavelength = jnp.array([5000.0, 6000.0, 7000.0]) velocity = jnp.array([[300.0, 400.0, 500.0], [600.0, 700.0, 800.0]]) @@ -172,6 +195,31 @@ def test_velocity_doppler_shift(): ) +def test_velocity_doppler_shift_handles_singleton_leading_axis(): + wavelength = jnp.array([5000.0, 6000.0]) + velocity = jnp.array([[[300.0, 400.0, 500.0], [600.0, 700.0, 800.0]]]) + + doppler_shifted_wavelength = velocity_doppler_shift( + wavelength, + velocity, + direction="y", + SPEED_OF_LIGHT=mock_config["constants"]["SPEED_OF_LIGHT"], + ) + + base_velocities = velocity[0] + expected = jnp.stack( + [ + wavelength + * jnp.exp( + base_velocities[i, 1] / mock_config["constants"]["SPEED_OF_LIGHT"] + ) + for i in range(base_velocities.shape[0]) + ] + ) + + assert jnp.allclose(doppler_shifted_wavelength, expected, rtol=1e-5) + + def test_resample_spectrum(): initial_spectrum = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) initial_wavelength = jnp.array([4000.0, 5000.0, 6000.0, 7000.0, 8000.0]) diff --git a/tests/test_ssp_factory.py b/tests/test_ssp_factory.py index 95eb64fd..71d80014 100644 --- a/tests/test_ssp_factory.py +++ b/tests/test_ssp_factory.py @@ -5,7 +5,8 @@ import numpy as np import pytest -from rubix.paths import TEMPLATE_PATH +from rubix import config as rubix_config +from rubix.spectra.ssp import factory from rubix.spectra.ssp.factory import HDF5SSPGrid, get_ssp_template, pyPipe3DSSPGrid @@ -23,6 +24,15 @@ def reset_config(): pass +def _set_templates(monkeypatch, templates): + monkeypatch.setitem(rubix_config["ssp"], "templates", templates) + + +@pytest.fixture(autouse=True) +def stub_logger(monkeypatch): + monkeypatch.setattr(factory, "get_logger", lambda: MagicMock()) + + def get_config(): from rubix import config @@ -79,9 +89,10 @@ def test_get_ssp_template_existing_template(): template = get_ssp_template(template_name) template_class_name = config["ssp"]["templates"][template_name]["name"] assert template.__class__.__name__ == template_class_name - assert ( - mock_write_fsps_data_to_disk.call_count <= 1 - ), f"Expected at most 1 call to 'write_fsps_data_to_disk', but got {mock_write_fsps_data_to_disk.call_count}" + assert mock_write_fsps_data_to_disk.call_count <= 1, ( + "Expected at most 1 call to 'write_fsps_data_to_disk', " + f"but got {mock_write_fsps_data_to_disk.call_count}" + ) def test_get_ssp_template_existing_template_BC03(): @@ -97,9 +108,9 @@ def test_get_ssp_template_non_existing_template(): with pytest.raises(ValueError) as excinfo: get_ssp_template(template_name) - assert ( - str(excinfo.value) - == "SSP template unknown_template not found in the supported configuration file." + assert str(excinfo.value) == ( + "SSP template unknown_template not found in the supported " + "configuration file." ) @@ -117,9 +128,9 @@ def test_get_ssp_template_invalid_format(): with pytest.raises(ValueError) as excinfo: get_ssp_template(template_name) - assert ( - str(excinfo.value) - == "Currently only HDF5 format and fits files in the format of pyPipe3D format are supported for SSP templates." + assert str(excinfo.value) == ( + "Currently only HDF5 format and fits files in the format of " + "pyPipe3D format are supported for SSP templates." ) @@ -148,6 +159,59 @@ def test_get_ssp_template_existing_fsps_template(): assert template.__class__.__name__ == template_class_name +def test_get_ssp_template_fsps_file_missing(monkeypatch): + grid = MagicMock(spec=HDF5SSPGrid) + from_file_spy = MagicMock(side_effect=[FileNotFoundError, grid]) + monkeypatch.setattr(factory.HDF5SSPGrid, "from_file", from_file_spy) + write_spy = MagicMock() + monkeypatch.setattr(factory, "write_fsps_data_to_disk", write_spy) + templates = { + "FSPS": { + "format": "fsps", + "source": "load_from_file", + "file_name": "fsps.h5", + } + } + _set_templates(monkeypatch, templates) + + result = get_ssp_template("FSPS") + + assert from_file_spy.call_count == 2 + write_spy.assert_called_once_with( + "fsps.h5", + file_location=factory.TEMPLATE_PATH, + ) + assert result is grid + + +def test_get_ssp_template_fsps_rerun_from_scratch(monkeypatch): + grid = MagicMock(spec=HDF5SSPGrid) + from_file_spy = MagicMock(return_value=grid) + monkeypatch.setattr(factory.HDF5SSPGrid, "from_file", from_file_spy) + write_spy = MagicMock() + monkeypatch.setattr(factory, "write_fsps_data_to_disk", write_spy) + templates = { + "FSPS": { + "format": "fsps", + "source": "rerun_from_scratch", + "file_name": "fsps.h5", + } + } + _set_templates(monkeypatch, templates) + + result = get_ssp_template("FSPS") + + from_file_spy.assert_called_once_with( + templates["FSPS"], + file_location=factory.TEMPLATE_PATH, + ) + write_spy.assert_called_once_with( + "fsps.h5", + file_location=factory.TEMPLATE_PATH, + ) + assert result is grid + + def test_get_fsps_template_wrong_source_keyword(): config = get_config() config_copy = config.copy() @@ -165,6 +229,6 @@ def test_get_fsps_template_wrong_source_keyword(): with pytest.raises(ValueError) as excinfo: get_ssp_template("FSPS") assert ( - f"The source {supported_templates['FSPS']['source']} of the FSPS SSP template is not supported." - == str(excinfo.value) - ) + f"The source {supported_templates['FSPS']['source']} " + "of the FSPS SSP template is not supported." + ) == str(excinfo.value) diff --git a/tests/test_ssp_grid.py b/tests/test_ssp_grid.py index a067a709..8d56ce2b 100644 --- a/tests/test_ssp_grid.py +++ b/tests/test_ssp_grid.py @@ -65,10 +65,10 @@ def test_from_hdf5(): mock_file.return_value = mock_instance mock_instance.__enter__.return_value = mock_instance mock_instance.__getitem__.side_effect = lambda key: { - "age": [1, 2, 3], - "metallicity": [0.1, 0.2, 0.3], - "wavelength": [4000, 5000, 6000], - "flux": [0.5, 1.0, 1.5], + "age": np.array([1, 2, 3], dtype=np.float32), + "metallicity": np.array([0.1, 0.2, 0.3], dtype=np.float32), + "wavelength": np.array([4000, 5000, 6000], dtype=np.float32), + "flux": np.array([0.5, 1.0, 1.5], dtype=np.float32), }[key] result = HDF5SSPGrid.from_file(config, file_location) @@ -81,6 +81,59 @@ def test_from_hdf5(): assert np.allclose(result.flux, [0.5, 1.0, 1.5]) +def test_from_hdf5_handles_log_field(): + config = { + "format": "hdf5", + "file_name": "test.hdf5", + "source": "http://example.com/template.hdf5", + "fields": { + "age": { + "name": "age", + "in_log": False, + "units": "Gyr", + }, + "metallicity": { + "name": "metallicity", + "in_log": False, + "units": "", + }, + "wavelength": { + "name": "wavelength", + "in_log": False, + "units": "Angstrom", + }, + "flux": { + "name": "flux", + "in_log": True, + "units": "Lsun/Angstrom", + }, + }, + "name": "TestSSPGrid", + } + file_location = "/path/to/files" + + with ( + patch("os.path.exists") as mock_exists, + patch("rubix.spectra.ssp.grid.h5py.File") as mock_file, + ): + mock_exists.return_value = True + mock_instance = MagicMock() + mock_file.return_value = mock_instance + mock_instance.__enter__.return_value = mock_instance + mock_instance.__getitem__.side_effect = lambda key: { + "age": np.array([1, 2, 3], dtype=np.float32), + "metallicity": np.array([0.1, 0.2, 0.3], dtype=np.float32), + "wavelength": np.array([4000, 5000, 6000], dtype=np.float32), + "flux": np.array([0.5, 1.0, 1.5], dtype=np.float32), + }[key] + + result = HDF5SSPGrid.from_file(config, file_location) + + assert isinstance(result, HDF5SSPGrid) + expected_flux = np.power(10, np.array([0.5, 1.0, 1.5], dtype=np.float32)) + assert np.allclose(result.flux, expected_flux) + + def test_from_hdf5_wrong_format(): config = { "format": "wrong", @@ -365,6 +418,54 @@ def test_from_pyPipe3D_wrong_field_name(): assert str(e.value) == "Field wrong_field_name not recognized" +def test_from_pyPipe3D_handles_log_field(tmp_path): + config = { + "format": "pypipe3d", + "file_name": "pyPipe_log.fits", + "source": "http://example.com/", + "fields": { + "age": {"name": "age", "in_log": True, "units": "Gyr"}, + "metallicity": {"name": "metallicity", "in_log": False, "units": ""}, + "wavelength": {"name": "wavelength", "in_log": False, "units": "Angstrom"}, + "flux": {"name": "flux", "in_log": False, "units": "Lsun/Angstrom"}, + }, + "name": "pyPipeSSPGrid", + } + + header = fits.Header() + header["CRVAL1"] = 4000 + header["CDELT1"] = 1000 + header["NAXIS1"] = 3 + header["CRPIX1"] = 1 + header["NAXIS2"] = 2 + header["NAME0"] = "spec_ssp_1.0_z01.spec" + header["NAME1"] = "spec_ssp_1.0_z02.spec" + header["NORM0"] = 1.0 + header["NORM1"] = 1.0 + + data = np.array([[0.5, 1.0, 1.5], [0.6, 1.1, 1.6]], dtype=np.float32) + hdu = fits.PrimaryHDU(data=data, header=header) + hdul = fits.HDUList([hdu]) + file_path = tmp_path / "pyPipe_log.fits" + hdul.writeto( + file_path, + overwrite=True, + output_verify="silentfix", + ) + + with patch( + "rubix.spectra.ssp.grid.pyPipe3DSSPGrid.checkout_SSP_template", + return_value=str(file_path), + ): + grid = pyPipe3DSSPGrid.from_file(config, str(tmp_path)) + + expected_age = jnp.power( + 10, + jnp.array([1.0], dtype=jnp.float32), + ) + assert np.allclose(grid.age, expected_age) + + def test_from_pyPipe3D_wrong_format(): config = { "format": "wrong", @@ -650,6 +751,33 @@ def test_checkout_SSP_template_file_download_failed_HDF5SSPGrid(): HDF5SSPGrid.checkout_SSP_template(config, file_location) +def test_checkout_SSP_template_raise_for_status(monkeypatch): + config = { + "format": "hdf5", + "file_name": "test.hdf5", + "source": "http://example.com", + } + file_location = "/tmp" + + response = MagicMock() + response.raise_for_status.side_effect = requests.exceptions.HTTPError("status fail") + download_msg = "Could not download file test.hdf5 from url http://example.com/." + + with ( + patch("os.path.exists") as mock_exists, + patch("requests.get", return_value=response), + ): + mock_exists.return_value = False + + with pytest.raises( + FileNotFoundError, + match=download_msg, + ): + SSPGrid.checkout_SSP_template(config, file_location) + + response.raise_for_status.assert_called_once() + + def test_get_lookup_interpolation(): # Create a mock SSPGrid instance age = jnp.array([1e9, 2e9, 3e9]) diff --git a/tests/test_visualisation.py b/tests/test_visualisation.py new file mode 100644 index 00000000..6d63543d --- /dev/null +++ b/tests/test_visualisation.py @@ -0,0 +1,239 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import h5py +import numpy as np + +from rubix.core import visualisation + + +class DummyWave: + def __init__(self): + self.unit = "Angstrom" + + def coord(self, index=None): + if index is None: + return np.array([4000.0, 5000.0]) + return 4000.0 + index * 1000.0 + + +class DummyImage: + def __init__(self, data): + self.data = data + self.plot = MagicMock() + + +class DummySpectrum: + def __init__(self, data): + self.data = data + self.plot = MagicMock() + + +class DummyCube: + def __init__(self): + self.shape = (4, 3, 3) + self.data = np.arange(np.prod(self.shape)).reshape(self.shape) + self.wave = DummyWave() + self.slice_calls = [] + + def __getitem__(self, key): + self.slice_calls.append(key) + sliced = self.data[key] + if sliced.ndim == 3: + return DummyCubeSlice(sliced) + return DummySpectrum(sliced) + + +class DummyCubeSlice: + def __init__(self, data): + self._data = data + + def sum(self, axis=0): + return DummyImage(self._data.sum(axis=axis)) + + +def test_plot_cube_slice_and_spectrum(monkeypatch): + cube, interact_data, ax1, ax2, ax3 = _prepare_visualize_plot(monkeypatch) + plot_fn = interact_data["func"] + plot_fn(wave_index=1, wave_range=1, x=1, y=1, radius=1) + + ax1.scatter.assert_called_once() + ax1.imshow.assert_called_once() + ax2.plot.assert_called() + ax3.plot.assert_called_once() + ax2.axvspan.assert_called_once() + ax2.set_xlabel.assert_called_once() + ax2.set_ylabel.assert_called_once() + ax2.grid.assert_called_once() + ax2.legend.assert_called_once() + ax3.set_ylabel.assert_called_once() + ax3.legend.assert_called_once() + ax2.set_ylim.assert_called_with(bottom=0) + ax3.set_ylim.assert_called_with(bottom=0) + ax3.vlines.assert_called_once() + + +def test_plot_cube_slice_and_spectrum_clamps_start(monkeypatch): + cube, interact_data, _, _, _ = _prepare_visualize_plot(monkeypatch) + plot_fn = interact_data["func"] + + plot_fn(wave_index=1, wave_range=2, x=1, y=1, radius=1) + + assert cube.slice_calls + first_slice = cube.slice_calls[0] + assert isinstance(first_slice, tuple) + slice_axis = first_slice[0] + assert isinstance(slice_axis, slice) + assert slice_axis.start == 0 + + +def _prepare_visualize_plot(monkeypatch): + cube = DummyCube() + monkeypatch.setattr(visualisation, "Cube", lambda filename: cube) + + def fake_slider(**kwargs): + return SimpleNamespace(description=kwargs.get("description", "")) + + monkeypatch.setattr(visualisation.widgets, "IntSlider", fake_slider) + + ax1 = MagicMock() + ax2 = MagicMock() + ax3 = MagicMock() + ax2.twinx.return_value = ax3 + fig = MagicMock() + monkeypatch.setattr( + visualisation.plt, + "subplots", + lambda *args, **kwargs: (fig, (ax1, ax2)), + ) + monkeypatch.setattr(visualisation.plt, "tight_layout", MagicMock()) + monkeypatch.setattr(visualisation.plt, "show", MagicMock()) + + interact_data = {} + + def fake_interact(func, **kwargs): + interact_data["func"] = func + return "widget" + + monkeypatch.setattr(visualisation, "interact", fake_interact) + + visualisation.visualize_rubix("/tmp/cube.fits") + + return cube, interact_data, ax1, ax2, ax3 + + +def _create_star_h5(tmp_path): + path = tmp_path / "stars.h5" + with h5py.File(path, "w") as f: + stars = f.create_group("particles/stars") + stars.create_dataset("age", data=np.array([1.5, 2.0, 3.0])) + stars.create_dataset( + "coords", + data=np.array( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ] + ), + ) + stars.create_dataset("metallicity", data=np.array([0.1, 0.2, 0.3])) + return path + + +def test_visualize_rubix_sets_up_interact(monkeypatch): + cube = MagicMock(shape=(4, 5, 6)) + monkeypatch.setattr(visualisation, "Cube", MagicMock(return_value=cube)) + + slider_calls = [] + + def fake_int_slider(**kwargs): + slider = MagicMock() + slider.description = kwargs.get("description") + slider_calls.append(kwargs) + return slider + + monkeypatch.setattr(visualisation.widgets, "IntSlider", fake_int_slider) + interact_mock = MagicMock(return_value="widget") + monkeypatch.setattr(visualisation, "interact", interact_mock) + + result = visualisation.visualize_rubix("/tmp/cube.fits") + + visualisation.Cube.assert_called_once_with(filename="/tmp/cube.fits") + assert result == "widget" + assert len(slider_calls) == 5 + interact_mock.assert_called_once() + interact_kwargs = interact_mock.call_args.kwargs + assert "wave_index" in interact_kwargs + assert interact_kwargs["wave_index"].description == "Waveindex:" + assert interact_kwargs["x"].description == "X Pixel:" + + +def test_visualize_cubeviz_loads_and_shows(monkeypatch): + cubeviz_mock = MagicMock() + monkeypatch.setattr( + visualisation, + "Cubeviz", + MagicMock(return_value=cubeviz_mock), + ) + + visualisation.visualize_cubeviz("/tmp/cube.fits") + + visualisation.Cubeviz.assert_called_once() + cubeviz_mock.load_data.assert_called_once_with("/tmp/cube.fits") + cubeviz_mock.show.assert_called_once() + + +def test_stellar_age_histogram_uses_hdf5_data(tmp_path, monkeypatch): + path = _create_star_h5(tmp_path) + plt = visualisation.plt + hist = MagicMock() + monkeypatch.setattr(plt, "figure", MagicMock()) + monkeypatch.setattr(plt, "hist", hist) + monkeypatch.setattr(plt, "xlabel", MagicMock()) + monkeypatch.setattr(plt, "ylabel", MagicMock()) + monkeypatch.setattr(plt, "grid", MagicMock()) + monkeypatch.setattr(plt, "tight_layout", MagicMock()) + monkeypatch.setattr(plt, "show", MagicMock()) + + visualisation.stellar_age_histogram(str(path)) + + hist.assert_called_once() + np.testing.assert_array_equal(hist.call_args.args[0], np.array([1.5, 2.0, 3.0])) + + +def test_star_coords_2d_scatter(monkeypatch, tmp_path): + path = _create_star_h5(tmp_path) + plt = visualisation.plt + scatter = MagicMock() + monkeypatch.setattr(plt, "figure", MagicMock()) + monkeypatch.setattr(plt, "scatter", scatter) + monkeypatch.setattr(plt, "xlabel", MagicMock()) + monkeypatch.setattr(plt, "ylabel", MagicMock()) + monkeypatch.setattr(plt, "grid", MagicMock()) + monkeypatch.setattr(plt, "show", MagicMock()) + + visualisation.star_coords_2D(str(path)) + + scatter.assert_called_once() + x_arg, y_arg = scatter.call_args.args[:2] + np.testing.assert_array_equal(x_arg, np.array([0.0, 3.0])) + np.testing.assert_array_equal(y_arg, np.array([1.0, 4.0])) + + +def test_star_metallicity_histogram_plots_metallicity(monkeypatch, tmp_path): + path = _create_star_h5(tmp_path) + plt = visualisation.plt + hist = MagicMock() + monkeypatch.setattr(plt, "figure", MagicMock()) + monkeypatch.setattr(plt, "hist", hist) + monkeypatch.setattr(plt, "xlabel", MagicMock()) + monkeypatch.setattr(plt, "ylabel", MagicMock()) + monkeypatch.setattr(plt, "title", MagicMock()) + monkeypatch.setattr(plt, "grid", MagicMock()) + monkeypatch.setattr(plt, "tight_layout", MagicMock()) + monkeypatch.setattr(plt, "show", MagicMock()) + + visualisation.star_metallicity_histogram(str(path)) + + hist.assert_called_once() + np.testing.assert_array_equal(hist.call_args.args[0], np.array([0.1, 0.2, 0.3]))