diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6408c94..82d3d70 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-merge-conflict # Check for files that contain merge conflict strings. - id: check-added-large-files # Prevent large files from being added to the repository. - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.11.0 + rev: 25.12.0 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/README.md b/README.md index b5ea3a7..8e91457 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,16 @@ Key features include: ## Installation -The Python package `rubix` can be downloades from git and can be installed: +The Python package `rubix` is published on GitHub and can be installed alongside its runtime dependencies (including JAX) by choosing the relevant extras. For a CPU-only environment, install with: ``` git clone https://github.com/AstroAI-Lab/rubix.git cd rubix -pip install . +pip install .[cpu] ``` +If you need GPU acceleration, replace `[cpu]` with `[cuda]` (or install `jax[cuda]` following the [JAX instructions](https://github.com/google/jax#installation) before installing Rubix). The plain `pip install .` command installs the minimal package without JAX and will raise `ImportError` if you try to import `rubix` before adding `jax` manually. + ## Development installation If you want to contribute to the development of `rubix`, we recommend @@ -41,7 +43,7 @@ the following editable installation from this repository: ``` git clone https://github.com/AstroAI-Lab/rubix.git cd rubix -python -m pip install --editable .[tests] +python -m pip install --editable .[cpu,tests,dev] ``` Having done so, the test suite can be run using `pytest`: @@ -50,9 +52,21 @@ Having done so, the test suite can be run using `pytest`: python -m pytest ``` -This project depends on [jax](https://github.com/google/jax). It only installed for cpu computations with the testing dependencies. For installation instructions with gpu support, -please refer to [here](https://github.com/google/jax?tab=readme-ov-file#installation). +This project depends on [jax](https://github.com/google/jax). For the pytests we only test the `cpu` version. +For installation instructions with gpu support, +please refer to [here](https://github.com/google/jax?tab=readme-ov-file#installation) or simply use the `cuda` option when pip installing. + +## Configuration overview + +Rubix ships with two YAML files in `rubix/config/`: `rubix_config.yml` (constants, SSP templates, dust recipes, handler mappings, etc.) and `pipeline_config.yml` (pipeline graphs such as `calc_ifu` and `calc_dusty_ifu`). There is no configuration wizard — your runtime settings must supply a dictionary with the following blocks: + +- `pipeline.name`: Identifies the pipeline from `pipeline_config.yml` (e.g., `calc_ifu`, `calc_dusty_ifu`, or `calc_gradient`). +- `galaxy`: Must provide `dist_z` and a `rotation` section (`type` or explicit `alpha`, `beta`, `gamma`). +- `telescope`: Requires `name`, `psf` (currently only the `gaussian` kernel with `size` and `sigma`), `lsf` (`sigma`), and `noise` (`signal_to_noise` plus `noise_distribution`, choose from `normal` or `uniform`). +- `ssp.dust`: Must declare `extinction_model` and `Rv` before calling the dusty pipeline (see `rubix/spectra/dust/extinction_models.py` for the supported models such as `Cardelli89`). +- `data.args.particle_type`: Should include `"stars"` (and `"gas"` if you want the gas branch) so the filters and rotation functions know which components exist. +The tutorials and notebooks assume square spaxels, so the default telescope factory currently only supports `pixel_type: square`. For a working example, inspect `notebooks/rubix_pipeline_single_function_shard_map.ipynb`, which runs the exact pipeline used in the tests. ## Documentation Sphinx Documentation of all the functions is currently available under [this link](https://astro-rubix.web.app/). diff --git a/docs/installation.rst b/docs/installation.rst index 5669a7a..1483c67 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,7 +6,7 @@ Installation Clone the repository and navigate to the root directory of the repository. Then run ``` -pip install . +pip install .[cpu] ``` If you want to contribute to the development of `RUBIX`, we recommend the following editable installation from this repository: @@ -14,7 +14,7 @@ If you want to contribute to the development of `RUBIX`, we recommend the follow ``` git clone https://github.com/AstroAI-Lab/rubix cd rubix -pip install -e . +pip install -e .[cpu,tests,dev] ``` Having done so, the test suit can be run unsing `pytest`: @@ -22,8 +22,21 @@ Having done so, the test suit can be run unsing `pytest`: python -m pytest ``` -Note that if `JAX` is not yet installed, only the CPU version of `JAX` will be installed +Note that if `JAX` is not yet installed, with the `cpu` option only the CPU version of `JAX` will be installed as a dependency. For a GPU-compatible installation of `JAX`, please refer to the -[JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html). +[JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html) or use the option `cuda`. -Get started with this simple example notebooks/rubix_pipeline_single_function.ipynb. +Get started with this simple example notebooks/rubix_pipeline_single_function_shard_map.ipynb. + +Configuration +============= + +When you run the pipeline you provide a configuration dict that references the files in `rubix/config/`. The following sections are required for the default pipelines: + +- `pipeline.name`: Choose one of `calc_ifu`, `calc_dusty_ifu`, or another entry from `pipeline_config.yml`. +- `galaxy`: Must include `dist_z` and a `rotation` block (`type` or explicit `alpha`, `beta`, `gamma`). +- `telescope`: Needs `name`, a `psf` block (Gaussian kernel with both `size` and `sigma`), an `lsf` block with `sigma`, and `noise` containing `signal_to_noise` plus a `noise_distribution` (`normal` or `uniform`). +- `ssp.dust`: Declares `extinction_model` and `Rv` before the dusty pipeline can produce an extincted datacube. +- `data.args.particle_type`: Must include `"stars"` (add `"gas"` if you rely on the optional gas branch) so the filtering/rotation steps know which components to process. + +The telescopes in `rubix/telescope` currently only support square pixels, so every config should set `pixel_type: square` in the relevant telescope definition. diff --git a/notebooks/cosmology.ipynb b/notebooks/cosmology.ipynb index e956c1b..7d97377 100644 --- a/notebooks/cosmology.ipynb +++ b/notebooks/cosmology.ipynb @@ -75,8 +75,8 @@ "from rubix.cosmology.utils import trapz\n", "import jax.numpy as jnp\n", "\n", - "x = jnp.array([0, 1, 2, 3])\n", - "y = jnp.array([0, 1, 4, 9])\n", + "x = jnp.array([0.0, 1.0, 2.0, 3.0])\n", + "y = jnp.array([0.0, 1.0, 4.0, 9.0])\n", "print(trapz(x, y))" ] }, @@ -102,7 +102,7 @@ ], "metadata": { "kernelspec": { - "display_name": "rubix", + "display_name": "publishrubix", "language": "python", "name": "python3" }, @@ -116,7 +116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/notebooks/gradient_age_metallicity_variational_inference.ipynb b/notebooks/gradient_age_metallicity_variational_inference.ipynb deleted file mode 100644 index 3bcc852..0000000 --- a/notebooks/gradient_age_metallicity_variational_inference.ipynb +++ /dev/null @@ -1,643 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from jax import config\n", - "#config.update(\"jax_enable_x64\", True)\n", - "#config.update('jax_num_cpu_devices', 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "import os\n", - "\n", - "# Tell XLA to fake 2 host CPU devices\n", - "#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=3'\n", - "\n", - "# Only make GPU 0 and GPU 1 visible to JAX:\n", - "#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'\n", - "\n", - "#os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", - "\n", - "import jax\n", - "\n", - "# Now JAX will list two CpuDevice entries\n", - "print(jax.devices())\n", - "# → [CpuDevice(id=0), CpuDevice(id=1)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import os\n", - "#os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n", - "#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", - "os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'\n", - "#os.environ['SPS_HOME'] = '/export/home/aschaibl/fsps'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load ssp template from FSPS" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.spectra.ssp.factory import get_ssp_template\n", - "ssp_fsps = get_ssp_template(\"FSPS\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "age_values = ssp_fsps.age\n", - "print(age_values.shape)\n", - "\n", - "metallicity_values = ssp_fsps.metallicity\n", - "print(metallicity_values.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Configure pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.pipeline import RubixPipeline\n", - "import os\n", - "config = {\n", - " \"pipeline\":{\"name\": \"calc_gradient\",},\n", - " \n", - " \"logger\": {\n", - " \"log_level\": \"DEBUG\",\n", - " \"log_file_path\": None,\n", - " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", - " },\n", - " \"data\": {\n", - " \"name\": \"IllustrisAPI\",\n", - " \"args\": {\n", - " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", - " \"particle_type\": [\"stars\"],\n", - " \"simulation\": \"TNG50-1\",\n", - " \"snapshot\": 99,\n", - " \"save_data_path\": \"data\",\n", - " },\n", - " \n", - " \"load_galaxy_args\": {\n", - " \"id\": 14,\n", - " \"reuse\": True,\n", - " },\n", - " \n", - " \"subset\": {\n", - " \"use_subset\": True,\n", - " \"subset_size\": 2,\n", - " },\n", - " },\n", - " \"simulation\": {\n", - " \"name\": \"IllustrisTNG\",\n", - " \"args\": {\n", - " \"path\": \"data/galaxy-id-14.hdf5\",\n", - " },\n", - " \n", - " },\n", - " \"output_path\": \"output\",\n", - "\n", - " \"telescope\":\n", - " {\"name\": \"TESTGRADIENT\",\n", - " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", - " \"lsf\": {\"sigma\": 1.2},\n", - " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},\n", - " },\n", - " \"cosmology\":\n", - " {\"name\": \"PLANCK15\"},\n", - " \n", - " \"galaxy\":\n", - " {\"dist_z\": 0.1,\n", - " \"rotation\": {\"type\": \"edge-on\"},\n", - " },\n", - " \n", - " \"ssp\": {\n", - " \"template\": {\n", - " \"name\": \"FSPS\"\n", - " },\n", - " \"dust\": {\n", - " \"extinction_model\": \"Cardelli89\",\n", - " \"dust_to_gas_ratio\": 0.01,\n", - " \"dust_to_metals_ratio\": 0.4,\n", - " \"dust_grain_density\": 3.5,\n", - " \"Rv\": 3.1,\n", - " },\n", - " }, \n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "pipe = RubixPipeline(config)\n", - "inputdata = pipe.prepare_data()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Gradient on the spectrum for each wavelenght" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.pipeline import linear_pipeline as pipeline\n", - "\n", - "pipeline_instance = RubixPipeline(config)\n", - "\n", - "pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(\n", - " pipeline_instance.pipeline_config, \n", - " pipeline_instance._get_pipeline_functions()\n", - ")\n", - "pipeline_instance._pipeline.assemble()\n", - "pipeline_instance.func = pipeline_instance._pipeline.compile_expression()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# pick values\n", - "initial_age_index = 95\n", - "initial_metallicity_index = 4\n", - "age0 = age_values[initial_age_index]\n", - "Z0 = metallicity_values[initial_metallicity_index]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "print(f\"age0 = {age0}, Z0 = {Z0}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import jax.numpy as jnp\n", - "\n", - "inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])\n", - "inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])\n", - "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", - "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", - "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import dataclasses\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "def spectrum_1d(age, Z, base_data, pipeline_instance):\n", - " # broadcast per-star\n", - " nstar = base_data.stars.age.shape[0]\n", - " stars2 = dataclasses.replace(\n", - " base_data.stars,\n", - " age=jnp.full((nstar,), age),\n", - " metallicity=jnp.full((nstar,), Z),\n", - " )\n", - " data2 = dataclasses.replace(base_data, stars=stars2)\n", - "\n", - " out = pipeline_instance.func(data2)\n", - "\n", - " cube = out.stars.datacube # shape (…, n_lambda)\n", - " # collapse all non-wavelength axes, keep wavelength last\n", - " spec = cube.reshape((-1, cube.shape[-1])).sum(axis=0)\n", - "\n", - " return jnp.ravel(spec) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "spec0 = spectrum_1d(age0, Z0, inputdata, pipeline_instance)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import matplotlib.pyplot as plt\n", - "wave = pipe.telescope.wave_seq" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from tensorflow_probability.substrates import jax as tfp\n", - "tfd = tfp.distributions\n", - "tfb = tfp.bijectors\n", - "\n", - "import tqdm\n", - "import optax\n", - "import flax.linen as nn\n", - "from flax.metrics import tensorboard" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "class AffineCoupling(nn.Module):\n", - " @nn.compact\n", - " def __call__(self, x, nunits):\n", - " net = nn.leaky_relu(nn.Dense(128)(x))\n", - " net = nn.leaky_relu(nn.Dense(128)(net))\n", - " shift = nn.Dense(nunits)(net)\n", - " scale = nn.softplus(nn.Dense(nunits)(net))\n", - " return tfb.Chain([ tfb.Shift(shift), tfb.Scale(scale)])\n", - "\n", - "def make_nvp_fn(n_layers=2, d=2):\n", - " # We alternate between permutations and flow layers\n", - " layers = [ tfb.Permute([1,0])(tfb.RealNVP(d//2,\n", - " bijector_fn=AffineCoupling(name='affine%d'%i)))\n", - " for i in range(n_layers) ]\n", - "\n", - " # We build the actual nvp from these bijectors and a standard Gaussian distribution\n", - " nvp = tfd.TransformedDistribution(\n", - " tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=0.05*jnp.ones(2)),\n", - " bijector=tfb.Chain([tfb.Shift([5,0.05])] + layers ))\n", - " # Note that we have here added a shift to the bijector\n", - " return nvp\n", - "\n", - "class NeuralSplineFlowSampler(nn.Module):\n", - " @nn.compact\n", - " def __call__(self, key, n_samples):\n", - " nvp = make_nvp_fn()\n", - " x = nvp.sample(n_samples, seed=key)\n", - " return x, nvp.log_prob(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "model = NeuralSplineFlowSampler()\n", - "params = model.init(jax.random.PRNGKey(42), jax.random.PRNGKey(1), 16)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import pandas as pd\n", - "from chainconsumer import ChainConsumer, Chain, Truth\n", - "\n", - "# 1) Draw samples from the untrained bounded flow\n", - "theta0, logq0 = model.apply(params, key=jax.random.PRNGKey(1), n_samples=500)\n", - "df = pd.DataFrame(theta0, columns=[\"age\", \"Z\"])\n", - "\n", - "# 2) Optional: pick a fiducial point (for synthetic tests use your known truth)\n", - "fid_age = age0 # example: mid of [0, 20]\n", - "fid_Z = Z0 # example: inside [4.5e-5, 4.5e-2]\n", - "\n", - "# 3) Build the ChainConsumer plot\n", - "c = ChainConsumer()\n", - "c.add_chain(Chain(samples=df, name=\"Initial VI\"))\n", - "c.add_truth(Truth(location={\"age\": fid_age, \"Z\": fid_Z}))\n", - "\n", - "fig = c.plotter.plot(figsize=\"column\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "def log_prior_gaussian(theta_batch,\n", - " mu_age=6.0, sigma_age=3.0,\n", - " mu_Z=1.3e-3, sigma_Z=2e-4):\n", - " \"\"\"Gaussian prior in physical space.\"\"\"\n", - " age = theta_batch[:, 0]\n", - " Z = theta_batch[:, 1]\n", - " lp_age = -0.5 * (((age - mu_age) / sigma_age)**2\n", - " + jnp.log(2*jnp.pi*sigma_age**2))\n", - " lp_Z = -0.5 * (((Z - mu_Z) / sigma_Z)**2\n", - " + jnp.log(2*jnp.pi*sigma_Z**2))\n", - " return lp_age + lp_Z # shape (batch,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import jax, jax.numpy as jnp\n", - "\n", - "def log_likelihood(y, s, mask=None):\n", - " \"\"\"Full-vector Gaussian log-likelihood.\"\"\"\n", - " if mask is None:\n", - " mask = jnp.ones_like(y)\n", - " r = y - s\n", - " term = (r**2)\n", - " return jnp.sum(term * mask)\n", - "\n", - "def make_batched_loglik(y, base_data, pipeline_instance, mask=None):\n", - " \"\"\"Returns a function mapping a batch of theta -> per-sample log-likelihood.\"\"\"\n", - " def one_theta(theta):\n", - " age, Z = theta[0], theta[1]\n", - " s = spectrum_1d(age, Z, base_data, pipeline_instance) # -> (n_lambda,)\n", - " return log_likelihood(y, s, mask=mask, )\n", - " return jax.vmap(one_theta) # (batch,2) -> (batch,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "def make_elbo_fn(y, base_data, pipeline_instance,\n", - " mask=None, \n", - " mu_age=7.0, sigma_age=2.0,\n", - " mu_Z=0.001, sigma_Z=1e-3):\n", - " batched_loglik = make_batched_loglik(y, base_data,\n", - " pipeline_instance, mask)\n", - "\n", - " def elbo(params, seed, n_samples=128):\n", - " # Draw θ ~ q_φ(θ)\n", - " theta_batch, log_q = model.apply(params, key=seed, n_samples=n_samples)\n", - " # Compute log p(θ)\n", - " log_p = log_prior_gaussian(theta_batch, mu_age, sigma_age, mu_Z, sigma_Z)\n", - " # Compute log p(y|θ)\n", - " log_lik = batched_loglik(theta_batch)\n", - " # ELBO\n", - " elbo_value = jnp.mean(log_lik + log_p - log_q)\n", - " return -elbo_value # minimize\n", - " return elbo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# Random key\n", - "seed = jax.random.PRNGKey(0)\n", - "\n", - "# Scheduler and optimizer\n", - "total_steps = 20_000\n", - "lr = 2e-3\n", - "# lr_scheduler = optax.piecewise_constant_schedule(\n", - "# init_value=1e-3,\n", - "# boundaries_and_scales={int(total_steps*0.5): 0.2}\n", - "# )\n", - "optimizer = optax.adam(lr) #lr_scheduler)\n", - "opt_state = optimizer.init(params)\n", - "\n", - "# TensorBoard logs\n", - "from flax.metrics import tensorboard\n", - "summary_writer = tensorboard.SummaryWriter(\"logs/elbo\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "eps = 1e-6\n", - "sigma_obs = jnp.maximum(jnp.abs(spec0) / 1000.0, eps)\n", - "y = spec0\n", - "base_data = inputdata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# Build once, outside update_model\n", - "elbo = make_elbo_fn(\n", - " y, # observed full flux vector\n", - " base_data,\n", - " pipeline_instance, \n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "@jax.jit\n", - "def update_model(params, opt_state, seed):#, n_samples=128):\n", - " # split RNG: first return is new seed you’ll keep, second is used to sample θ\n", - " seed, key = jax.random.split(seed)\n", - "\n", - " # loss(params) = -ELBO(params, key, n_samples)\n", - " loss, grads = jax.value_and_grad(elbo)(params, key)#, n_samples)\n", - "\n", - " # apply Adam step; passing params is safest for transforms that need them\n", - " updates, opt_state = optimizer.update(grads, opt_state, params)\n", - " params = optax.apply_updates(params, updates)\n", - "\n", - " return params, opt_state, loss, seed" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import tqdm\n", - "\n", - "losses = []\n", - "\n", - "for i in tqdm.tqdm(range(total_steps)):\n", - " # one optimization step (minimizes -ELBO)\n", - " params, opt_state, loss, seed = update_model(params, opt_state, seed)\n", - "\n", - " losses.append(float(loss))\n", - "\n", - " # log every 10 steps\n", - " if i % 10 == 0:\n", - " summary_writer.scalar(\"neg_elbo\", float(loss), i)\n", - " #summary_writer.scalar(\"learning_rate\", float(lr_scheduler(i)), i)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "# 1) Sample posterior θ = (age, Z)\n", - "seed, sub = jax.random.split(seed)\n", - "theta, log_q = model.apply(params, key=sub, n_samples=5000) # theta.shape == (5000, 2)\n", - "age = theta[:, 0]\n", - "Z = theta[:, 1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "c = ChainConsumer()\n", - "\n", - "# fresh RNG split so we don’t reuse training key\n", - "seed, sub = jax.random.split(seed)\n", - "\n", - "# sample θ ~ qϕ(θ)\n", - "theta, log_q = model.apply(params, key=sub, n_samples=20_000) # shape (N, 2)\n", - "age = theta[:, 0]\n", - "Z = theta[:, 1]\n", - "\n", - "# ChainConsumer expects a pandas DataFrame\n", - "df = pd.DataFrame({\"age\": age, \"Z\": Z})\n", - "\n", - "# add the VI chain\n", - "c.add_chain(Chain(samples=df, name=\"VI\"))\n", - "\n", - "# optional “truth” dot: use known synthetic truth if you have it; else posterior mean\n", - "# truth_age, truth_Z = 8.0, 1.0e-2 # <- set these if you know them\n", - "# truth_age, truth_Z = float(age.mean()), float(Z.mean())\n", - "truth_age, truth_Z = age0, Z0\n", - "c.add_truth(Truth(location={\"age\": truth_age, \"Z\": truth_Z}))\n", - "\n", - "fig = c.plotter.plot(figsize=\"column\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "plt.figure(figsize=(7,3))\n", - "plt.plot(np.arange(len(losses)), losses, lw=1)\n", - "plt.xlabel(\"Iteration\")\n", - "plt.ylabel(\"Loss\")\n", - "plt.grid(True)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rubix", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index a53894a..49b9347 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "ipywidgets", "jdaviz", "pynbody", + "optax", "opt-einsum >=3.3.0", ] [project.optional-dependencies] @@ -91,7 +92,6 @@ tests = [ "pytest-mock", "requests-mock", "nbval", - "jax[cpu]>0.5.1", "pre-commit", ] docs = [ diff --git a/rubix/core/cosmology.py b/rubix/core/cosmology.py index 49f7399..193f6d5 100644 --- a/rubix/core/cosmology.py +++ b/rubix/core/cosmology.py @@ -21,7 +21,7 @@ def get_cosmology(config: dict) -> RubixCosmology: ValueError: When ``config["cosmology"]["name"]`` is not supported. Example: - :: + >>> config = { ... ... ... "cosmology": diff --git a/rubix/core/data.py b/rubix/core/data.py index 0d4dbf2..ebc55fe 100644 --- a/rubix/core/data.py +++ b/rubix/core/data.py @@ -2,12 +2,12 @@ import os from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker +from beartype.typing import Any, Callable, Optional, Union from jaxtyping import jaxtyped from rubix.galaxy import IllustrisAPI, get_input_handler @@ -265,7 +265,7 @@ def convert_to_rubix(config: Union[dict, str]): ValueError: When ``config['data']['name']`` is unsupported. Example: - :: + >>> import os >>> from rubix.core.data import convert_to_rubix @@ -397,7 +397,7 @@ def prepare_input(config: Union[dict, str]) -> RubixData: ValueError: When subset mode is enabled but neither stars nor gas coordinates exist. Example: - :: + >>> import os >>> from rubix.core.data import convert_to_rubix, prepare_input @@ -430,7 +430,7 @@ def prepare_input(config: Union[dict, str]) -> RubixData: # Set the galaxy attributes rubixdata.galaxy.redshift = jnp.float64(data["redshift"]) rubixdata.galaxy.redshift_unit = units["galaxy"]["redshift"] - rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64) + rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float32) rubixdata.galaxy.center_unit = units["galaxy"]["center"] rubixdata.galaxy.halfmassrad_stars = jnp.float64(data["subhalo_halfmassrad_stars"]) rubixdata.galaxy.halfmassrad_stars_unit = units["galaxy"]["halfmassrad_stars"] @@ -550,7 +550,7 @@ def get_reshape_data(config: Union[dict, str]) -> Callable: Function that reshapes a `RubixData` instance. Example: - :: + >>> from rubix.core.data import get_reshape_data >>> reshape_data = get_reshape_data(config) >>> rubixdata = reshape_data(rubixdata) diff --git a/rubix/core/dust.py b/rubix/core/dust.py index 45a38b3..e86d85a 100644 --- a/rubix/core/dust.py +++ b/rubix/core/dust.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.core.cosmology import get_cosmology diff --git a/rubix/core/ifu.py b/rubix/core/ifu.py index b7d37df..6402e84 100644 --- a/rubix/core/ifu.py +++ b/rubix/core/ifu.py @@ -1,8 +1,7 @@ -from typing import Callable, Union - import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable from jax import lax from jaxtyping import Array, Float, jaxtyped diff --git a/rubix/core/lsf.py b/rubix/core/lsf.py index 5e60760..4983987 100644 --- a/rubix/core/lsf.py +++ b/rubix/core/lsf.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -24,7 +23,6 @@ def get_convolve_lsf(config: dict) -> Callable[[RubixData], RubixData]: ValueError: When the telescope LSF configuration or sigma is missing. Example: - :: >>> config = { ... ... diff --git a/rubix/core/noise.py b/rubix/core/noise.py index 472023e..d5ffc03 100644 --- a/rubix/core/noise.py +++ b/rubix/core/noise.py @@ -1,7 +1,6 @@ -from typing import Callable - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -27,7 +26,6 @@ def get_apply_noise(config: dict) -> Callable[[RubixData], RubixData]: ValueError: When required noise configuration keys are missing. Example: - :: >>> config = { ... ... diff --git a/rubix/core/pipeline.py b/rubix/core/pipeline.py index 7be1f8b..b8bd72b 100644 --- a/rubix/core/pipeline.py +++ b/rubix/core/pipeline.py @@ -1,11 +1,23 @@ import time -from typing import Any, Optional, Sequence, Union +import warnings import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Any, Optional, Sequence, Union from jax import lax -from jax.experimental.shard_map import shard_map + +try: + from jax.shard_map import shard_map # type: ignore[attr-defined] +except ImportError: # pragma: no cover - older JAX compatibility + warnings.filterwarnings( + "ignore", + message="jax.experimental.shard_map is deprecated in v0.8.0.*", + category=DeprecationWarning, + module=__name__, + ) + from jax.experimental.shard_map import shard_map + from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.tree_util import tree_map from jaxtyping import jaxtyped @@ -36,16 +48,17 @@ class RubixPipeline: Parsed configuration dictionary or path to a configuration file. Example: - :: >>> from rubix.core.pipeline import RubixPipeline >>> config = "path/to/config.yml" + >>> target_datacube = ... # Load or define your target datacube here >>> pipe = RubixPipeline(config) >>> inputdata = pipe.prepare_data() - >>> output = pipe.run(inputdata) >>> final_datacube = pipe.run_sharded(inputdata) - >>> ssp_model = pipeline.ssp - >>> telescope = pipeline.telescope + >>> ssp_model = pipe.ssp + >>> telescope = pipe.telescope + >>> loss_value = pipe.loss(inputdata, target_datacube) + >>> gradient_data = pipe.gradient(inputdata, target_datacube) """ def __init__(self, user_config: Union[dict, str]): @@ -304,6 +317,6 @@ def loss( jnp.ndarray: Scalar mean squared error value. """ - output = self.run(rubixdata) + output = self.run_sharded(rubixdata) loss_value = jnp.sum((output - targetdata) ** 2) return loss_value diff --git a/rubix/core/psf.py b/rubix/core/psf.py index 274c4ef..46dc40e 100644 --- a/rubix/core/psf.py +++ b/rubix/core/psf.py @@ -1,6 +1,5 @@ -from typing import Callable - from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger @@ -29,7 +28,6 @@ def get_convolve_psf(config: dict) -> Callable: kernel type. Example: - :: >>> config = { ... ... diff --git a/rubix/core/rotation.py b/rubix/core/rotation.py index 6023270..f2db5c3 100644 --- a/rubix/core/rotation.py +++ b/rubix/core/rotation.py @@ -26,7 +26,7 @@ def get_galaxy_rotation(config: dict): or missing. Example: - :: + >>> config = { ... ... ... "galaxy": { diff --git a/rubix/core/ssp.py b/rubix/core/ssp.py index 850f33c..dd9bb44 100644 --- a/rubix/core/ssp.py +++ b/rubix/core/ssp.py @@ -1,7 +1,6 @@ -from typing import Callable - import jax from beartype import beartype as typechecker +from beartype.typing import Callable from jaxtyping import jaxtyped from rubix.logger import get_logger diff --git a/rubix/core/telescope.py b/rubix/core/telescope.py index d9fddd6..b847f20 100644 --- a/rubix/core/telescope.py +++ b/rubix/core/telescope.py @@ -1,7 +1,6 @@ -from typing import Callable, Union - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Callable, Union from jaxtyping import Array, Float, jaxtyped from rubix.logger import get_logger @@ -153,7 +152,7 @@ def get_filter_particles(config: dict) -> Callable: Callable[[RubixData], RubixData]: Function that filters particles. Example: - :: + >>> from rubix.core.telescope import get_filter_particles >>> filter_particles = get_filter_particles(config) diff --git a/rubix/cosmology/base.py b/rubix/cosmology/base.py index a6c460e..feff600 100644 --- a/rubix/cosmology/base.py +++ b/rubix/cosmology/base.py @@ -40,7 +40,7 @@ class BaseCosmology(eqx.Module): h (jnp.float32): Dimensionless Hubble constant. Example: - :: + >>> # Create Planck15 cosmology >>> from rubix.cosmology import COSMOLOGY >>> cosmo = COSMOLOGY(0.3089, -1.0, 0.0, 0.6774) @@ -73,7 +73,7 @@ def scale_factor_to_redshift( Float[Array, "..."]: Redshift ``1/a - 1``. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Convert scale factor 0.5 to redshift >>> cosmo.scale_factor_to_redshift(jnp.array(0.5)) @@ -121,7 +121,7 @@ def comoving_distance_to_z( Float[Array, "..."]: Comoving distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate comoving distance to redshift 0.5 >>> cosmo.comoving_distance_to_z(0.5) @@ -145,7 +145,7 @@ def luminosity_distance_to_z( Float[Array, "..."]: Luminosity distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the luminosity distance to redshift 0.5 >>> cosmo.luminosity_distance_to_z(0.5) @@ -167,7 +167,7 @@ def angular_diameter_distance_to_z( Float[Array, "..."]: Angular diameter distance in Mpc. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the angular diameter distance to redshift 0.5 >>> cosmo.angular_diameter_distance_to_z(0.5) @@ -189,7 +189,7 @@ def distance_modulus_to_z( Float[Array, "..."]: Distance modulus. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Compute the distance modulus to redshift 0.5 >>> cosmo.distance_modulus_to_z(0.5) @@ -211,7 +211,7 @@ def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, ".. Float[Array, "..."]: Hubble time in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the Hubble time at redshift 0.5 >>> cosmo._hubble_time(0.5) @@ -235,7 +235,7 @@ def lookback_to_z( Float[Array, "..."]: Lookback time in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the lookback time to redshift 0.5 >>> cosmo.lookback_to_z(0.5) @@ -256,7 +256,7 @@ def age_at_z0(self) -> Float[Array, "..."]: The age of the universe at redshift 0 (float). Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the age of the universe at redshift 0 >>> cosmo.age_at_z0() @@ -294,7 +294,7 @@ def age_at_z( Float[Array, "..."]: Age in seconds. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the age of the universe at redshift 0.5 >>> cosmo.age_at_z(0.5) @@ -317,7 +317,7 @@ def angular_scale( Float[Array, "..."]: Angular scale in kpc/arcsec. Example: - :: + >>> from rubix.cosmology import PLANCK15 as cosmo >>> # Calculate the angular scale at redshift 0.5 >>> cosmo.angular_scale(0.5) @@ -326,34 +326,3 @@ def angular_scale( D_A = self.angular_diameter_distance_to_z(z) # in Mpc scale = D_A * (jnp.pi / (180 * 60 * 60)) * 1e3 # in kpc/arcsec return scale - - """ - I dont think we need this currently, but keeping it here for reference - @jit - def rho_crit(self, redshift): - rho_crit0 = RHO_CRIT0_KPC3_UNITY_H * self.h * self.h - rho_crit = rho_crit0 * self._Ez(redshift) ** 2 - return rho_crit - - @jit - def _integrand_oneOverEz1pz(self, z): - return 1.0 / self._Ez(z) / (1.0 + z) - - @jit - def _Om_at_z(self, z): - E = self._Ez(z) - return self.Om0 * (1.0 + z) ** 3 / E / E - - @jit - def _delta_vir(self, z): - x = self._Om(z) - 1.0 - Delta = 18 * jnp.pi**2 + 82.0 * x - 39.0 * x**2 - return Delta - - @jit - def virial_dynamical_time(self, redshift): - delta = self._delta_vir(redshift) - t_cross = 2**1.5 * self._hubble_time(redshift) * delta**-0.5 - return t_cross - -""" diff --git a/rubix/cosmology/utils.py b/rubix/cosmology/utils.py index 07b0710..be4eb49 100644 --- a/rubix/cosmology/utils.py +++ b/rubix/cosmology/utils.py @@ -71,7 +71,7 @@ def trapz( jnp.ndarray: Scalar results collected from the scan. Example: - :: + >>> from rubix.cosmology.utils import trapz >>> import jax.numpy as jnp diff --git a/rubix/galaxy/alignment.py b/rubix/galaxy/alignment.py index 7ff9c52..f201080 100644 --- a/rubix/galaxy/alignment.py +++ b/rubix/galaxy/alignment.py @@ -1,7 +1,6 @@ -from typing import Tuple, Union - import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Tuple, Union from jax.scipy.spatial.transform import Rotation from jaxtyping import Array, Float, jaxtyped @@ -23,7 +22,7 @@ def center_particles(rubixdata: object, key: str) -> object: ValueError: If the galaxy center lies outside the particle bounds. Example: - :: + >>> from rubix.galaxy.alignment import center_particles >>> rubixdata = center_particles(rubixdata, "stars") """ @@ -84,7 +83,7 @@ def moment_of_inertia_tensor( Float[Array, "..."]: Moment of inertia tensor. Example: - :: + >>> from rubix.galaxy.alignment import moment_of_inertia_tensor >>> I = moment_of_inertia_tensor( ... rubixdata.stars.coords, diff --git a/rubix/galaxy/input_handler/api/illustris_api.py b/rubix/galaxy/input_handler/api/illustris_api.py index b5340c2..c35c2c2 100644 --- a/rubix/galaxy/input_handler/api/illustris_api.py +++ b/rubix/galaxy/input_handler/api/illustris_api.py @@ -224,7 +224,7 @@ def load_galaxy( unsupported particle type is configured. Example: - :: + >>> illustris_api = IllustrisAPI( ... api_key, ... simulation="TNG50-1", diff --git a/rubix/galaxy/input_handler/base.py b/rubix/galaxy/input_handler/base.py index 33941fd..b2ef4d4 100644 --- a/rubix/galaxy/input_handler/base.py +++ b/rubix/galaxy/input_handler/base.py @@ -163,25 +163,6 @@ def _check_galaxy_data(self, galaxy_data, units): if field not in units["galaxy"]: raise ValueError(f"Units for {field} not found in units") - """ - def _check_particle_data(self, particle_data, units): - # Check if all required fields are present - for key in self.config["particles"]: - if key not in particle_data: - raise ValueError(f"Missing particle type {key} in particle data") - for field in self.config["particles"][key]: - if field not in particle_data[key]: - raise ValueError( - f"Missing field {field} in particle data for particle type {key}" - ) - - # Check if the units are correct - for key in particle_data: - for field in particle_data[key]: - if field not in units[key]: - raise ValueError(f"Units for {field} not found in units") - """ - def _check_particle_data(self, particle_data, units): # Get the list of expected particle types from the configuration expected_particle_types = list(self.config["particles"].keys()) diff --git a/rubix/spectra/dust/extinction_models.py b/rubix/spectra/dust/extinction_models.py index 453ec36..8935767 100644 --- a/rubix/spectra/dust/extinction_models.py +++ b/rubix/spectra/dust/extinction_models.py @@ -39,7 +39,6 @@ class Cardelli89(BaseExtRvModel): Example: Example showing CCM89 curves for a range of R(V) values. - :: .. plot:: :include-source: @@ -209,7 +208,6 @@ class Gordon23(BaseExtRvModel): Example: Example showing G23 curves for a range of R(V) values. - :: .. plot:: :include-source: diff --git a/rubix/spectra/dust/generic_models.py b/rubix/spectra/dust/generic_models.py index 885e7ac..f1e36c4 100644 --- a/rubix/spectra/dust/generic_models.py +++ b/rubix/spectra/dust/generic_models.py @@ -117,6 +117,7 @@ def Drude1d( ValueError: If ``x_0`` is zero. Examples: + .. plot:: :include-source: @@ -214,7 +215,6 @@ def FM90( Examples: Example showing a FM90 curve with components identified. - :: .. plot:: :include-source: diff --git a/rubix/spectra/dust/helpers.py b/rubix/spectra/dust/helpers.py index 2967448..692e08c 100644 --- a/rubix/spectra/dust/helpers.py +++ b/rubix/spectra/dust/helpers.py @@ -1,8 +1,7 @@ -from typing import Final, Tuple - import jax import jax.numpy as jnp from beartype import beartype as typechecker +from beartype.typing import Final, Tuple from jaxtyping import Array, Float, jaxtyped # from jax.scipy.special import comb diff --git a/rubix/spectra/ssp/factory.py b/rubix/spectra/ssp/factory.py index b003ae9..b3f078d 100644 --- a/rubix/spectra/ssp/factory.py +++ b/rubix/spectra/ssp/factory.py @@ -22,6 +22,7 @@ def get_ssp_template(template: str) -> SSPGrid: ValueError: If the template name or source format is not supported. Example: + >>> from rubix.spectra.ssp.factory import get_ssp_template >>> ssp = get_ssp_template("FSPS") >>> ssp.age.shape diff --git a/rubix/spectra/ssp/grid.py b/rubix/spectra/ssp/grid.py index d01a2b8..88c9e59 100644 --- a/rubix/spectra/ssp/grid.py +++ b/rubix/spectra/ssp/grid.py @@ -1,6 +1,5 @@ import os from dataclasses import dataclass, fields -from typing import List, Tuple, Union # import equinox as eqx import h5py @@ -9,6 +8,7 @@ from astropy import units as u from astropy.io import fits from beartype import beartype as typechecker +from beartype.typing import List, Tuple, Union from interpax import interp2d from jax.tree_util import Partial from jaxtyping import Array, Float, Int, jaxtyped @@ -77,7 +77,7 @@ def get_lookup_interpolation( Partial: Interpolation function ``f(metallicity, age)``. Examples: - :: + >>> grid = SSPGrid(...) >>> lookup = grid.get_lookup_interpolation() >>> metallicity = 0.02 @@ -256,7 +256,6 @@ class HDF5SSPGrid(SSPGrid): flux (Float[Array, FLUX_AXES]): SSP fluxes in Lsun/Angstrom. Example: - :: >>> config = { ... "name": "Bruzual & Charlot (2003)", @@ -363,7 +362,6 @@ class pyPipe3DSSPGrid(SSPGrid): flux (Float[Array, FLUX_AXES]): SSP fluxes in Lsun/Angstrom. Example: - :: >>> config = { ... "name": "Mastar Charlot & Bruzual (2019)", diff --git a/rubix/spectra/ssp/templates.py b/rubix/spectra/ssp/templates.py index 27353c2..9230b4e 100644 --- a/rubix/spectra/ssp/templates.py +++ b/rubix/spectra/ssp/templates.py @@ -2,6 +2,7 @@ This module contains the supported templates for the SSP grid. Example: + >>> from rubix.spectra.ssp.templates import BruzualCharlot2003 >>> BruzualCharlot2003 >>> print(BruzualCharlot2003.age) diff --git a/rubix/telescope/base.py b/rubix/telescope/base.py index cbb3f80..0212150 100644 --- a/rubix/telescope/base.py +++ b/rubix/telescope/base.py @@ -1,8 +1,7 @@ -from typing import List, Optional, Union - import equinox as eqx import numpy as np from beartype import beartype as typechecker +from beartype.typing import List, Optional, Union from jaxtyping import Array, Float, Int, jaxtyped diff --git a/rubix/telescope/factory.py b/rubix/telescope/factory.py index 60db868..7649c54 100644 --- a/rubix/telescope/factory.py +++ b/rubix/telescope/factory.py @@ -6,6 +6,7 @@ from beartype import beartype as typechecker from jaxtyping import jaxtyped +from rubix.logger import get_logger from rubix.telescope.apertures import ( CIRCULAR_APERTURE, HEXAGONAL_APERTURE, @@ -22,11 +23,17 @@ class TelescopeFactory: @jaxtyped(typechecker=typechecker) def __init__(self, telescopes_config: Optional[Union[dict, str]] = None) -> None: + logger = get_logger() if telescopes_config is None: + logger.info( + "No telescope config provided, falling back to %s", + TELESCOPE_CONFIG_PATH, + ) warnings.warn( - "No telescope config provided, using default stored in {}".format( + ("No telescope config provided, " "using default stored in {}").format( TELESCOPE_CONFIG_PATH - ) + ), + UserWarning, ) self.telescopes_config = read_yaml(TELESCOPE_CONFIG_PATH) elif isinstance(telescopes_config, str): @@ -46,7 +53,8 @@ def create_telescope(self, name: str) -> BaseTelescope: The telescope object as BaseTelescope. Raises: - ValueError: If the telescope name is not present in the configuration. + ValueError: If the telescope name is not present in the + configuration. Example 1 (Uses the defined telescope configuration) ----------------------------------------------------- diff --git a/rubix/telescope/utils.py b/rubix/telescope/utils.py index 2400e51..6a9e321 100644 --- a/rubix/telescope/utils.py +++ b/rubix/telescope/utils.py @@ -1,8 +1,7 @@ -from typing import List, Tuple, Union - import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker +from beartype.typing import List, Tuple, Union from jaxtyping import Array, Bool, Float, Int, jaxtyped from rubix.cosmology.base import BaseCosmology diff --git a/rubix/utils.py b/rubix/utils.py index edc08e2..09829e8 100644 --- a/rubix/utils.py +++ b/rubix/utils.py @@ -180,7 +180,7 @@ def load_galaxy_data( Tuple[Dict[str, Any], Dict[str, Any]]: Galaxy data and associated units Example: - :: + >>> from rubix.utils import load_galaxy_data >>> galaxy_data, units = load_galaxy_data("path/to/file.hdf5") """ diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 615a8f7..fcc433d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -538,10 +538,10 @@ def test_loss_uses_run(simple_pipeline): target = jnp.array([1.0, 2.0]) output = jnp.array([3.0, 4.0]) - pipeline.run = MagicMock(return_value=output) + pipeline.run_sharded = MagicMock(return_value=output) loss_value = pipeline.loss(rubixdata, target) - pipeline.run.assert_called_once_with(rubixdata) + pipeline.run_sharded.assert_called_once_with(rubixdata) expected = jnp.sum((output - target) ** 2) assert jnp.allclose(loss_value, expected)