diff --git a/.gitignore b/.gitignore index c091f83..7c10e49 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,6 @@ rubix/spectra/ssp/templates/fsps.h5 notebooks/frames notebooks/frames/* notebooks/data/* + +# don´t add .env files +*.env diff --git a/notebooks/rubix_pipeline_single_function_shard_map.ipynb b/notebooks/rubix_pipeline_single_function_shard_map.ipynb new file mode 100644 index 0000000..8c3e989 --- /dev/null +++ b/notebooks/rubix_pipeline_single_function_shard_map.ipynb @@ -0,0 +1,486 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "#import os\n", + "#import multiprocessing\n", + "\n", + "# Logical cores (includes hyperthreads)\n", + "#print(\"Logical cores:\", os.cpu_count())\n", + "\n", + "\n", + "# Total threads/cores via multiprocessing\n", + "#print(\"multiprocessing.cpu_count():\", multiprocessing.cpu_count())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "import os\n", + "\n", + "# Tell XLA to fake 2 host CPU devices\n", + "#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=3'\n", + "\n", + "# Only make GPU 0 and GPU 1 visible to JAX:\n", + "#os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'\n", + "\n", + "#os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", + "\n", + "import jax\n", + "\n", + "# Now JAX will list two CpuDevice entries\n", + "print(jax.devices())\n", + "# → [CpuDevice(id=0), CpuDevice(id=1)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "#import os\n", + "# os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n", + "#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", + "os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'\n", + "#os.environ['SPS_HOME'] = '/export/home/aschaibl/fsps'\n", + "#os.environ['SPS_HOME'] = '/home/annalena_data/sps_fsps'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# RUBIX pipeline\n", + "\n", + "RUBIX is designed as a linear pipeline, where the individual functions are called and constructed as a pipeline. This allows as to execude the whole data transformation from a cosmological hydrodynamical simulation of a galaxy to an IFU cube in two lines of code. This notebook shows, how to execute the pipeline. To see, how the pipeline is execuded in small individual steps per individual function, we refer to the notebook `rubix_pipeline_stepwise.ipynb`.\n", + "\n", + "## How to use the Pipeline\n", + "1) Define a `config`\n", + "2) Setup the `pipeline yaml`\n", + "3) Run the RUBIX pipeline\n", + "4) Do science with the mock-data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Config\n", + "\n", + "The `config` contains all the information needed to run the pipeline. Those are run specfic configurations. Currently we just support Illustris as simulation, but extensions to other simulations (e.g. NIHAO) are planned.\n", + "\n", + "For the `config` you can choose the following options:\n", + "- `pipeline`: you specify the name of the pipeline that is stored in the yaml file in rubix/config/pipeline_config.yml\n", + "- `logger`: RUBIX has implemented a logger to report the user, what is happening during the pipeline execution and give warnings\n", + "- `data - args - particle_type`: load only stars particle (\"particle_type\": [\"stars\"]) or only gas particle (\"particle_type\": [\"gas\"]) or both (\"particle_type\": [\"stars\",\"gas\"])\n", + "- `data - args - simulation`: choose the Illustris simulation (e.g. \"simulation\": \"TNG50-1\")\n", + "- `data - args - snapshot`: which time step of the simulation (99 for present day)\n", + "- `data - args - save_data_path`: set the path to save the downloaded Illustris data\n", + "- `data - load_galaxy_args - id`: define, which Illustris galaxy is downloaded\n", + "- `data - load_galaxy_args - reuse`: if True, if in th esave_data_path directory a file for this galaxy id already exists, the downloading is skipped and the preexisting file is used\n", + "- `data - subset`: only a defined number of stars/gas particles is used and stored for the pipeline. This may be helpful for quick testing\n", + "- `simulation - name`: currently only IllustrisTNG is supported\n", + "- `simulation - args - path`: where the data is stored and how the file will be named\n", + "- `output_path`: where the hdf5 file is stored, which is then the input to the RUBIX pipeline\n", + "- `telescope - name`: define the telescope instrument that is observing the simulation. Some telescopes are predefined, e.g. MUSE. If your instrument does not exist predefined, you can easily define your instrument in rubix/telescope/telescopes.yaml\n", + "- `telescope - psf`: define the point spread function that is applied to the mock data\n", + "- `telescope - lsf`: define the line spread function that is applied to the mock data\n", + "- `telescope - noise`: define the noise that is applied to the mock data\n", + "- `cosmology`: specify the cosmology you want to use, standard for RUBIX is \"PLANCK15\"\n", + "- `galaxy - dist_z`: specify at which redshift the mock-galaxy is observed\n", + "- `galaxy - rotation`: specify the orientation of the galaxy. You can set the types edge-on or face-on or specify the angles alpha, beta and gamma as rotations around x-, y- and z-axis\n", + "- `ssp - template`: specify the simple stellar population lookup template to get the stellar spectrum for each stars particle. In RUBIX frequently \"BruzualCharlot2003\" is used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "import matplotlib.pyplot as plt\n", + "from rubix.core.pipeline import RubixPipeline \n", + "import os\n", + "\n", + "galaxy_id = \"g8.13e11\"\n", + "\n", + "config_NIHAO = {\n", + " \"pipeline\":{\"name\": \"calc_ifu\"},\n", + " \n", + " \"logger\": {\n", + " \"log_level\": \"DEBUG\",\n", + " \"log_file_path\": None,\n", + " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", + " },\n", + " \"data\": {\n", + " \"name\": \"NihaoHandler\",\n", + " \"args\": {\n", + " \"particle_type\": [\"stars\"],\n", + " \"save_data_path\": \"data\",\n", + " \"snapshot\": \"1024\",\n", + " },\n", + " \"load_galaxy_args\": {\"reuse\": True, \"id\": galaxy_id},\n", + " \"subset\": {\"use_subset\": False, \"subset_size\": 200000},\n", + " },\n", + " \"simulation\": {\n", + " \"name\": \"NIHAO\",\n", + " \"args\": {\n", + " \"path\": f'/home/_data/nihao/nihao_classic/{galaxy_id}/{galaxy_id}.01024',\n", + " \"halo_path\": f'/home/_data/nihao/nihao_classic/{galaxy_id}/{galaxy_id}.01024.z0.000.AHF_halos',\n", + " \"halo_id\": 0,\n", + " },\n", + " },\n", + " \"output_path\": \"output\",\n", + "\n", + " \"telescope\":\n", + " {\"name\": \"MUSE\",\n", + " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", + " \"lsf\": {\"sigma\": 0.5},\n", + " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},},\n", + " \"cosmology\":\n", + " {\"name\": \"PLANCK15\"},\n", + " \n", + " \"galaxy\":\n", + " {\"dist_z\": 0.1,\n", + " \"rotation\": {\"type\": \"edge-on\"},\n", + " },\n", + " \n", + " \"ssp\": {\n", + " \"template\": {\n", + " \"name\": \"Mastar_CB19_SLOG_1_5\"\n", + " },\n", + " \"dust\": {\n", + " \"extinction_model\": \"Cardelli89\",\n", + " \"dust_to_gas_ratio\": 0.01,\n", + " \"dust_to_metals_ratio\": 0.4,\n", + " \"dust_grain_density\": 3.5,\n", + " \"Rv\": 3.1,\n", + " },\n", + " }, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "config_TNG = {\n", + " \"pipeline\":{\"name\": \"calc_ifu\"},\n", + " \n", + " \"logger\": {\n", + " \"log_level\": \"DEBUG\",\n", + " \"log_file_path\": None,\n", + " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", + " },\n", + " \"data\": {\n", + " \"name\": \"IllustrisAPI\",\n", + " \"args\": {\n", + " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", + " \"particle_type\": [\"stars\"],\n", + " \"simulation\": \"TNG50-1\",\n", + " \"snapshot\": 99,\n", + " \"save_data_path\": \"data\",\n", + " },\n", + " \n", + " \"load_galaxy_args\": {\n", + " \"id\": 12,\n", + " \"reuse\": True,\n", + " },\n", + " \n", + " \"subset\": {\n", + " \"use_subset\": True,\n", + " \"subset_size\": 2000,\n", + " },\n", + " },\n", + " \"simulation\": {\n", + " \"name\": \"IllustrisTNG\",\n", + " \"args\": {\n", + " \"path\": \"data/galaxy-id-12.hdf5\",\n", + " },\n", + " \n", + " },\n", + " \"output_path\": \"output\",\n", + "\n", + " \"telescope\":\n", + " {\"name\": \"MUSE\",\n", + " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", + " \"lsf\": {\"sigma\": 0.5},\n", + " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},},\n", + " \"cosmology\":\n", + " {\"name\": \"PLANCK15\"},\n", + " \n", + " \"galaxy\":\n", + " {\"dist_z\": 0.1,\n", + " \"rotation\": {\"type\": \"edge-on\"},\n", + " },\n", + " \n", + " \"ssp\": {\n", + " \"template\": {\n", + " \"name\": \"Mastar_CB19_SLOG_1_5\"\n", + " },\n", + " \"dust\": {\n", + " \"extinction_model\": \"Cardelli89\",\n", + " \"dust_to_gas_ratio\": 0.01,\n", + " \"dust_to_metals_ratio\": 0.4,\n", + " \"dust_grain_density\": 3.5,\n", + " \"Rv\": 3.1,\n", + " },\n", + " }, \n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Pipeline yaml\n", + "\n", + "To run the RUBIX pipeline, you need a yaml file (stored in `rubix/config/pipeline_config.yml`) that defines which functions are used during the execution of the pipeline. This shows the example pipeline yaml to compute a stellar IFU cube.\n", + "\n", + "```yaml\n", + "calc_ifu:\n", + " Transformers:\n", + " rotate_galaxy:\n", + " name: rotate_galaxy\n", + " depends_on: null\n", + " args: []\n", + " kwargs:\n", + " type: \"face-on\"\n", + " filter_particles:\n", + " name: filter_particles\n", + " depends_on: rotate_galaxy\n", + " args: []\n", + " kwargs: {}\n", + " spaxel_assignment:\n", + " name: spaxel_assignment\n", + " depends_on: filter_particles\n", + " args: []\n", + " kwargs: {}\n", + "\n", + " reshape_data:\n", + " name: reshape_data\n", + " depends_on: spaxel_assignment\n", + " args: []\n", + " kwargs: {}\n", + "\n", + " calculate_spectra:\n", + " name: calculate_spectra\n", + " depends_on: reshape_data\n", + " args: []\n", + " kwargs: {}\n", + "\n", + " scale_spectrum_by_mass:\n", + " name: scale_spectrum_by_mass\n", + " depends_on: calculate_spectra\n", + " args: []\n", + " kwargs: {}\n", + " doppler_shift_and_resampling:\n", + " name: doppler_shift_and_resampling\n", + " depends_on: scale_spectrum_by_mass\n", + " args: []\n", + " kwargs: {}\n", + " calculate_datacube:\n", + " name: calculate_datacube\n", + " depends_on: doppler_shift_and_resampling\n", + " args: []\n", + " kwargs: {}\n", + " convolve_psf:\n", + " name: convolve_psf\n", + " depends_on: calculate_datacube\n", + " args: []\n", + " kwargs: {}\n", + " convolve_lsf:\n", + " name: convolve_lsf\n", + " depends_on: convolve_psf\n", + " args: []\n", + " kwargs: {}\n", + " apply_noise:\n", + " name: apply_noise\n", + " depends_on: convolve_lsf\n", + " args: []\n", + " kwargs: {}\n", + "```\n", + "\n", + "Ther is one thing you have to know about the naming of the functions in this yaml: To use the functions inside the pipeline, the functions have to be called exactly the same as they are returned from the core module function!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Run the pipeline\n", + "\n", + "After defining the `config` and the `pipeline_config` you can simply run the whole pipeline by these two lines of code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "pipe = RubixPipeline(config_TNG)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "\n", + "devices = jax.devices()\n", + "inputdata = pipe.prepare_data()\n", + "rubixdata = pipe.run_sharded(inputdata, devices)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "\n", + "#inputdata = pipe.prepare_data()\n", + "#shard_rubixdata = pipe.run_sharded_chunked(inputdata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Mock-data\n", + "\n", + "Now we have our final datacube and can use the mock-data to do science. Here we have a quick look in the optical wavelengthrange of the mock-datacube and show the spectra of a central spaxel and a spatial image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "import jax.numpy as jnp\n", + "\n", + "wave = pipe.telescope.wave_seq\n", + "# get the indices of the visible wavelengths of 4000-8000 Angstroms\n", + "visible_indices = jnp.where((wave >= 4000) & (wave <= 8000))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is how you can access the spectrum of an individual spaxel, the wavelength can be accessed via `pipe.wave_seq`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "wave = pipe.telescope.wave_seq\n", + "\n", + "#spectra = rubixdata#.stars.datacube # Spectra of all stars\n", + "spectra_sharded = rubixdata # Spectra of all stars\n", + "#print(spectra.shape)\n", + "\n", + "plt.figure(figsize=(10, 5))\n", + "\n", + "plt.title(\"Rubix Sharded\")\n", + "plt.xlabel(\"Wavelength [Angstrom]\")\n", + "plt.ylabel(\"Flux [erg/s/cm^2/Angstrom]\")\n", + "plt.plot(wave, spectra_sharded[12,12,:])\n", + "plt.plot(wave, spectra_sharded[8,12,:])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot a spacial image of the data cube" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#NBVAL_SKIP\n", + "# get the spectra of the visible wavelengths from the ifu cube\n", + "#visible_spectra = rubixdata.stars.datacube[ :, :, visible_indices[0]]\n", + "#visible_spectra = rubixdata[ :, :, visible_indices[0]]\n", + "sharded_visible_spectra = rubixdata[ :, :, visible_indices[0]]\n", + "#visible_spectra.shape\n", + "\n", + "#image = jnp.sum(visible_spectra, axis=2)\n", + "sharded_image = jnp.sum(sharded_visible_spectra, axis=2)\n", + "\n", + "# Plot side by side\n", + "fig, axes = plt.subplots(1, 1, figsize=(12, 5))\n", + "\n", + "# Sharded IFU datacube image\n", + "im1 = axes.imshow(sharded_image, origin=\"lower\", cmap=\"inferno\")\n", + "axes.set_title(\"Sharded IFU Datacube\")\n", + "fig.colorbar(im1, ax=axes)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DONE!\n", + "\n", + "Congratulations, you have sucessfully run the RUBIX pipeline to create your own mock-observed IFU datacube! Now enjoy playing around with the RUBIX pipeline and enjoy doing amazing science with RUBIX :)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rubix", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/rubix_pipeline_stepwise.ipynb b/notebooks/rubix_pipeline_stepwise.ipynb index e8db48e..e13ac78 100644 --- a/notebooks/rubix_pipeline_stepwise.ipynb +++ b/notebooks/rubix_pipeline_stepwise.ipynb @@ -6,6 +6,7 @@ "metadata": {}, "outputs": [], "source": [ + "# NBVAL_SKIP\n", "import os\n", "#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", "os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'" @@ -510,7 +511,7 @@ ], "metadata": { "kernelspec": { - "display_name": "rubix", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -524,7 +525,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 24a6326..fa512b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,9 +32,7 @@ dependencies = [ "pyaml", "jaxtyping", "equinox", - "jax[cpu]!=0.4.27", - "jax[cpu]!=0.4.36", - "jax[cpu]!=0.5.1", + "jax[cpu]>0.5.1", "interpax", "astroquery", "beartype", @@ -51,8 +49,7 @@ tests = [ "pytest-cov", "pytest-mock", "nbval", - "jax[cpu]!=0.4.27", - "jax[cpu]!=0.4.36", + "jax[cpu]>0.5.1", "pre-commit", ] docs = [ @@ -63,10 +60,8 @@ docs = [ "sphinx_mdinclude", "sphinx_rtd_theme", ] - cuda = [ - "jax[cuda]!=0.4.27", - "jax[cuda]!=0.4.36", + "jax[cuda]>0.5.1", ] diff --git a/rubix/config/pipeline_config.yml b/rubix/config/pipeline_config.yml index 19b0277..477b1fd 100644 --- a/rubix/config/pipeline_config.yml +++ b/rubix/config/pipeline_config.yml @@ -15,37 +15,14 @@ calc_ifu: depends_on: filter_particles args: [] kwargs: {} - - reshape_data: - name: reshape_data + calculate_datacube_particlewise: + name: calculate_datacube_particlewise depends_on: spaxel_assignment args: [] kwargs: {} - - calculate_spectra: - name: calculate_spectra - depends_on: reshape_data - args: [] - kwargs: {} - - scale_spectrum_by_mass: - name: scale_spectrum_by_mass - depends_on: calculate_spectra - args: [] - kwargs: {} - doppler_shift_and_resampling: - name: doppler_shift_and_resampling - depends_on: scale_spectrum_by_mass - args: [] - kwargs: {} - calculate_datacube: - name: calculate_datacube - depends_on: doppler_shift_and_resampling - args: [] - kwargs: {} convolve_psf: name: convolve_psf - depends_on: calculate_datacube + depends_on: calculate_datacube_particlewise args: [] kwargs: {} convolve_lsf: @@ -76,42 +53,47 @@ calc_dusty_ifu: depends_on: filter_particles args: [] kwargs: {} - - reshape_data: - name: reshape_data + calculate_dusty_datacube_particlewise: + name: calculate_dusty_datacube_particlewise depends_on: spaxel_assignment args: [] kwargs: {} - - calculate_spectra: - name: calculate_spectra - depends_on: reshape_data + convolve_psf: + name: convolve_psf + depends_on: calculate_dusty_datacube_particlewise args: [] kwargs: {} - - scale_spectrum_by_mass: - name: scale_spectrum_by_mass - depends_on: calculate_spectra + convolve_lsf: + name: convolve_lsf + depends_on: convolve_psf + args: [] + kwargs: {} + apply_noise: + name: apply_noise + depends_on: convolve_lsf args: [] kwargs: {} - doppler_shift_and_resampling: - name: doppler_shift_and_resampling - depends_on: scale_spectrum_by_mass + +calc_gradient: + Transformers: + rotate_galaxy: + name: rotate_galaxy + depends_on: null args: [] kwargs: {} - calculate_extinction: - name: calculate_extinction - depends_on: doppler_shift_and_resampling + spaxel_assignment: + name: spaxel_assignment + depends_on: rotate_galaxy args: [] kwargs: {} - calculate_datacube: - name: calculate_datacube - depends_on: calculate_extinction + calculate_datacube_particlewise: + name: calculate_datacube_particlewise + depends_on: spaxel_assignment args: [] kwargs: {} convolve_psf: name: convolve_psf - depends_on: calculate_datacube + depends_on: calculate_datacube_particlewise args: [] kwargs: {} convolve_lsf: diff --git a/rubix/config/pynbody_config.yml b/rubix/config/pynbody_config.yml index d25f045..802dc9e 100644 --- a/rubix/config/pynbody_config.yml +++ b/rubix/config/pynbody_config.yml @@ -34,7 +34,6 @@ units: metals: "dimensionless" #OxMassFrac: "dimensionless" #HI: "dimensionless" - metallicity: "Zsun" coords: "kpc" velocity: "km/s" mass: "Msun" diff --git a/rubix/core/data.py b/rubix/core/data.py index b1ddb7c..ce25898 100644 --- a/rubix/core/data.py +++ b/rubix/core/data.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from functools import partial -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp @@ -15,51 +15,6 @@ from rubix.logger import get_logger from rubix.utils import load_galaxy_data, read_yaml -# class Particles: -# def __init__(self, particle_data: object): -# self.particle_data = particle_data -# self.attributes = self._filter_attributes() -# -# def _filter_attributes(self) -> list: -# """ -# Filters the attributes of the particle_data object based on the specified criteria. -# """ -# return [ -# attr -# for attr in dir(self.particle_data) -# if not attr.startswith("__") -# and not callable(getattr(self.particle_data, attr)) -# ] -# -# def get_attributes(self) -> list: -# """ -# Returns the filtered attributes. -# """ -# return self.attributes - - -# class Particles: -# def __init__(self, particle_data: object): -# self.particle_data = particle_data -# self.attributes = self._filter_attributes() -# -# def _filter_attributes(self) -> list: -# """ -# Filters the attributes of the particle_data object based on the specified criteria. -# """ -# return [ -# attr -# for attr in dir(self.particle_data) -# if not attr.startswith("__") -# and not callable(getattr(self.particle_data, attr)) -# ] -# -# def get_attributes(self) -> list: -# """ -# Returns the filtered attributes. -# """ -# return self.attributes - # Registering the dataclass with JAX for automatic tree traversal @jaxtyped(typechecker=typechecker) @@ -75,9 +30,9 @@ class Galaxy: halfmassrad_stars: Half mass radius of the stars in the galaxy """ - redshift: Optional[jnp.ndarray] = None - center: Optional[jnp.ndarray] = None - halfmassrad_stars: Optional[jnp.ndarray] = None + redshift: Optional[Any] = None + center: Optional[Any] = None + halfmassrad_stars: Optional[Any] = None def __repr__(self): representationString = ["Galaxy:"] @@ -141,16 +96,16 @@ class StarsData: """ - coords: Optional[jnp.ndarray] = None - velocity: Optional[jnp.ndarray] = None - mass: Optional[jnp.ndarray] = None - metallicity: Optional[jnp.ndarray] = None - age: Optional[jnp.ndarray] = None - pixel_assignment: Optional[jnp.ndarray] = None - spatial_bin_edges: Optional[jnp.ndarray] = None - mask: Optional[jnp.ndarray] = None - spectra: Optional[jnp.ndarray] = None - datacube: Optional[jnp.ndarray] = None + coords: Optional[Any] = None + velocity: Optional[Any] = None + mass: Optional[Any] = None + metallicity: Optional[Any] = None + age: Optional[Any] = None + pixel_assignment: Optional[Any] = None + spatial_bin_edges: Optional[Any] = None + mask: Optional[Any] = None + spectra: Optional[Any] = None + datacube: Optional[Any] = None def __repr__(self): representationString = ["StarsData:"] @@ -227,20 +182,20 @@ class GasData: datacube: IFU datacube for the gas component """ - coords: Optional[jnp.ndarray] = None - velocity: Optional[jnp.ndarray] = None - mass: Optional[jnp.ndarray] = None - density: Optional[jnp.ndarray] = None - internal_energy: Optional[jnp.ndarray] = None - metallicity: Optional[jnp.ndarray] = None - metals: Optional[jnp.ndarray] = None - sfr: Optional[jnp.ndarray] = None - electron_abundance: Optional[jnp.ndarray] = None - pixel_assignment: Optional[jnp.ndarray] = None - spatial_bin_edges: Optional[jnp.ndarray] = None - mask: Optional[jnp.ndarray] = None - spectra: Optional[jnp.ndarray] = None - datacube: Optional[jnp.ndarray] = None + coords: Optional[Any] = None + velocity: Optional[Any] = None + mass: Optional[Any] = None + density: Optional[Any] = None + internal_energy: Optional[Any] = None + metallicity: Optional[Any] = None + metals: Optional[Any] = None + sfr: Optional[Any] = None + electron_abundance: Optional[Any] = None + pixel_assignment: Optional[Any] = None + spatial_bin_edges: Optional[Any] = None + mask: Optional[Any] = None + spectra: Optional[Any] = None + datacube: Optional[Any] = None def __repr__(self): representationString = ["GasData:"] @@ -321,12 +276,6 @@ def __repr__(self): representationString.append("\n\t".join(f"{k}: {v}".split("\n"))) return "\n\t".join(representationString) - # def __post_init__(self): - # if self.stars is not None: - # self.stars = Particles(self.stars) - # if self.gas is not None: - # self.gas = Particles(self.gas) - def tree_flatten(self): """ Flattens the RubixData object into a tuple of children and auxiliary data @@ -427,13 +376,18 @@ def convert_to_rubix(config: Union[dict, str]): # If the simulationtype is IllustrisAPI, get data from IllustrisAPI # TODO: we can do this more elgantly + if "data" in config: if config["data"]["name"] == "IllustrisAPI": logger.info("Loading data from IllustrisAPI") api = IllustrisAPI(**config["data"]["args"], logger=logger) api.load_galaxy(**config["data"]["load_galaxy_args"]) + elif config["data"]["name"] == "NihaoHandler": + logger.info("Loading data from Nihao simulation") + else: + raise ValueError(f"Unknown data source: {config['data']['name']}.") - # Load the saved data into the input handler + # Load the saved data into the input handler logger.info("Loading data into input handler") input_handler = get_input_handler(config, logger=logger) input_handler.to_rubix(output_path=config["output_path"]) diff --git a/rubix/core/fits.py b/rubix/core/fits.py index a766122..075ca6f 100644 --- a/rubix/core/fits.py +++ b/rubix/core/fits.py @@ -25,15 +25,7 @@ def store_fits(config, data, filepath): logger_config = config.get("logger", None) logger = get_logger(logger_config) - if "cube_type" not in config["data"]["args"]: - datacube = data.stars.datacube - parttype = "stars" - elif config["data"]["args"]["cube_type"] == "stars": - datacube = data.stars.datacube - parttype = "stars" - elif config["data"]["args"]["cube_type"] == "gas": - datacube = data.gas.datacube - parttype = "gas" + datacube = data telescope = get_telescope(config) diff --git a/rubix/core/ifu.py b/rubix/core/ifu.py index 2388f76..daa13b2 100644 --- a/rubix/core/ifu.py +++ b/rubix/core/ifu.py @@ -3,339 +3,106 @@ import jax import jax.numpy as jnp from beartype import beartype as typechecker +from jax import lax from jaxtyping import Array, Float, jaxtyped from rubix import config as rubix_config from rubix.core.data import GasData, StarsData from rubix.logger import get_logger from rubix.spectra.ifu import ( - calculate_cube, + _velocity_doppler_shift_single, cosmological_doppler_shift, resample_spectrum, - velocity_doppler_shift, ) from .data import RubixData -from .ssp import ( - get_lookup_interpolation, - get_lookup_interpolation_pmap, - get_lookup_interpolation_vmap, - get_ssp, -) +from .ssp import get_lookup_interpolation, get_ssp from .telescope import get_telescope @jaxtyped(typechecker=typechecker) -def get_calculate_spectra(config: dict) -> Callable: +def get_calculate_datacube_particlewise(config: dict) -> Callable: """ - The function gets the lookup function that performs the lookup to the SSP model, - and parallelizes the funciton across all GPUs. + Create a function that calculates the datacube for the stars component + of a RubixData object on a per-particle basis. First, it looks up the SSP + spectrum for each star based on its age and metallicity, scales it by the + star's mass, applies a Doppler shift based on the star's velocity, resamples + the spectrum onto the telescope's wavelength grid, and finally accumulates + the resulting spectra into the appropriate pixels of the datacube. Args: - config (dict): The configuration dictionary - - Returns: - The function that calculates the spectra of the stars. - - Example - ------- - >>> config = { - ... "ssp": { - ... "template": { - ... "name": "BruzualCharlot2003" - ... }, - ... }, - ... } - - >>> from rubix.core.ifu import get_calculate_spectra - >>> calcultae_spectra = get_calculate_spectra(config) - - >>> rubixdata = calcultae_spectra(rubixdata) - >>> # Access the spectra of the stars - >>> rubixdata.stars.spectra - """ - logger = get_logger(config.get("logger", None)) - lookup_interpolation_pmap = get_lookup_interpolation_pmap(config) - # lookup_interpolation_vmap = get_lookup_interpolation_vmap(config) - lookup_interpolation = get_lookup_interpolation(config) - - @jaxtyped(typechecker=typechecker) - def calculate_spectra(rubixdata: RubixData) -> RubixData: - logger.info("Calculating IFU cube...") - logger.debug( - f"Input shapes: Metallicity: {len(rubixdata.stars.metallicity)}, Age: {len(rubixdata.stars.age)}" - ) - # Ensure metallicity and age are arrays and reshape them to be at least 1-dimensional - # age_data = jax.device_get(rubixdata.stars.age) - age_data = rubixdata.stars.age - # metallicity_data = jax.device_get(rubixdata.stars.metallicity) - metallicity_data = rubixdata.stars.metallicity - # Ensure they are not scalars or empty; convert to 1D arrays if necessary - age = jnp.atleast_1d(age_data) - metallicity = jnp.atleast_1d(metallicity_data) - - """ - spectra1 = lookup_interpolation( - # rubixdata.stars.metallicity, rubixdata.stars.age - metallicity[0][:250000], - age[0][:250000], - ) # * inputs["mass"] - spectra2 = lookup_interpolation( - # rubixdata.stars.metallicity, rubixdata.stars.age - metallicity[0][250000:500000], - age[0][250000:500000], - ) - spectra3 = lookup_interpolation( - # rubixdata.stars.metallicity, rubixdata.stars.age - metallicity[0][500000:750000], - age[0][500000:750000], - ) - spectra = jnp.concatenate([spectra1, spectra2, spectra3], axis=0) - """ - # Define the chunk size (number of particles per chunk) - chunk_size = 100000 - total_length = metallicity[0].shape[ - 0 - ] # assuming metallicity[0] is your 1D array of particles - - # List to hold the spectra chunks - spectra_chunks = [] - - # Loop over the data in chunks - for start in range(0, total_length, chunk_size): - end = min(start + chunk_size, total_length) - current_chunk = lookup_interpolation( - metallicity[0][start:end], - age[0][start:end], - ) - spectra_chunks.append(current_chunk) + config (dict): Configuration dictionary containing telescope and galaxy + parameters. - # Concatenate all the chunks along axis 0 - spectra = jnp.concatenate(spectra_chunks, axis=0) - logger.debug(f"Calculation Finished! Spectra shape: {spectra.shape}") - spectra_jax = jnp.array(spectra) - spectra_jax = jnp.expand_dims(spectra_jax, axis=0) - rubixdata.stars.spectra = spectra_jax - # setattr(rubixdata.gas, "spectra", spectra) - # jax.debug.print("Calculate Spectra: Spectra {}", spectra) - return rubixdata - - return calculate_spectra - - -@jaxtyped(typechecker=typechecker) -def get_scale_spectrum_by_mass(config: dict) -> Callable: - """ - The spectra of the stellar particles are scaled by the mass of the stars. - - Args: - config (dict): The configuration dictionary Returns: - The function that scales the spectra by the mass of the stars. - - Example - ------- - >>> from rubix.core.ifu import get_scale_spectrum_by_mass - >>> scale_spectrum_by_mass = get_scale_spectrum_by_mass(config) - - >>> rubixdata = scale_spectrum_by_mass(rubixdata) - >>> # Access the spectra of the stars, which is now scaled by the stellar mass - >>> rubixdata.stars.spectra + Callable: A function that takes a RubixData object and returns it with + the datacube calculated and added to the stars component. """ - logger = get_logger(config.get("logger", None)) - - @jaxtyped(typechecker=typechecker) - def scale_spectrum_by_mass(rubixdata: RubixData) -> RubixData: - - logger.info("Scaling Spectra by Mass...") - mass = jnp.expand_dims(rubixdata.stars.mass, axis=-1) - # rubixdata.stars.spectra = rubixdata.stars.spectra * mass - spectra_mass = rubixdata.stars.spectra * mass - setattr(rubixdata.stars, "spectra", spectra_mass) - # jax.debug.print("mass mult: Spectra {}", inputs["spectra"]) - return rubixdata - - return scale_spectrum_by_mass - - -# Vectorize the resample_spectrum function -@jaxtyped(typechecker=typechecker) -def get_resample_spectrum_vmap(target_wavelength) -> Callable: - """ - The spectra of the stars are resampled to the telescope wavelength grid. - - Args: - target_wavelength (jax.Array): The telescope wavelength grid - - Returns: - The function that resamples the spectra to the telescope wavelength grid. - """ - - @jaxtyped(typechecker=typechecker) - def resample_spectrum_vmap(initial_spectrum, initial_wavelength): - return resample_spectrum( - initial_spectrum=initial_spectrum, - initial_wavelength=initial_wavelength, - target_wavelength=target_wavelength, - ) - - return jax.vmap(resample_spectrum_vmap, in_axes=(0, 0)) - - -# Parallelize the vectorized function across devices -@jaxtyped(typechecker=typechecker) -def get_resample_spectrum_pmap(target_wavelength) -> Callable: - """ - Pmap the function that resamples the spectra of the stars to the telescope wavelength grid. - - Args: - target_wavelength (jax.Array): The telescope wavelength grid - - Returns: - The function that resamples the spectra to the telescope wavelength grid. - """ - vmapped_resample_spectrum = get_resample_spectrum_vmap(target_wavelength) - return jax.pmap(vmapped_resample_spectrum) - - -@jaxtyped(typechecker=typechecker) -def get_velocities_doppler_shift_vmap( - ssp_wave: Float[Array, "..."], velocity_direction: str -) -> Callable: - """ - The function doppler shifts the wavelength based on the velocity of the stars. - - Args: - ssp_wave (jax.Array): The wavelength of the SSP grid - velocity_direction (str): The velocity component of the stars that is used to doppler shift the wavelength - - Returns: - The function that doppler shifts the wavelength based on the velocity of the stars. - """ - - def func(velocity): - return velocity_doppler_shift( - wavelength=ssp_wave, velocity=velocity, direction=velocity_direction - ) - - return jax.vmap(func, in_axes=0) - - -@jaxtyped(typechecker=typechecker) -def get_doppler_shift_and_resampling(config: dict) -> Callable: - """ - The function doppler shifts the wavelength based on the velocity of the stars and resamples the spectra to the telescope wavelength grid. - - Args: - config (dict): The configuration dictionary - - Returns: - The function that doppler shifts the wavelength based on the velocity of the stars and resamples the spectra to the telescope wavelength grid. - - Example - ------- - >>> from rubix.core.ifu import get_doppler_shift_and_resampling - >>> doppler_shift_and_resampling = get_doppler_shift_and_resampling(config) - - >>> rubixdata = doppler_shift_and_resampling(rubixdata) - >>> # Access the spectra of the stars, which is now doppler shifted and resampled to the telescope wavelength grid - >>> rubixdata.stars.spectra - """ - logger = get_logger(config.get("logger", None)) - - # The velocity component of the stars that is used to doppler shift the wavelength - velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] - - # The redshift at which the user wants to observe the galaxy - galaxy_redshift = config["galaxy"]["dist_z"] - - # Get the telescope wavelength bins telescope = get_telescope(config) - telescope_wavelength = telescope.wave_seq - - # Get the SSP grid to doppler shift the wavelengths - ssp = get_ssp(config) + ns = int(telescope.sbin) + nseg = ns * ns + target_wave = telescope.wave_seq # (n_wave_tel,) - # Doppler shift the SSP wavelenght based on the cosmological distance of the observed galaxy - ssp_wave = cosmological_doppler_shift(z=galaxy_redshift, wavelength=ssp.wavelength) - logger.debug(f"SSP Wave: {ssp_wave.shape}") + # prepare SSP lookup + lookup_ssp = get_lookup_interpolation(config) - # Function to Doppler shift the wavelength based on the velocity of the stars particles - # This binds the velocity direction, such that later we only need the velocity during the pipeline - doppler_shift = get_velocities_doppler_shift_vmap(ssp_wave, velocity_direction) - - @jaxtyped(typechecker=typechecker) - def process_particle( - particle: Union[StarsData, GasData], - ) -> Union[Float[Array, "..."], None]: - if particle.spectra is not None: - # Doppler shift based on the velocity of the particle - doppler_shifted_ssp_wave = doppler_shift(particle.velocity) - logger.info(f"Doppler shifting and resampling spectra...") - logger.debug(f"Doppler Shifted SSP Wave: {doppler_shifted_ssp_wave.shape}") - logger.debug(f"Telescope Wave Seq: {telescope_wavelength.shape}") - - # Function to resample the spectrum to the telescope wavelength grid - resample_spectrum_pmap = get_resample_spectrum_pmap(telescope_wavelength) - spectrum_resampled = resample_spectrum_pmap( - particle.spectra, doppler_shifted_ssp_wave - ) - return spectrum_resampled - return particle.spectra - - @jaxtyped(typechecker=typechecker) - def doppler_shift_and_resampling(rubixdata: RubixData) -> RubixData: - for particle_name in ["stars", "gas"]: - particle = getattr(rubixdata, particle_name) - particle.spectra = process_particle(particle) - - return rubixdata - - return doppler_shift_and_resampling - - -@jaxtyped(typechecker=typechecker) -def get_calculate_datacube(config: dict) -> Callable: - """ - The function returns the function that calculates the datacube of the stars. - - Args: - config (dict): The configuration dictionary - - Returns: - The function that calculates the datacube of the stars. - - Example - ------- - >>> from rubix.core.ifu import get_calculate_datacube - >>> calculate_datacube = get_calculate_datacube(config) - - >>> rubixdata = calculate_datacube(rubixdata) - >>> # Access the datacube of the stars - >>> rubixdata.stars.datacube - """ - logger = get_logger(config.get("logger", None)) - telescope = get_telescope(config) - num_spaxels = int(telescope.sbin) - - # Bind the num_spaxels to the function - calculate_cube_fn = jax.tree_util.Partial(calculate_cube, num_spaxels=num_spaxels) - calculate_cube_pmap = jax.pmap(calculate_cube_fn) + # prepare Doppler machinery + velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] + z_obs = config["galaxy"]["dist_z"] + ssp_model = get_ssp(config) + ssp_wave0 = cosmological_doppler_shift( + z=z_obs, wavelength=ssp_model.wavelength + ) # (n_wave_ssp,) @jaxtyped(typechecker=typechecker) - def calculate_datacube(rubixdata: RubixData) -> RubixData: - logger.info("Calculating Data Cube...") - ifu_cubes = calculate_cube_pmap( - spectra=rubixdata.stars.spectra, - spaxel_index=rubixdata.stars.pixel_assignment, - ) - datacube = jnp.sum(ifu_cubes, axis=0) - logger.debug(f"Datacube Shape: {datacube.shape}") - # logger.debug(f"This is the datacube: {datacube}") - datacube_jax = jnp.array(datacube) - setattr(rubixdata.stars, "datacube", datacube_jax) - # rubixdata.stars.datacube = datacube + def calculate_datacube_particlewise(rubixdata: RubixData) -> RubixData: + logger.info("Calculating Data Cube (combined per‐particle)…") + + stars = rubixdata.stars + ages = stars.age # (n_stars,) + metallicity = stars.metallicity # (n_stars,) + masses = stars.mass # (n_stars,) + velocities = stars.velocity # (n_stars,) + pix_idx = stars.pixel_assignment # (n_stars,) + nstar = ages.shape[0] + + # init flat cube: (nseg, n_wave_tel) + init_cube = jnp.zeros((nseg, target_wave.shape[-1])) + + def body(cube, i): + age_i = ages[i] # scalar + Z_i = metallicity[i] # scalar + m_i = masses[i] # scalar + v_i = velocities[i] # scalar or vector + pix_i = pix_idx[i].astype(jnp.int32) + + # 1) SSP lookup + spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,) + # 2) scale by mass + spec_mass = spec_ssp * m_i # (n_wave_ssp,) + # 3) Doppler‐shift wavelengths + shifted_wave = _velocity_doppler_shift_single( + wavelength=ssp_wave0, + velocity=v_i, + direction=velocity_direction, + ) # (n_wave_ssp,) + # 4) resample onto telescope grid + spec_tel = resample_spectrum( + initial_spectrum=spec_mass, + initial_wavelength=shifted_wave, + target_wavelength=target_wave, + ) # (n_wave_tel,) + + # 5) accumulate + cube = cube.at[pix_i].add(spec_tel) + return cube, None + + cube_flat, _ = lax.scan(body, init_cube, jnp.arange(nstar, dtype=jnp.int32)) + + cube_3d = cube_flat.reshape(ns, ns, -1) + setattr(rubixdata.stars, "datacube", cube_3d) + logger.debug(f"Datacube shape: {cube_3d.shape}") return rubixdata - return calculate_datacube + return calculate_datacube_particlewise diff --git a/rubix/core/pipeline.py b/rubix/core/pipeline.py index e376bd4..3491d4c 100644 --- a/rubix/core/pipeline.py +++ b/rubix/core/pipeline.py @@ -1,24 +1,36 @@ +import dataclasses import time +from functools import partial +from types import SimpleNamespace from typing import Union import jax import jax.numpy as jnp + +# For shard_map and device mesh. +import numpy as np from beartype import beartype as typechecker -from jax import block_until_ready +from jax import block_until_ready, lax +from jax.experimental.pjit import pjit +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jaxtyping import jaxtyped from rubix.logger import get_logger from rubix.pipeline import linear_pipeline as pipeline -from rubix.utils import get_config, get_pipeline_config +from rubix.utils import _pad_particles, get_config, get_pipeline_config -from .data import get_reshape_data, get_rubix_data -from .dust import get_extinction -from .ifu import ( - get_calculate_datacube, - get_calculate_spectra, - get_doppler_shift_and_resampling, - get_scale_spectrum_by_mass, +from .data import ( + Galaxy, + GasData, + RubixData, + StarsData, + get_reshape_data, + get_rubix_data, ) +from .dust import get_extinction +from .ifu import get_calculate_datacube_particlewise from .lsf import get_convolve_lsf from .noise import get_apply_noise from .psf import get_convolve_psf @@ -31,43 +43,47 @@ class RubixPipeline: """ RubixPipeline is responsible for setting up and running the data processing pipeline. - Args: - user_config (dict or str): Parsed user configuration for the pipeline. - pipeline_config (dict): Configuration for the pipeline. - logger(Logger) : Logger instance for logging messages. - ssp(object) : Stellar population synthesis model. - telescope(object) : Telescope configuration. - data (dict): Dictionary containing particle data. - func (callable): Compiled pipeline function to process data. - - Example - -------- - >>> from rubix.core.pipeline import RubixPipeline - >>> config = "path/to/config.yml" - >>> pipeline = RubixPipeline(config) - >>> output = pipeline.run() - >>> ssp_model = pipeline.ssp - >>> telescope = pipeline.telescope + Usage + ----- + >>> pipe = RubixPipeline(config) + >>> inputdata = pipe.prepare_data() + >>> # To run without sharding: + >>> output = pipe.run(inputdata) + >>> # To run with sharding using jax.shard_map: + >>> final_datacube = pipe.run_sharded(inputdata, shard_size=100000) """ def __init__(self, user_config: Union[dict, str]): + """ + Initializes the RubixPipeline with the given user configuration. + + Args: + user_config (Union[dict, str]): User configuration dictionary or path to config file. + pipeline_config (dict): Pipeline configuration dictionary. + logger: Logger instance for logging messages. + ssp: SSP model instance. + telescope: Telescope instance. + func: Compiled pipeline function. + + Returns: + None + """ self.user_config = get_config(user_config) self.pipeline_config = get_pipeline_config(self.user_config["pipeline"]["name"]) self.logger = get_logger(self.user_config["logger"]) self.ssp = get_ssp(self.user_config) self.telescope = get_telescope(self.user_config) - self.data = self._prepare_data() self.func = None - def _prepare_data(self): + def prepare_data(self): """ Prepares and loads the data for the pipeline. Returns: - Dictionary containing particle data with keys: - 'n_particles', 'coords', 'velocities', 'metallicity', 'mass', and 'age'. + Object containing particle data with attributes such as: + 'coords', 'velocities', 'mass', 'age', and 'metallicity' under stars and gas. """ - # Get the data + t1 = time.time() self.logger.info("Getting rubix data...") rubixdata = get_rubix_data(self.user_config) star_count = ( @@ -77,17 +93,8 @@ def _prepare_data(self): self.logger.info( f"Data loaded with {star_count} star particles and {gas_count} gas particles." ) - # Setup the data dictionary - # TODO: This is a temporary solution, we need to figure out a better way to handle the data - # This works, because JAX can trace through the data dictionary - # Other option may be named tuples or data classes to have fixed keys - - # self.logger.debug("Data: %s", rubixdata) - # self.logger.debug( - # "Data Shape: %s", - # {k: v.shape for k, v in rubixdata.items() if hasattr(v, "shape")}, - # ) - + t2 = time.time() + self.logger.info("Data preparation completed in %.2f seconds.", t2 - t1) return rubixdata @jaxtyped(typechecker=typechecker) @@ -101,18 +108,13 @@ def _get_pipeline_functions(self) -> list: self.logger.info("Setting up the pipeline...") self.logger.debug("Pipeline Configuration: %s", self.pipeline_config) - # TODO: maybe there is a nicer way to load the functions from the yaml config? rotate_galaxy = get_galaxy_rotation(self.user_config) filter_particles = get_filter_particles(self.user_config) spaxel_assignment = get_spaxel_assignment(self.user_config) - calculate_spectra = get_calculate_spectra(self.user_config) - reshape_data = get_reshape_data(self.user_config) - scale_spectrum_by_mass = get_scale_spectrum_by_mass(self.user_config) - doppler_shift_and_resampling = get_doppler_shift_and_resampling( + apply_extinction = get_extinction(self.user_config) + calculate_datacube_particlewise = get_calculate_datacube_particlewise( self.user_config ) - apply_extinction = get_extinction(self.user_config) - calculate_datacube = get_calculate_datacube(self.user_config) convolve_psf = get_convolve_psf(self.user_config) convolve_lsf = get_convolve_lsf(self.user_config) apply_noise = get_apply_noise(self.user_config) @@ -121,88 +123,144 @@ def _get_pipeline_functions(self) -> list: rotate_galaxy, filter_particles, spaxel_assignment, - calculate_spectra, - reshape_data, - scale_spectrum_by_mass, - doppler_shift_and_resampling, apply_extinction, - calculate_datacube, + calculate_datacube_particlewise, convolve_psf, convolve_lsf, apply_noise, ] - return functions - # TODO: currently returns dict, but later should return only the IFU cube - def run(self): + def run_sharded(self, inputdata, devices): """ - Runs the data processing pipeline. + Runs the pipeline on sharded input data in parallel using jax.shard_map. + It splits the particle arrays (e.g. under stars and gas) into shards, runs + the compiled pipeline on each shard, and then combines the resulting datacubes. + + This is the recomended method to run the pipeline in parallel at the moment!!! + + Parameters + ---------- + inputdata : object + Data prepared from the `prepare_data` method. + shard_size : int + Number of particles per shard. Returns ------- - dict - Output of the pipeline after processing the input data. + jax.numpy.ndarray + The final datacube combined from all shards. """ - # Create the pipeline time_start = time.time() + # Assemble and compile the pipeline as before. functions = self._get_pipeline_functions() self._pipeline = pipeline.LinearTransformerPipeline( self.pipeline_config, functions ) - - # Assembling the pipeline self.logger.info("Assembling the pipeline...") self._pipeline.assemble() - - # Compiling the expressions self.logger.info("Compiling the expressions...") self.func = self._pipeline.compile_expression() - # Running the pipeline - self.logger.info("Running the pipeline on the input data...") - output = self.func(self.data) + # devices = jax.devices() + num_devices = len(devices) + self.logger.info("Number of devices: %d", num_devices) + + mesh = Mesh(devices, axis_names=("data",)) + + # — sharding specs by rank — + replicate_0d = NamedSharding(mesh, P()) # for scalars + replicate_1d = NamedSharding(mesh, P(None)) # for 1-D arrays + shard_2d = NamedSharding(mesh, P("data", None)) # for (N, D) + shard_1d = NamedSharding(mesh, P("data")) # for (N,) + shard_bins = NamedSharding(mesh, P(None, None)) + replicate_3d = NamedSharding(mesh, P(None, None, None)) # for full cube + + # — 1) allocate empty instances — + galaxy_spec = object.__new__(Galaxy) + stars_spec = object.__new__(StarsData) + gas_spec = object.__new__(GasData) + rubix_spec = object.__new__(RubixData) + + # — 2) assign NamedSharding to each field — + # galaxy + galaxy_spec.redshift = replicate_0d + galaxy_spec.center = replicate_1d + galaxy_spec.halfmassrad_stars = replicate_0d + + # stars + stars_spec.coords = shard_2d + stars_spec.velocity = shard_2d + stars_spec.mass = shard_1d + stars_spec.age = shard_1d + stars_spec.metallicity = shard_1d + stars_spec.pixel_assignment = shard_1d + stars_spec.spatial_bin_edges = shard_bins + stars_spec.mask = shard_1d + stars_spec.spectra = shard_2d + stars_spec.datacube = replicate_3d + + # gas (same idea) + gas_spec.coords = shard_2d + gas_spec.velocity = shard_2d + gas_spec.mass = shard_1d + gas_spec.density = shard_1d + gas_spec.internal_energy = shard_1d + gas_spec.metallicity = shard_1d + gas_spec.metals = shard_1d + gas_spec.sfr = shard_1d + gas_spec.electron_abundance = shard_1d + gas_spec.pixel_assignment = shard_1d + gas_spec.spatial_bin_edges = shard_bins + gas_spec.mask = shard_1d + gas_spec.spectra = shard_2d + gas_spec.datacube = replicate_3d + + # — link them up — + rubix_spec.galaxy = galaxy_spec + rubix_spec.stars = stars_spec + rubix_spec.gas = gas_spec + + # 1) Make a pytree of PartitionSpec + partition_spec_tree = tree_map( + lambda s: s.spec if isinstance(s, NamedSharding) else None, rubix_spec + ) + + # if the particle number is not modulo the device number, we have to pad a few empty particles + # to make it work + n = inputdata.stars.coords.shape[0] + pad = (num_devices - (n % num_devices)) % num_devices + if pad: + self.logger.info( + "Padding particles to make the number of particles divisible by the number of devices (%d).", + num_devices, + ) + inputdata = _pad_particles(inputdata, pad) + + inputdata = jax.device_put(inputdata, rubix_spec) + + # create the sharded data + def _shard_pipeline(sharded_rubixdata): + out_local = self.func(sharded_rubixdata) + local_cube = out_local.stars.datacube # shape (25,25,5994) + # in‐XLA all‐reduce across the "data" axis: + summed_cube = lax.psum(local_cube, axis_name="data") + return summed_cube # replicated on each device + + sharded_pipeline = shard_map( + _shard_pipeline, # the function to compile + mesh=mesh, # the mesh to use + in_specs=(partition_spec_tree,), + out_specs=replicate_3d.spec, + check_rep=False, + ) + + sharded_result = sharded_pipeline(inputdata) - block_until_ready(output) time_end = time.time() self.logger.info( - "Pipeline run completed in %.2f seconds.", time_end - time_start + "Total time for sharded pipeline run: %.2f seconds.", + time_end - time_start, ) - output.galaxy.redshift_unit = self.data.galaxy.redshift_unit - output.galaxy.center_unit = self.data.galaxy.center_unit - output.galaxy.halfmassrad_stars_unit = self.data.galaxy.halfmassrad_stars_unit - - if output.stars.coords != None: - output.stars.coords_unit = self.data.stars.coords_unit - output.stars.velocity_unit = self.data.stars.velocity_unit - output.stars.mass_unit = self.data.stars.mass_unit - # output.stars.metallictiy_unit = self.data.stars.metallictiy_unit - output.stars.age_unit = self.data.stars.age_unit - output.stars.spatial_bin_edges_unit = "kpc" - # output.stars.wavelength_unit = rubix_config["ssp"]["units"]["wavelength"] - # output.stars.spectra_unit = rubix_config["ssp"]["units"]["flux"] - # output.stars.datacube_unit = rubix_config["ssp"]["units"]["flux"] - - if output.gas.coords != None: - output.gas.coords_unit = self.data.gas.coords_unit - output.gas.velocity_unit = self.data.gas.velocity_unit - output.gas.mass_unit = self.data.gas.mass_unit - output.gas.density_unit = self.data.gas.density_unit - output.gas.internal_energy_unit = self.data.gas.internal_energy_unit - # output.gas.metallicity_unit = self.data.gas.metallicity_unit - output.gas.sfr_unit = self.data.gas.sfr_unit - output.gas.electron_abundance_unit = self.data.gas.electron_abundance_unit - output.gas.spatial_bin_edges_unit = "kpc" - # output.gas.wavelength_unit = rubix_config["ssp"]["units"]["wavelength"] - # output.gas.spectra_unit = rubix_config["ssp"]["units"]["flux"] - # output.gas.datacube_unit = rubix_config["ssp"]["units"]["flux"] - - return output - - # TODO: implement gradient calculation - def gradient(self): - """ - This function will calculate the gradient of the pipeline, but is yet not implemented. - """ - raise NotImplementedError("Gradient calculation is not implemented yet") + return sharded_result diff --git a/rubix/core/rotation.py b/rubix/core/rotation.py index 05c8d3d..bb4024c 100644 --- a/rubix/core/rotation.py +++ b/rubix/core/rotation.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import jaxtyped @@ -42,7 +43,7 @@ def get_galaxy_rotation(config: dict): # Check if type is provided if "type" in config["galaxy"]["rotation"]: # Check if type is valid: face-on or edge-on - if config["galaxy"]["rotation"]["type"] not in ["face-on", "edge-on"]: + if config["galaxy"]["rotation"]["type"] not in ["face-on", "edge-on", "matrix"]: raise ValueError("Invalid type provided in rotation information") # if type is face on, alpha = beta = gamma = 0 @@ -75,47 +76,14 @@ def get_galaxy_rotation(config: dict): @jaxtyped(typechecker=typechecker) def rotate_galaxy(rubixdata: RubixData) -> RubixData: logger.info(f"Rotating galaxy with alpha={alpha}, beta={beta}, gamma={gamma}") - """ - for particle_type in ["stars", "gas"]: - if particle_type in config["data"]["args"]["particle_type"]: - # Get the component (either stars or gas) - logger.info(f"Rotating {particle_type}") - component = getattr(rubixdata, particle_type) - - # Get the inputs - coords = component.coords - velocities = component.velocity - masses = component.mass - halfmass_radius = rubixdata.galaxy.halfmassrad_stars - - assert ( - coords is not None - ), f"Coordinates not found for {particle_type}. " - assert ( - velocities is not None - ), f"Velocities not found for {particle_type}. " - assert masses is not None, f"Masses not found for {particle_type}. " - - # Rotate the galaxy - coords, velocities = rotate_galaxy_core( - positions=coords, - velocities=velocities, - positions_stars=rubixdata.stars.coords, - masses_stars=rubixdata.stars.mass, - halfmass_radius=halfmass_radius, - alpha=alpha, - beta=beta, - gamma=gamma, - ) - - # Update the inputs - # rubixdata.stars.coords = coords - # rubixdata.stars.velocity = velocities - setattr(component, "coords", coords) - setattr(component, "velocity", velocities) + Rotates the galaxy particle data based on the specified rotation angles. - return rubixdata + Args: + rubixdata (RubixData): The RubixData object containing particle data. + + Returns: + RubixData: The rotated RubixData object. """ logger.info("Rotating galaxy for simulation: " + config["simulation"]["name"]) # Rotate gas diff --git a/rubix/core/ssp.py b/rubix/core/ssp.py index 5d205d3..23577a1 100644 --- a/rubix/core/ssp.py +++ b/rubix/core/ssp.py @@ -79,6 +79,7 @@ def get_lookup_interpolation_vmap(config: dict) -> Callable: """ lookup = get_lookup_interpolation(config) lookup_vmap = jax.vmap(lookup, in_axes=(0, 0)) + return lookup_vmap diff --git a/rubix/galaxy/alignment.py b/rubix/galaxy/alignment.py index 09b4153..2b2e104 100644 --- a/rubix/galaxy/alignment.py +++ b/rubix/galaxy/alignment.py @@ -252,6 +252,7 @@ def rotate_galaxy( alpha (float): Rotation around the x-axis in degrees beta (float): Rotation around the y-axis in degrees gamma (float): Rotation around the z-axis in degrees + key (str): The key to the particle data, e.g. "IllustrisTNG" or "NIHAO" Returns: The rotated positions and velocities as a jnp.ndarray. diff --git a/rubix/galaxy/input_handler/pynbody.py b/rubix/galaxy/input_handler/pynbody.py index 1fc1f29..9decf28 100644 --- a/rubix/galaxy/input_handler/pynbody.py +++ b/rubix/galaxy/input_handler/pynbody.py @@ -15,12 +15,20 @@ class PynbodyHandler(BaseHandler): def __init__( - self, path, halo_path=None, logger=None, config=None, dist_z=None, halo_id=None + self, + path, + halo_path=None, + rotation_path="./data", + logger=None, + config=None, + dist_z=None, + halo_id=None, ): """Initialize handler with paths to snapshot and halo files.""" self.metallicity_unit = Zsun self.path = path self.halo_path = halo_path + self.rotation_path = rotation_path self.halo_id = halo_id self.pynbody_config = config or self._load_config() self.logger = logger or self._default_logger() @@ -74,7 +82,19 @@ def load_data(self): self.logger.info(f"Simulation snapshot loaded from halo {self.halo_id}") halo = self.get_halo_data(halo_id=self.halo_id) if halo is not None: - pynbody.analysis.angmom.faceon(halo) + pynbody.analysis.angmom.faceon(halo.s) + ang_mom_vec = pynbody.analysis.angmom.ang_mom_vec(halo.s) + rotation_matrix = pynbody.analysis.angmom.calc_sideon_matrix(ang_mom_vec) + if not os.path.exists(self.rotation_path): + self.logger.info("Rotation matrix calculated and not saved.") + else: + np.save( + os.path.join(self.rotation_path, "rotation_matrix.npy"), + rotation_matrix, + ) + self.logger.info( + f"Rotation matrix calculated and saved to '{self.rotation_path}/rotation_matrix.npy'." + ) self.sim = halo fields = self.pynbody_config["fields"] @@ -89,7 +109,6 @@ def load_data(self): getattr(self.sim, cls), fields[cls], units[cls], cls ) - # Combine HI and OxMassFrac into a two-column metals field for gas hi_data = self.load_particle_data( getattr(self.sim, "gas"), {"HI": "HI"}, @@ -102,8 +121,7 @@ def load_data(self): {"OxMassFrac": u.dimensionless_unscaled}, "gas", ) - # fe_data = self.load_particle_data(getattr(self.sim, "gas"), {"FeMassFrac": "FeMassFrac"}, {"FeMassFrac": u.dimensionless_unscaled}, "gas") - # self.data["gas"]["metals"] = np.column_stack((hi_data["HI"], ox_data["OxMassFrac"])) + # Create a metals array with 10 columns, filled with zeros initially n_particles = hi_data["HI"].shape[0] metals = np.zeros((n_particles, 10), dtype=hi_data["HI"].dtype) @@ -116,9 +134,6 @@ def load_data(self): self.logger.info("Metals assigned to gas particles.") self.logger.info("Metals shape is: %s", self.data["gas"]["metals"].shape) - age_at_z0 = rubix_cosmo.age_at_z0() - self.data["stars"]["age"] = age_at_z0 * u.Gyr - self.data["stars"]["age"] - self.logger.info( f"Simulation snapshot and halo data loaded successfully for classes: {load_classes}." ) diff --git a/rubix/pipeline/linear_pipeline.py b/rubix/pipeline/linear_pipeline.py index 729c44d..bc6de13 100644 --- a/rubix/pipeline/linear_pipeline.py +++ b/rubix/pipeline/linear_pipeline.py @@ -178,7 +178,6 @@ def apply(self, *args, static_args=[], static_kwargs=[], **kwargs): ValueError _description_ """ - print("Arguments: ", *args) if len(args) == 0: raise ValueError("Cannot apply the pipeline to an empty list of arguments") diff --git a/rubix/utils.py b/rubix/utils.py index 07cf77d..66644dc 100644 --- a/rubix/utils.py +++ b/rubix/utils.py @@ -3,6 +3,7 @@ from typing import Dict, Union import h5py +import jax.numpy as jnp import yaml from astropy.cosmology import Planck15 as cosmo @@ -195,3 +196,26 @@ def load_galaxy_data(path_to_file: str): units[key][field] = f[f"particles/{key}/{field}"].attrs["unit"] return galaxy_data, units + + +def _pad_particles(inputdata, pad: int) -> "InputData": + """ + Pads the particle arrays in inputdata to make their length divisible by num_devices. + This is necessary for sharding to work correctly. + + Args: + inputdata (InputData): The input data containing particle arrays. + pad (int): The number of particles to pad. + + Returns: + InputData: The padded input data. + """ + + # pad along the first axis + inputdata.stars.coords = jnp.pad(inputdata.stars.coords, ((0, pad), (0, 0))) + inputdata.stars.velocity = jnp.pad(inputdata.stars.velocity, ((0, pad), (0, 0))) + inputdata.stars.mass = jnp.pad(inputdata.stars.mass, ((0, pad))) + inputdata.stars.age = jnp.pad(inputdata.stars.age, ((0, pad))) + inputdata.stars.metallicity = jnp.pad(inputdata.stars.metallicity, ((0, pad))) + + return inputdata diff --git a/tests/test_core_ifu.py b/tests/test_core_ifu.py index 4dd948f..21f67f7 100644 --- a/tests/test_core_ifu.py +++ b/tests/test_core_ifu.py @@ -1,17 +1,8 @@ -import jax import jax.numpy as jnp import numpy as np -from rubix.core.data import Galaxy, GasData, RubixData, StarsData, reshape_array -from rubix.core.ifu import ( - get_calculate_spectra, - get_doppler_shift_and_resampling, - get_resample_spectrum_pmap, - get_resample_spectrum_vmap, - get_scale_spectrum_by_mass, -) -from rubix.core.ssp import get_ssp -from rubix.spectra.ifu import resample_spectrum +from rubix.core.data import Galaxy, GasData, RubixData, StarsData +from rubix.core.ifu import get_calculate_datacube_particlewise, get_telescope RTOL = 1e-4 ATOL = 1e-6 @@ -29,7 +20,6 @@ print("Sample_inputs:") for key in sample_inputs: - sample_inputs[key] = reshape_array(sample_inputs[key]) print(f"Key: {key}, shape: {sample_inputs[key].shape}") @@ -78,235 +68,67 @@ def __init__(self, spectra): target_wavelength = jnp.array([4000.0, 5000.0, 6000.0]) -def _get_sample_inputs(subset=None): - ssp = get_ssp(sample_config) - """metallicity = reshape_array(ssp.metallicity) - age = reshape_array(ssp.age) - spectra = reshape_array(ssp.flux)""" - metallicity = ssp.metallicity - age = ssp.age - spectra = ssp.flux - - print("Metallicity shape: ", metallicity.shape) - print("Age shape: ", age.shape) - print("Spectra shape: ", spectra.shape) - print(".............") - - # Create meshgrid for metallicity and age to cover all combinations - metallicity_grid, age_grid = np.meshgrid( - metallicity.flatten(), age.flatten(), indexing="ij" - ) - metallicity_grid = jnp.asarray(metallicity_grid.flatten()) # Convert to jax.Array - age_grid = jnp.asarray(age_grid.flatten()) # Convert to jax.Array - metallicity_grid = reshape_array(metallicity_grid) - age_grid = reshape_array(age_grid) - metallicity_grid = jnp.array(metallicity_grid) - age_grid = jnp.array(age_grid) - print("Metallicity grid shape: ", metallicity_grid.shape) - print("Age grid shape: ", age_grid.shape) - - spectra = spectra.reshape(-1, spectra.shape[-1]) - print("spectra after reshape: ", spectra.shape) - spectra = reshape_array(spectra) - - print("spectra after reshape_array call: ", spectra.shape) - - # reshape spectra - num_combinations = metallicity_grid.shape[1] - spectra_reshaped = spectra.reshape( - spectra.shape[0], num_combinations, spectra.shape[-1] - ) - - # Create Velocities for each combination - - velocities = jnp.ones((metallicity_grid.shape[0], num_combinations, 3)) - mass = jnp.ones_like(metallicity_grid) - - if subset is not None: - metallicity_grid = metallicity_grid[:, :subset] - age_grid = age_grid[:, :subset] - velocities = velocities[:, :subset] - mass = mass[:, :subset] - spectra_reshaped = spectra_reshaped[:, :subset] - # inputs = dict( - # metallicity=metallicity_grid, age=age_grid, velocities=velocities, mass=mass - # ) - inputs = MockRubixData( - MockStarsData( - velocity=velocities, - metallicity=metallicity_grid, - mass=mass, - age=age_grid, - ), - MockGasData(spectra=None), - ) - return inputs, spectra_reshaped - - -def test_resample_spectrum_vmap(): - print("initial_spectra shape", initial_spectra.shape) - print("initial_wavelengths shape", initial_wavelengths.shape) - print("target_wavelength shape", target_wavelength.shape) - resample_spectrum_vmap = get_resample_spectrum_vmap(target_wavelength) - result_vmap = resample_spectrum_vmap(initial_spectra, initial_wavelengths) - - expected_result = jnp.stack( +def test_get_calculate_datacube_particlewise(): + # Setup config and telescope + config = { + "pipeline": {"name": "calc_ifu"}, + "logger": { + "log_level": "DEBUG", + "log_file_path": None, + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + "telescope": {"name": "MUSE"}, + "cosmology": {"name": "PLANCK15"}, + "galaxy": {"dist_z": 0.1}, + "ssp": {"template": {"name": "BruzualCharlot2003"}}, + } + telescope = get_telescope(config) + n_spaxels = int(telescope.sbin) + n_wave_tel = telescope.wave_seq.shape[0] + n_particles = 3 + + # Assign properties for n_particles + # Use valid values to avoid triggering issues in SSP lookup, resampling, etc. + metallicity = jnp.array([0.02, 0.01, 0.015]) + age = jnp.array([5.0, 8.0, 10.0]) + mass = jnp.array([1.0, 2.0, 0.5]) + velocity = jnp.array( [ - resample_spectrum( - initial_spectra[0], initial_wavelengths[0], target_wavelength - ), - resample_spectrum( - initial_spectra[1], initial_wavelengths[1], target_wavelength - ), + [100.0, 200.0, 300.0], + [0.0, 50.0, -100.0], + [1.0, 1.0, 1.0], ] ) - assert jnp.allclose(result_vmap, expected_result) - assert not jnp.any(jnp.isnan(result_vmap)) - - -def test_resample_spectrum_pmap(): - # For pmap we need to reshape, such that first axis is the device axis - initial_spectra = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - initial_wavelengths = jnp.array( - [[4500.0, 5500.0, 6500.0], [4500.0, 5500.0, 6500.0]] - ) - initial_spectra = reshape_array(initial_spectra) - initial_wavelengths = reshape_array(initial_wavelengths) - resample_spectrum_pmap = get_resample_spectrum_pmap(target_wavelength) - result_pmap = resample_spectrum_pmap(initial_spectra, initial_wavelengths) - - # Check how many GPUs are available, since this defines the shape of the result - if jax.device_count() > 1: - expected_result = jnp.array( - [ - resample_spectrum( - initial_spectra[0, 0], initial_wavelengths[0, 0], target_wavelength - ), - resample_spectrum( - initial_spectra[1, 0], initial_wavelengths[1, 0], target_wavelength - ), - ] - ) - expected_result = reshape_array(expected_result) - - else: - expected_result = jnp.stack( - [ - resample_spectrum( - initial_spectra[0, 0], initial_wavelengths[0, 0], target_wavelength - ), - resample_spectrum( - initial_spectra[0, 1], initial_wavelengths[0, 1], target_wavelength - ), - ] - ) - assert jnp.allclose(result_pmap, expected_result) - assert not jnp.any(jnp.isnan(result_pmap)) - - -def test_calculate_spectra(): - # Use an actual RubixData instance - mock_rubixdata = RubixData( - galaxy=Galaxy(), - stars=StarsData(), - gas=GasData(), - ) - - # Populate the RubixData object with mock data - mock_rubixdata.stars.coords = jnp.array([[1, 2, 3]]) - mock_rubixdata.stars.velocity = jnp.array([[4.0, 5.0, 6.0]]) - mock_rubixdata.stars.metallicity = jnp.array( - [[0.1]] - ) # 2D array for vmap compatibility - mock_rubixdata.stars.mass = jnp.array([[1000]]) # 2D array for vmap compatibility - mock_rubixdata.stars.age = jnp.array([[4.5]]) # 2D array for vmap compatibility - mock_rubixdata.galaxy.redshift = 0.1 - mock_rubixdata.galaxy.center = jnp.array([0, 0, 0]) - mock_rubixdata.galaxy.halfmassrad_stars = 1 - - # Obtain the calculate_spectra function - calculate_spectra = get_calculate_spectra(sample_config) - - # Mock expected spectra - expected_spectra_shape = (1, 1, 842) # Adjust shape as per your data - expected_spectra = jnp.zeros(expected_spectra_shape) - - # Call the calculate_spectra function - result = calculate_spectra(mock_rubixdata) - - # Validate the result - calculated_spectra = result.stars.spectra - - assert calculated_spectra.shape == expected_spectra.shape, "Shape mismatch" - assert jnp.allclose( - calculated_spectra, expected_spectra, rtol=RTOL, atol=ATOL - ), "Spectra values mismatch" - assert not jnp.any( - jnp.isnan(calculated_spectra) - ), "NaN values in calculated spectra" - - -def test_scale_spectrum_by_mass(): - # Use an actual RubixData instance - input = RubixData( - galaxy=Galaxy(), - stars=StarsData( - velocity=sample_inputs["velocities"], - metallicity=sample_inputs["metallicity"], - mass=sample_inputs["mass"], - age=sample_inputs["age"], - spectra=sample_inputs["spectra"], - ), - gas=GasData(spectra=None), - ) - - # Calculate expected spectra - expected_spectra = input.stars.spectra * jnp.expand_dims(input.stars.mass, axis=-1) - - # Call the function - scale_spectrum_by_mass = get_scale_spectrum_by_mass(sample_config) - result = scale_spectrum_by_mass(input) - - # Print for debugging - print("Input Mass:", input.stars.mass) - print("Input Spectra:", input.stars.spectra) - print("Result Spectra:", result.stars.spectra) - print("Expected Spectra:", expected_spectra) - - # Assertions - assert jnp.array_equal( - result.stars.spectra, expected_spectra - ), "Spectra scaling mismatch" - assert not jnp.any( - jnp.isnan(result.stars.spectra) - ), "NaN values found in result spectra" - - -def test_doppler_shift_and_resampling(): - # Obtain the function - doppler_shift_and_resampling = get_doppler_shift_and_resampling(sample_config) - - # Create an actual RubixData object - inputs = RubixData( - galaxy=Galaxy(), # Create a Galaxy instance as required - stars=StarsData( - velocity=sample_inputs["velocities"], - metallicity=sample_inputs["metallicity"], - mass=sample_inputs["mass"], - age=sample_inputs["age"], - spectra=sample_inputs["spectra"], # Assign expected spectra - ), - gas=GasData(spectra=None), - ) - - # Mock expected spectra - expected_spectra = sample_inputs["spectra"] - - # Call the function - result = doppler_shift_and_resampling(inputs) - - # Assertions - assert hasattr(result.stars, "spectra"), "Result does not have 'spectra'" - assert not jnp.any( - jnp.isnan(result.stars.spectra) - ), "NaN values found in result spectra" + # Assign each particle to a unique spaxel + pixel_assignment = jnp.array([0, 1, n_spaxels**2 - 1], dtype=jnp.int32) + + # Build the StarsData and RubixData object + stars = StarsData() + stars.metallicity = metallicity + stars.age = age + stars.mass = mass + stars.velocity = velocity + stars.pixel_assignment = pixel_assignment + + rubixdata = RubixData(galaxy=Galaxy(), stars=stars, gas=GasData()) + + # Run the particlewise datacube calculation + calc_datacube_particlewise = get_calculate_datacube_particlewise(config) + result = calc_datacube_particlewise(rubixdata) + + # Check output + assert hasattr(result.stars, "datacube") + assert result.stars.datacube.shape == (n_spaxels, n_spaxels, n_wave_tel) + # The cube must be non-negative and not NaN + assert jnp.all(result.stars.datacube >= 0) + assert not jnp.isnan(result.stars.datacube).any() + # Each particle's contribution must end up in the correct spaxel + # For a full test, you could do a partial "rebuild" as in your get_calculate_datacube test: + flat_cube = result.stars.datacube.reshape(-1, n_wave_tel) + # The nonzero spaxels should not be all zero (quick sanity check) + for pix in pixel_assignment: + assert jnp.any(flat_cube[pix] != 0) + # All spaxels not assigned should be exactly zero + mask = jnp.ones((n_spaxels**2,), dtype=bool) + mask = mask.at[pixel_assignment].set(False) + assert jnp.all(flat_cube[mask] == 0) diff --git a/tests/test_core_pipeline.py b/tests/test_core_pipeline.py index 1f6430a..06f0c13 100644 --- a/tests/test_core_pipeline.py +++ b/tests/test_core_pipeline.py @@ -1,9 +1,11 @@ import os # noqa from unittest.mock import MagicMock, patch +import jax import jax.numpy as jnp import pytest +from rubix.core.data import Galaxy, GasData, RubixData, StarsData from rubix.core.pipeline import RubixPipeline from rubix.spectra.ssp.grid import SSPGrid from rubix.telescope.base import BaseTelescope @@ -85,77 +87,43 @@ def test_rubix_pipeline_not_implemented(setup_environment): pipeline = RubixPipeline(user_config=config) # noqa -""" -def test_rubix_pipeline_gradient_not_implemented(setup_environment): - pipeline = RubixPipeline(user_config=user_config) - with pytest.raises( - NotImplementedError, match="Gradient calculation is not implemented yet" - ): - pipeline.gradient() -""" - - -def test_rubix_pipeline_gradient_not_implemented(setup_environment): - mock_rubix_data = MagicMock() - mock_rubix_data.stars.coords = jnp.array([[0, 0, 0]]) - mock_rubix_data.stars.velocities = jnp.array([[0, 0, 0]]) - mock_rubix_data.stars.metallicity = jnp.array([0.1]) - mock_rubix_data.stars.mass = jnp.array([1.0]) - mock_rubix_data.stars.age = jnp.array([1.0]) - - with patch("rubix.core.pipeline.get_rubix_data", return_value=mock_rubix_data): - pipeline = RubixPipeline(user_config=user_config) - with pytest.raises( - NotImplementedError, match="Gradient calculation is not implemented yet" - ): - pipeline.gradient() - - -def test_rubix_pipeline_run(): - pipeline = RubixPipeline(user_config=user_config) - output = pipeline.run() - - # Check if output is as expected - assert hasattr(output.stars, "coords") - assert hasattr(output.stars, "velocity") - assert hasattr(output.stars, "metallicity") - assert hasattr(output.stars, "mass") - assert hasattr(output.stars, "age") - assert hasattr(output.stars, "spectra") - - assert isinstance(pipeline.telescope, BaseTelescope) - assert isinstance(pipeline.ssp, SSPGrid) - - spectrum = output.stars.spectra - print("Spectrum shape: ", spectrum.shape) - print("Spectrum sum: ", jnp.sum(spectrum, axis=-1)) - - # Check if spectrum contains any nan values - # Only count the numby of NaN values in the spectra - is_nan = jnp.isnan(spectrum) - # check whether there are any NaN values in the spectra - - indices_nan = jnp.where(is_nan) - - # Get only the unique index of the spectra with NaN values - unique_spectra_indices = jnp.unique(indices_nan[-1]) - print("Unique indices of spectra with NaN values: ", unique_spectra_indices) - print( - "Masses of the spectra with NaN values: ", - output.stars.mass[unique_spectra_indices], - ) - print( - "Ages of the spectra with NaN values: ", - output.stars.age[unique_spectra_indices], - ) - print( - "Metallicities of the spectra with NaN values: ", - output.stars.metallicity[unique_spectra_indices], +def test_rubix_pipeline_run_sharded(): + # Use the number of devices to set up data that can be sharded + devices = jax.devices() + num_devices = len(jax.devices()) + n_particles = num_devices if num_devices > 1 else 2 # At least two for sanity + + # Mock input data + input_data = RubixData( + galaxy=Galaxy( + redshift=jnp.array([0.1]), + center=jnp.zeros((1, 3)), + halfmassrad_stars=jnp.array([1.0]), + ), + stars=StarsData( + coords=jnp.arange(n_particles * 3, dtype=jnp.float32).reshape( + n_particles, 3 + ), + velocity=jnp.arange(n_particles * 3, dtype=jnp.float32).reshape( + n_particles, 3 + ), + metallicity=jnp.linspace(0.01, 0.03, n_particles), + mass=jnp.ones(n_particles), + age=jnp.linspace(2.0, 10.0, n_particles), + pixel_assignment=jnp.arange(n_particles, dtype=jnp.int32), + ), + gas=GasData(velocity=None), ) - ssp = pipeline.ssp - print("SSP bounds age:", ssp.age.min(), ssp.age.max()) - print("SSP bounds metallicity:", ssp.metallicity.min(), ssp.metallicity.max()) - - # assert that the spectra does not contain any NaN values - assert not jnp.isnan(spectrum).any() + pipeline = RubixPipeline(user_config=user_config) + output_cube = pipeline.run_sharded(input_data, devices) + + # Output should be a jax array (the datacube) + assert isinstance(output_cube, jax.Array) + # Should have 3 dimensions (n_spaxels, n_spaxels, n_wave_tel) + assert output_cube.ndim == 3 + # Should be non-negative and not NaN + assert jnp.all(output_cube >= 0) + assert not jnp.isnan(output_cube).any() + # The cube should have nonzero values (sanity check) + assert jnp.any(output_cube != 0) diff --git a/tests/test_pynbody_handler.py b/tests/test_pynbody_handler.py index b865681..74a4f4f 100644 --- a/tests/test_pynbody_handler.py +++ b/tests/test_pynbody_handler.py @@ -97,16 +97,24 @@ def dm_getitem(key): @pytest.fixture def handler_with_mock_data(mock_simulation, mock_config): - with patch("pynbody.load", return_value=mock_simulation): - with patch("pynbody.analysis.angmom.faceon", return_value=None): - handler = PynbodyHandler( - path="mock_path", - halo_path="mock_halo_path", - config=mock_config, - dist_z=mock_config["galaxy"]["dist_z"], - halo_id=1, - ) - return handler + with ( + patch("pynbody.load", return_value=mock_simulation), + patch("pynbody.analysis.angmom.faceon", return_value=None), + patch( + "pynbody.analysis.angmom.ang_mom_vec", + return_value=np.array([0.0, 0.0, 1.0]), + ), + patch("pynbody.analysis.angmom.calc_sideon_matrix", return_value=np.eye(3)), + ): + + handler = PynbodyHandler( + path="mock_path", + halo_path="mock_halo_path", + config=mock_config, + dist_z=mock_config["galaxy"]["dist_z"], + halo_id=1, + ) + return handler def test_pynbody_handler_initialization(handler_with_mock_data): diff --git a/tests/test_ssp_fsps.py b/tests/test_ssp_fsps.py index a9dae47..42219b7 100644 --- a/tests/test_ssp_fsps.py +++ b/tests/test_ssp_fsps.py @@ -61,9 +61,7 @@ def test_retrieve_ssp_data_from_fsps(): assert isinstance(result, SSPGrid) assert np.allclose(result.metallicity, np.log10(mock_sp_instance.zlegend)) assert np.allclose(result.age, mock_sp_instance.log_age - 9.0) - assert np.allclose( - result.wavelength, np.array([4000, 4100, 4200]) - 50 - ) # because wavelengths are shifted by the calculated offset in the mock to be centered + assert np.allclose(result.wavelength, np.array([3950, 4050, 4150])) assert np.allclose( result.flux, np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), @@ -86,9 +84,7 @@ def test_retrieve_ssp_data_from_fsps_with_kwargs(): assert isinstance(result, SSPGrid) assert np.allclose(result.metallicity, np.log10(mock_sp_instance.zlegend)) assert np.allclose(result.age, mock_sp_instance.log_age - 9.0) - assert np.allclose( - result.wavelength, np.array([4000, 4100, 4200]) - 50 - ) # because wavelengths are shifted by 50 in the mock to be centered + assert np.allclose(result.wavelength, np.array([3950, 4050, 4150])) assert np.allclose( result.flux, np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),