diff --git a/notebooks/dust_extinction.ipynb b/notebooks/dust_extinction.ipynb index 55099ca..3b5ed3c 100644 --- a/notebooks/dust_extinction.ipynb +++ b/notebooks/dust_extinction.ipynb @@ -9,7 +9,7 @@ "# NBVAL_SKIP\n", "import os\n", "# os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", - "# os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'" + "os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'" ] }, { @@ -272,17 +272,6 @@ "In order to comapre a dusty and non dusty IFU cube, we first run a normal RUBIX pipeline." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#import os\n", - "#os.environ[\"SPS_HOME\"] = '/Users/buck/Documents/Nexus/codes/fsps'\n", - "#ILLUSTRIS_API_KEY = 'c0112e1fa11489ef0e6164480643d1c8'" - ] - }, { "cell_type": "code", "execution_count": null, @@ -368,7 +357,8 @@ "#NBVAL_SKIP\n", "pipe = RubixPipeline(config)\n", "\n", - "rubixdata = pipe.run()" + "inputdata = pipe.prepare_data()\n", + "rubixdata = pipe.run_sharded(inputdata)" ] }, { @@ -472,7 +462,8 @@ "#NBVAL_SKIP\n", "pipe = RubixPipeline(config)\n", "\n", - "rubixdata_dust = pipe.run()" + "inputdata = pipe.prepare_data()\n", + "rubixdata_dust = pipe.run_sharded(inputdata)" ] }, { @@ -491,8 +482,8 @@ "#NBVAL_SKIP\n", "wave = pipe.telescope.wave_seq\n", "\n", - "spectra = rubixdata.stars.datacube # Spectra of all stars\n", - "dusty_spectra = rubixdata_dust.stars.datacube # Spectra of all stars\n", + "spectra = rubixdata # Spectra of all stars\n", + "dusty_spectra = rubixdata_dust # Spectra of all stars\n", "print(spectra.shape)\n", "print(dusty_spectra.shape)\n", "\n", @@ -536,7 +527,7 @@ "source": [ "# NBVAL_SKIP\n", "wave = pipe.telescope.wave_seq\n", - "filters,images = curves.apply_filter_curves(rubixdata_dust.stars.datacube, wave).values()\n", + "filters,images = curves.apply_filter_curves(rubixdata_dust, wave).values()\n", "\n", "for i_dust,name in zip(images, filters):\n", " plt.figure()\n", @@ -559,8 +550,23 @@ "outputs": [], "source": [ "# NBVAL_SKIP\n", - "idx = np.where(rubixdata.gas.mass[0] != 0)\n", - "gas_map = np.histogram2d(rubixdata.gas.coords[0,:,0][idx], rubixdata.gas.coords[0,:,1][idx], bins=(25,25), weights=np.squeeze(rubixdata.gas.mass)[idx])" + "# The input data are rotated in the same way as the particles are rotaterd to calculate the IFU. \n", + "# This step is necessary, because we only have the raw input data and the pipeline only returns the datacube and not the per particle information.\n", + "from rubix.core.rotation import get_galaxy_rotation\n", + "rotate = get_galaxy_rotation(config)\n", + "\n", + "inputdatadata = rotate(inputdata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "idx = np.where(inputdata.gas.mass != 0)\n", + "gas_map = np.histogram2d(inputdata.gas.coords[:,0][idx], inputdata.gas.coords[:,1][idx], bins=(25,25), weights=np.squeeze(inputdata.gas.mass)[idx])" ] }, { @@ -605,7 +611,7 @@ "source": [ "# NBVAL_SKIP\n", "wave = pipe.telescope.wave_seq\n", - "filters,images = curves.apply_filter_curves(rubixdata.stars.datacube, wave).values()\n", + "filters,images = curves.apply_filter_curves(rubixdata, wave).values()\n", "\n", "for i,name in zip(images, filters):\n", " plt.figure()\n", @@ -624,7 +630,7 @@ ], "metadata": { "kernelspec": { - "display_name": "rubix-test", + "display_name": "rubix", "language": "python", "name": "python3" }, @@ -638,7 +644,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.0" + "version": "3.12.8" } }, "nbformat": 4, diff --git a/notebooks/fits_file.ipynb b/notebooks/fits_file.ipynb deleted file mode 100644 index 72a05f6..0000000 --- a/notebooks/fits_file.ipynb +++ /dev/null @@ -1,330 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import os\n", - "os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n", - "#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'\n", - "#os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Fits files\n", - "\n", - "In this notebook we show, how you can store your mock datacube in a fits file, which is the common format in which are observational data handled. We firtss create a mock IFU cube by running the RUBIX pipeline, store it then in a fits file and then lod the data from the fits file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import matplotlib.pyplot as plt\n", - "import os\n", - "from rubix.core.pipeline import RubixPipeline\n", - "\n", - "# Define Illustris configuration\n", - "config_illustris = {\n", - " \"pipeline\": {\"name\": \"calc_ifu\"},\n", - " \"logger\": {\"log_level\": \"DEBUG\", \"log_file_path\": None, \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"},\n", - " \"data\": {\n", - " \"name\": \"IllustrisAPI\",\n", - " \"args\": {\n", - " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", - " \"particle_type\": [\"stars\", \"gas\"],\n", - " \"simulation\": \"TNG50-1\",\n", - " \"snapshot\": 99,\n", - " \"save_data_path\": \"data\",\n", - " },\n", - " \"load_galaxy_args\": {\"id\": 422754, \"reuse\": True},\n", - " \"subset\": {\"use_subset\": False, \"subset_size\": 750000},\n", - " },\n", - " \"simulation\": {\"name\": \"IllustrisTNG\", \"args\": {\"path\": \"data/galaxy-id-422754.hdf5\"}},\n", - " \"output_path\": \"output\",\n", - " \"telescope\": {\"name\": \"MUSE\", \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6}, \n", - " \"lsf\": {\"sigma\": 0.5}, \"noise\": {\"signal_to_noise\": 100, \"noise_distribution\": \"normal\"}},\n", - " \"cosmology\": {\"name\": \"PLANCK15\"},\n", - " \"galaxy\": {\"dist_z\": 0.1, \"rotation\": {\"type\": \"edge-on\"}},\n", - " \"ssp\": {\"template\": {\"name\": \"FSPS\"}, #\"Mastar_CB19_SLOG_1_5\"},\n", - " \"dust\": {\n", - " \"extinction_model\": \"Cardelli89\", #\"Gordon23\", \n", - " \"dust_to_gas_ratio\": 0.01, # need to check Remyer's paper\n", - " \"dust_to_metals_ratio\": 0.4, # do we need this ratio if we set the dust_to_gas_ratio?\n", - " \"dust_grain_density\": 3.5, # g/cm^3 #check this value\n", - " \"Rv\": 3.1,\n", - " },\n", - " },\n", - "}\n", - "\n", - "\n", - "# Run pipeline\n", - "pipe = RubixPipeline(config_illustris)\n", - "data = pipe.run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "data.stars.spectra.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "data.stars.spectra.max()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import numpy as np\n", - "plt.plot(np.linspace(1, 10, data.stars.spectra.shape[2]), data.stars.spectra[:,:750000,:].sum(axis=1)[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "datacube = data.stars.datacube\n", - "\n", - "img = datacube.sum(axis=2)\n", - "plt.imshow(img, origin=\"lower\")\n", - "plt.plot(12,12, 'ro')\n", - "plt.plot(17,12, 'x', color=\"blue\")\n", - "plt.plot(7,12, 'x', color=\"orange\")\n", - "plt.colorbar()\n", - "print(img.min(), img.max())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "wave = pipe.telescope.wave_seq\n", - "#plt.plot(wave, data.stars.datacube[12, 12, :], color=\"red\", label=\"Spectrum\")\n", - "plt.vlines(4861.333, 0, 3000, color='r', label=\"Hbeta=4861.333A\")\n", - "plt.vlines(4861.333*1.1, 0, 3000, color='y', label=\"line obs=Hbeta*(1+z)\")\n", - "plt.plot(wave, data.stars.datacube[7, 12, :], color=\"orange\", label=\"Spectrum 7,12\")\n", - "plt.plot(wave, data.stars.datacube[17, 12, :], color=\"blue\", label=\"Spectrum 17,12\")\n", - "#plt.xlim(5300, 5400)\n", - "plt.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "wave = pipe.telescope.wave_seq\n", - "#plt.plot(wave, data.stars.datacube[12, 12, :], color=\"red\", label=\"Spectrum\")\n", - "plt.vlines(4861.333, 0, 370, color='r', label=\"Hbeta=4861.333A\")\n", - "plt.vlines(4861.333*1.1, 0, 370, color='y', label=\"line obs=Hbeta*(1+z)\")\n", - "plt.plot(wave, data.stars.datacube[17, 12, :], color=\"blue\", label=\"Spectrum 2,12\")\n", - "plt.plot(wave, data.stars.datacube[7, 12, :], color=\"orange\", label=\"Spectrum 22,12\")\n", - "plt.xlim(5300, 5400)\n", - "plt.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Plot a histogram of the velocities\n", - "plt.hist(data.stars.velocity[0,:,2], bins=30, edgecolor='black')\n", - "plt.xlabel('Velocity')\n", - "plt.ylabel('Frequency')\n", - "plt.title('Histogram of Star Velocities')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Assuming your data arrays are defined as follows:\n", - "pixel_assignment = np.asarray(np.squeeze(data.stars.pixel_assignment))\n", - "velocities = np.asarray(data.stars.velocity[0, :, 2])\n", - "\n", - "# Compute the sum of velocities and count per pixel using np.bincount\n", - "sum_velocity = np.bincount(pixel_assignment, weights=velocities)\n", - "counts = np.bincount(pixel_assignment)\n", - "\n", - "# Calculate mean velocity; note: division by zero is avoided if every pixel has at least one star.\n", - "mean_velocity = sum_velocity / counts\n", - "\n", - "# If you know the pixel grid dimensions (for example, a square grid)\n", - "n_pixels = len(mean_velocity)\n", - "grid_size = int(np.sqrt(n_pixels))\n", - "if grid_size * grid_size != n_pixels:\n", - " raise ValueError(\"The total number of pixels is not a perfect square; please specify the grid shape explicitly.\")\n", - "\n", - "# Reshape the mean_velocity into a 2D array for imshow\n", - "velocity_map = mean_velocity.reshape((grid_size, grid_size))\n", - "\n", - "# Plot the result\n", - "plt.figure(figsize=(6, 5))\n", - "plt.imshow(velocity_map, origin='lower', interpolation='nearest', cmap='seismic')\n", - "plt.colorbar(label='Mean Velocity')\n", - "plt.title('Mean Velocity per Pixel')\n", - "plt.xlabel('X pixel index')\n", - "plt.ylabel('Y pixel index')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Store datacube in a fits file with header\n", - "\n", - "In RUBIX we implemented a function that automaticly takes the relevant information from the config and writes it into the header. Then the header and data are stored in a fits file. All is done with the store_fits function from the rubix.core.fits module." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "from rubix.core.fits import store_fits\n", - "\n", - "store_fits(config_illustris, data, \"output/\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load datacube from fits file\n", - "\n", - "We implemented a function to load a fits file. It is based on MPDAF, which is a package to handle MUSE IFU cubes. You can load your datacube by the following line and access all kind of information from the fitsfile." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "from rubix.core.fits import load_fits\n", - "\n", - "cube = load_fits(\"output/IllustrisTNG_id11_snap99_stars_subsetTrue.fits\") #if you use NIHAO, you have to insert the NIHAO fits file" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "cube.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "cube.info()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "cube.primary_header" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#NBVAL_SKIP\n", - "import matplotlib.pyplot as plt\n", - "\n", - "image1 = cube[0,:,:]\n", - "\n", - "plt.figure()\n", - "image1.plot(colorbar='v', title = '$\\lambda$ = %.1f (%s)' %(cube.wave.coord(1000), cube.wave.unit))\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/rubix_pipeline_stepwise.ipynb b/notebooks/rubix_pipeline_stepwise.ipynb index e13ac78..80c17f9 100644 --- a/notebooks/rubix_pipeline_stepwise.ipynb +++ b/notebooks/rubix_pipeline_stepwise.ipynb @@ -242,66 +242,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 6: Reshape data\n", + "## Step 6: Data cube calculation\n", "\n", - "At the moment we have to reshape the rubix data that we can split the data on multiple GPUs. We plan to move from pmap to shard_map. Then this step should not be necessary any more. This step has purely computational reason and no physics motivated reason." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.data import get_reshape_data\n", - "reshape_data = get_reshape_data(config)\n", - "\n", - "rubixdata = reshape_data(rubixdata)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 7: Spectra calculation\n", + "This is the heart of the `pipeline`. Now we do the lookup for the spectrum for each stellar particle. For the simple stellar population model by `BruzualCharlot2003`, each stellar particle gets a spectrum assigned based on its age and metallicity.\n", "\n", - "This is the heart of the `pipeline`. Now we do the lookup for the spectrum for each stellar particle. For the simple stellar population model by `BruzualCharlot2003`, each stellar particle gets a spectrum assigned based on its age and metallicity. In the plot we can see that the spectrum differs for different stellar particles." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.ifu import get_calculate_spectra\n", - "calcultae_spectra = get_calculate_spectra(config)\n", + "We scale the stellar particle spectra by its mass. The stellar spectra have to be scaled by the stellar mass. Later heavier stellar particles should contribute more to the spectrum in a spaxel than lighter stellar particles.\n", "\n", - "rubixdata = calcultae_spectra(rubixdata)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "import jax.numpy as jnp\n", + "The stellar particles are not at rest and therefore the emitted light is doppler shifted with respect to the observer. Before adding all stellar spectra in each spaxel, we dopplershift the spectra according to their particle velocity and we resample the spectra to the wavelength grid of the observing instrument.\n", "\n", - "plt.plot(jnp.arange(len(rubixdata.stars.spectra[0][0][:])), rubixdata.stars.spectra[0][0][:])\n", - "plt.plot(jnp.arange(len(rubixdata.stars.spectra[0][0][:])), rubixdata.stars.spectra[0][1][:])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 8: Scaling by mass\n", + "Each stellar particle falls into one spaxel in the datacube. We ad the stellar particles spectra contribution to the according spaxel in the datacube. We do these steps for all atellar particles.\n", "\n", - "The stellar spectra have to be scaled by the stellar mass. Later heavier stellar particles should contribute more to the spectrum in a spaxel than lighter stellar particles." + "The first plot shows the spectra for two different spaxels.\n", + "The second plot shows the spatial dimension of the `datacube`, where we summed over the wavelength dimension." ] }, { @@ -311,32 +263,10 @@ "outputs": [], "source": [ "# NBVAL_SKIP\n", - "from rubix.core.ifu import get_scale_spectrum_by_mass\n", - "scale_spectrum_by_mass = get_scale_spectrum_by_mass(config)\n", - "\n", - "rubixdata = scale_spectrum_by_mass(rubixdata)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 9: Doppler shifting and resampling\n", + "from rubix.core.ifu import get_calculate_datacube_particlewise\n", + "calculate_datacube_particlewise = get_calculate_datacube_particlewise(config)\n", "\n", - "The stellar particles are not at rest and therefore the emitted light is doppler shifted with respect to the observer. Before adding all stellar spectra in each spaxel, we dopplershift the spectra according to their particle velocity and we resample the spectra to the wavelength grid of the observing instrument." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.ifu import get_doppler_shift_and_resampling\n", - "doppler_shift_and_resampling = get_doppler_shift_and_resampling(config)\n", - "\n", - "rubixdata = doppler_shift_and_resampling(rubixdata)" + "rubixdata = calculate_datacube_particlewise(rubixdata)" ] }, { @@ -352,32 +282,10 @@ "\n", "wave = pipe.telescope.wave_seq\n", "print(wave)\n", - "print(rubixdata.stars.spectra[0][0][:])\n", - "\n", - "plt.plot(wave, rubixdata.stars.spectra[0][0][:])\n", - "plt.plot(wave, rubixdata.stars.spectra[0][1][:])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 10: Datacube\n", - "\n", - "Now we can add all stellar spectra that contribute to one spaxel and get the IFU datacube. The plot shows the spatial dimension of the `datacube`, where we summed over the wavelength dimension." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.ifu import get_calculate_datacube\n", - "calculate_datacube = get_calculate_datacube(config)\n", + "print(rubixdata.stars.datacube[0][0][:])\n", "\n", - "rubixdata = calculate_datacube(rubixdata)" + "plt.plot(wave, rubixdata.stars.datacube[12][12][:])\n", + "plt.plot(wave, rubixdata.stars.datacube[10][5][:])" ] }, { @@ -396,7 +304,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 11: PSF\n", + "## Step 7: PSF\n", "\n", "The instrument and the earth athmosphere affect the spatial resolution of the observation data and smooth in spatial dimention. To take this effect into account we convolve our datacube with a point spread function (PSF)." ] @@ -434,14 +342,14 @@ "source": [ "# NBVAL_SKIP\n", "plt.plot(wave, datacube[12,12,:])\n", - "plt.plot(wave, datacube[0,0,:])" + "plt.plot(wave, datacube[10,5,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 11: LSF\n", + "## Step 8: LSF\n", "\n", "The instrument affects the spectral resolution of the observation data and smooth in spectral dimention. To take this effect into account we convolve our datacube with a line spread function (LSF)." ] @@ -459,14 +367,14 @@ "rubixdata = convolve_lsf(rubixdata)\n", "\n", "plt.plot(wave, rubixdata.stars.datacube[12,12,:])\n", - "plt.plot(wave, rubixdata.stars.datacube[0,0,:])" + "plt.plot(wave, rubixdata.stars.datacube[10,5,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 12: Noise\n", + "## Step 9: Noise\n", "\n", "Observational data are never noise-free. We apply noise to our mock-datacube to mimic real measurements." ] @@ -496,7 +404,7 @@ "source": [ "# NBVAL_SKIP\n", "plt.plot(wave, rubixdata.stars.datacube[12,12,:])\n", - "plt.plot(wave, rubixdata.stars.datacube[0,0,:])" + "plt.plot(wave, rubixdata.stars.datacube[10,5,:])" ] }, { @@ -507,11 +415,64 @@ "\n", "Congratulations, you have now created step by step your own mock-observed IFU datacube! Now enjoy playing around with the RUBIX pipeline and enjoy doing amazing science with RUBIX :)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Store datacube in a fits file with header\n", + "\n", + "Keep in mind that this it the luminosity cube. If you want to have a flux cube, you have to convert it. You can do this with the `rubix.spectra.ifu.convert_luminoisty_to_flux` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.fits import store_fits\n", + "\n", + "store_fits(config, rubixdata.stars.datacube, \"./output/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualisation of the datacube\n", + "\n", + "We show, how you can visualize the datacube and see the image and spectra and explore the datacube. We show our own build tool to visualize the datacube and second we provide the opportunity to load the datacube with the Cubeviz module from jdaviz.\n", + "\n", + "`visualize_rubix` uses mpdaf to load the datacube. This is a package specialized to load MUSE datacubes. The function will thisplay you on the left an image collapsed along the wavelength and on the right a spectrum for a certain pixel or aperture. \n", + "\n", + "Explanation of the sliders:\n", + "- Waveindex: Waveindex, which wavelength slice is plotted in the image.\n", + "- Wavelengthrange: Range in wavelength that is collapsed to the image.\n", + "- X Pixel: X coordinate of the displayed spectrum and x coordinate of the center of the aperture.\n", + "- Y Pixel: Y coordinate of the displayed spectrum and y coordinate of the center of the aperture.\n", + "- Radius: size of the circular aperture in pixels. If this value is set to zerro, only the spaxel specified in the x and y pixel is considered for the spectrum plot.\n", + "\n", + "Now you can explore your datacube with the sliders!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "from rubix.core.visualisation import visualize_rubix\n", + "\n", + "visualize_rubix(\"./output/IllustrisTNG_id11_snap99_stars_subsetTrue.fits\")" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "rubix", "language": "python", "name": "python3" }, @@ -525,7 +486,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.12.8" } }, "nbformat": 4, diff --git a/notebooks/spaxel_assignment.ipynb b/notebooks/spaxel_assignment.ipynb deleted file mode 100644 index 8798c80..0000000 --- a/notebooks/spaxel_assignment.ipynb +++ /dev/null @@ -1,166 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Spaxel assignment\n", - "\n", - "This notebook shows the principle, how stellar particles or gas particles are assigned to the different spaxels. We show this here for squared spaxels.\n", - "\n", - "We start with two particles and assign them to the spatial matching spaxels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.telescope.utils import square_spaxel_assignment\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.colors import ListedColormap\n", - "from jaxtyping import Float, Array \n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "# Define the particle coordinates\n", - "coords = np.array([[0.5, 1.5], [2.5, 3.5]])\n", - "print(\"coords\", coords)\n", - "\n", - "# Define the spatial bin edges\n", - "spatial_bin_edges = np.array([0, 1, 2, 3, 4])\n", - "\n", - "# Compute the pixel assignments\n", - "pixel_assignments = square_spaxel_assignment(coords, spatial_bin_edges)\n", - "\n", - "# Create a discrete colormap\n", - "max_assignment = np.max(pixel_assignments)\n", - "colors = plt.cm.viridis(np.linspace(0, 1, int(max_assignment) + 1))\n", - "cmap = ListedColormap(colors)\n", - "\n", - "# Plot the results\n", - "plt.figure(figsize=(10, 5))\n", - "\n", - "# Plotting the particles with labels\n", - "plt.subplot(1, 2, 1)\n", - "scatter = plt.scatter(coords[:, 0], coords[:, 1], c=pixel_assignments, cmap=cmap, edgecolor='k')\n", - "plt.colorbar(scatter, ticks=np.arange(0, max_assignment + 1))\n", - "plt.title('Particle Coordinates and Pixel Assignments')\n", - "plt.xlabel('X Coordinate')\n", - "plt.ylabel('Y Coordinate')\n", - "plt.xlim(spatial_bin_edges[0], spatial_bin_edges[-1])\n", - "plt.ylim(spatial_bin_edges[0], spatial_bin_edges[-1])\n", - "\n", - "\n", - "# Label each point with its pixel index\n", - "for i, (x, y) in enumerate(coords[:, :2]):\n", - " plt.text(x, y, str(pixel_assignments[i]), color='red', fontsize=8)\n", - "\n", - "#create the bins\n", - "for edge in spatial_bin_edges:\n", - " plt.axvline(edge, color='k', linestyle='--')\n", - " plt.axhline(edge, color='k', linestyle='--')\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we do the same with a lot more random points." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "#create random data\n", - "n_stars = 1000\n", - "coords = np.random.normal(2, 0.5, (n_stars, 3))\n", - "coords = jnp.array(coords)\n", - "\n", - "# Compute the pixel assignments\n", - "pixel_assignments = square_spaxel_assignment(coords, spatial_bin_edges)\n", - "\n", - "# Plot the results\n", - "plt.figure(figsize=(10, 5))\n", - "\n", - "\n", - "# Plot the results\n", - "plt.figure(figsize=(10, 5))\n", - "\n", - "# Plotting the particles with labels\n", - "plt.subplot(1, 2, 1)\n", - "scatter = plt.scatter(coords[:, 0], coords[:, 1], c=pixel_assignments, cmap=cmap, edgecolor='k')\n", - "plt.colorbar(scatter, ticks=np.arange(0, max_assignment + 1))\n", - "plt.title('Particle Coordinates and Pixel Assignments')\n", - "plt.xlabel('X Coordinate')\n", - "plt.ylabel('Y Coordinate')\n", - "plt.xlim(spatial_bin_edges[0], spatial_bin_edges[-1])\n", - "plt.ylim(spatial_bin_edges[0], spatial_bin_edges[-1])\n", - "\n", - "\n", - "#create the bins\n", - "for edge in spatial_bin_edges:\n", - " plt.axvline(edge, color='k', linestyle='--')\n", - " plt.axhline(edge, color='k', linestyle='--')\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And the last plot shows how many particles fall in each spaxel." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "image = np.zeros((len(spatial_bin_edges) - 1, len(spatial_bin_edges) - 1))\n", - "\n", - "# Count the number of particles in each pixel\n", - "for i in range(len(spatial_bin_edges) - 1):\n", - " for j in range(len(spatial_bin_edges) - 1):\n", - " image[i, j] = np.sum(pixel_assignments == (i + (len(spatial_bin_edges) - 1) * j))\n", - " \n", - " \n", - "plt.imshow(image, cmap='viridis', origin='lower')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rubix", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/ssp_template_fsps.ipynb b/notebooks/ssp_template_fsps.ipynb index 80208ec..0a01568 100644 --- a/notebooks/ssp_template_fsps.ipynb +++ b/notebooks/ssp_template_fsps.ipynb @@ -54,6 +54,52 @@ "ssp" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "config = {\n", + " \"name\": \"FSPS (Conroy et al. 2009)\",\n", + " # more information on how those models are synthesized: https://github.com/cconroy20/fsps\n", + " # and https://dfm.io/python-fsps/current/\n", + " \"format\": \"fsps\", # Format of the template\n", + " \"source\": \"load_from_file\", # the source can be \"load_from_file\" or \"rerun_from_scratch\"\n", + " # \"load_from_file\" is the default and loads the template from a pre-existing file in h5 format specified by \"file_name\"\n", + " # if that file is not found, it will automatically run fsps and save the output to disk in h5 format under the \"file_name\" given.\n", + " # \"rerun_from_scratch\" # note: this is just meant for the case in which you really want to rerun your template library.\n", + " # You should be aware that fsps templates will silently be overwritten by this. Use with caution.\n", + " \"file_name\": \"fsps.h5\", # File name of the template, stored in templates directory\n", + " # Define the Fields in the template and their units\n", + " # This is used to convert them to the required units\n", + " \"fields\":{ # Fields in the template and their units\n", + " # Name defines the name of the key stored in the hdf5 file\n", + " \"age\":{\n", + " \"name\": \"age\",\n", + " \"units\": \"Gyr\", # Age of the template\n", + " \"in_log\": True # If the field is stored in log scale\n", + " },\n", + " \"metallicity\":{\n", + " \"name\": \"metallicity\",\n", + " \"units\": \"\", # Metallicity of the template\n", + " \"in_log\": True # If the field is stored in log scale\n", + " },\n", + " \"wavelength\":{\n", + " \"name\": \"wavelength\",\n", + " \"units\": \"Angstrom\", # Wavelength of the template\n", + " \"in_log\": False # If the field is stored in log scale\n", + " },\n", + " \"flux\":{\n", + " \"name\": \"flux\",\n", + " \"units\": \"Lsun/Angstrom\", # Luminosity of the template as per pyFSPS documentation\n", + " \"in_log\": False # If the field is stored in log scale\n", + " }\n", + " }\n", + "}" + ] + }, { "cell_type": "code", "execution_count": null, @@ -230,13 +276,6 @@ "# NBVAL_SKIP\n", "ssp.wavelength.shape == ssp2.wavelength.shape" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -255,7 +294,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.8" } }, "nbformat": 4, diff --git a/notebooks/visualisation.ipynb b/notebooks/visualisation.ipynb deleted file mode 100644 index ea67da9..0000000 --- a/notebooks/visualisation.ipynb +++ /dev/null @@ -1,78 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Visualisation of the datacube\n", - "\n", - "In this notebook, we show, how you can visualize the datacube and see the image and spectra and explore the datacube. First we show our own build tool to visualize the datacube and second we provide the opportunity to load the datacube with the Cubeviz module from jdaviz.\n", - "\n", - "`visualize_rubix` uses mpdaf to load the datacube. This is a package specialized to load MUSE datacubes. The function will thisplay you on the left an image collapsed along the wavelength and on the right a spectrum for a certain pixel or aperture. \n", - "\n", - "Explanation of the sliders:\n", - "- Waveindex: Waveindex, which wavelength slice is plortted in the image.\n", - "- Wavelengthrange: Range in wavelength that is collapsed to the image.\n", - "- X Pixel: X coordinate of the displayed spectrum and x coordinate of the center of the aperture.\n", - "- Y Pixel: Y coordinate of the displayed spectrum and y coordinate of the center of the aperture.\n", - "- Radius: size of the circular aperture in pixels. If this value is set to zerro, only the spaxel specified in the x and y pixel is considered for the spectrum plot.\n", - "\n", - "Now you can explore your datacube with the sliders!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.visualisation import visualize_rubix\n", - "\n", - "visualize_rubix(\"./output/IllustrisTNGid11_stars2_sn100.fits\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualisation with Cubeviz\n", - "\n", - "Cubeviz is a common tool to explore datacubes. We integrated Cubeviz into RUBIX that you can easily use Cubeviz for fits files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NBVAL_SKIP\n", - "from rubix.core.visualisation import visualize_cubeviz\n", - "\n", - "visualize_cubeviz(\"./output/IllustrisTNGid11_stars2_sn100.fits\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rubix", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index b5cd7e9..655bb60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,10 @@ build-backend = "setuptools.build_meta" # single source of truth for rubix's version [project] name = "rubix" -description = "Add short description here" +description = "A toolkit for simulating and analyzing integral field spectroscopic data cubes of astronomical sources." readme = "README.md" +authors = [{ name = "AstroAI-Lab", email = "astroai@iwr.uni-heidelberg.de" }, { name = "Tobias Buck", email = "tobias.buck@iwr.uni-heidelberg.de" }, + { name = "Ufuk Cakir", email = "ufukcakir@robots.ox.ac.uk" }, { name = "Anna Lena Schaible", email = "annalena.schaible@iwr.uni-heidelberg.de"}] maintainers = [{ name = "AstroAI-Lab", email = "astroai@iwr.uni-heidelberg.de" }] dynamic = ["version"] requires-python = ">=3.9" @@ -21,6 +23,9 @@ classifiers = [ "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] + +urls = { Homepage = "https://astro-rubix.web.app", Repository = "https://github.com/AstroAI-Lab/rubix", Issues = "https://github.com/AstroAI-Lab/rubix/issues" } + dependencies = [ "requests", "requests-mock", diff --git a/rubix/config/pipeline_config.yml b/rubix/config/pipeline_config.yml index 477b1fd..1d30fa0 100644 --- a/rubix/config/pipeline_config.yml +++ b/rubix/config/pipeline_config.yml @@ -53,9 +53,14 @@ calc_dusty_ifu: depends_on: filter_particles args: [] kwargs: {} + calculate_extinction: + name: calculate_extinction + depends_on: spaxel_assignment + args: [] + kwargs: {} calculate_dusty_datacube_particlewise: name: calculate_dusty_datacube_particlewise - depends_on: spaxel_assignment + depends_on: calculate_extinction args: [] kwargs: {} convolve_psf: diff --git a/rubix/core/data.py b/rubix/core/data.py index 25449c8..d9baf48 100644 --- a/rubix/core/data.py +++ b/rubix/core/data.py @@ -104,6 +104,7 @@ class StarsData: pixel_assignment: Optional[Any] = None spatial_bin_edges: Optional[Any] = None mask: Optional[Any] = None + extinction: Optional[Any] = None spectra: Optional[Any] = None datacube: Optional[Any] = None @@ -138,6 +139,7 @@ def tree_flatten(self): self.pixel_assignment, self.spatial_bin_edges, self.mask, + self.extinction, self.spectra, self.datacube, ) diff --git a/rubix/core/dust.py b/rubix/core/dust.py index fe15fcc..90545a1 100644 --- a/rubix/core/dust.py +++ b/rubix/core/dust.py @@ -56,7 +56,7 @@ def calculate_extinction(rubixdata: RubixData) -> RubixData: """Apply the dust extinction to the spaxel data.""" logger.info("Applying dust extinction to the spaxel data...") - rubixdata.stars.spectra = apply_spaxel_extinction( + rubixdata.stars.extinction = apply_spaxel_extinction( config, rubixdata, wavelength, n_spaxels, spaxel_area ) diff --git a/rubix/core/ifu.py b/rubix/core/ifu.py index daa13b2..90249da 100644 --- a/rubix/core/ifu.py +++ b/rubix/core/ifu.py @@ -9,6 +9,7 @@ from rubix import config as rubix_config from rubix.core.data import GasData, StarsData from rubix.logger import get_logger +from rubix.spectra.dust.extinction_models import * from rubix.spectra.ifu import ( _velocity_doppler_shift_single, cosmological_doppler_shift, @@ -106,3 +107,110 @@ def body(cube, i): return rubixdata return calculate_datacube_particlewise + + +@jaxtyped(typechecker=typechecker) +def get_calculate_dusty_datacube_particlewise(config: dict) -> Callable: + """ + Create a function that calculates the datacube for the stars component + of a RubixData object on a per-particle basis. First, it looks up the SSP + spectrum for each star based on its age and metallicity, scales it by the + star's mass, applies a Doppler shift based on the star's velocity, resamples + the spectrum onto the telescope's wavelength grid, and finally accumulates + the resulting spectra into the appropriate pixels of the datacube. + + Args: + config (dict): Configuration dictionary containing telescope and galaxy + parameters. + + Returns: + Callable: A function that takes a RubixData object and returns it with + the datacube calculated and added to the stars component. + """ + logger = get_logger(config.get("logger", None)) + telescope = get_telescope(config) + ns = int(telescope.sbin) + nseg = ns * ns + target_wave = telescope.wave_seq # (n_wave_tel,) + + # prepare SSP lookup + lookup_ssp = get_lookup_interpolation(config) + + # prepare Doppler machinery + velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] + z_obs = config["galaxy"]["dist_z"] + ssp_model = get_ssp(config) + ssp_wave0 = cosmological_doppler_shift( + z=z_obs, wavelength=ssp_model.wavelength + ) # (n_wave_ssp,) + + @jaxtyped(typechecker=typechecker) + def calculate_dusty_datacube_particlewise(rubixdata: RubixData) -> RubixData: + logger.info("Calculating Data Cube (combined per‐particle)…") + + stars = rubixdata.stars + ages = stars.age # (n_stars,) + metallicity = stars.metallicity # (n_stars,) + masses = stars.mass # (n_stars,) + velocities = stars.velocity # (n_stars,) + pix_idx = stars.pixel_assignment # (n_stars,) + Av_array = stars.extinction # (n_stars, n_wave_ssp) + nstar = ages.shape[0] + + # dust model + ext_model = config["ssp"]["dust"]["extinction_model"] + Rv = config["ssp"]["dust"]["Rv"] + # Dynamically choose the extinction model based on the string name + if ext_model not in RV_MODELS: + raise ValueError( + f"Extinction model '{ext_model}' is not available. Choose from {RV_MODELS}." + ) + + ext_model_class = Rv_model_dict[ext_model] + ext = ext_model_class(Rv=Rv) + + # init flat cube: (nseg, n_wave_tel) + init_cube = jnp.zeros((nseg, target_wave.shape[-1])) + + def body(cube, i): + age_i = ages[i] # scalar + Z_i = metallicity[i] # scalar + m_i = masses[i] # scalar + v_i = velocities[i] # scalar or vector + pix_i = pix_idx[i].astype(jnp.int32) + av_i = Av_array[i] # (n_wave_ssp,) + + # 1) SSP lookup + spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,) + # 2) scale by mass + spec_mass = spec_ssp * m_i # (n_wave_ssp,) + # 3) Doppler‐shift wavelengths + shifted_wave = _velocity_doppler_shift_single( + wavelength=ssp_wave0, + velocity=v_i, + direction=velocity_direction, + ) # (n_wave_ssp,) + # 4) resample onto telescope grid + spec_tel = resample_spectrum( + initial_spectrum=spec_mass, + initial_wavelength=shifted_wave, + target_wavelength=target_wave, + ) # (n_wave_tel,) + + # apply extinction + extinction = ext.extinguish(target_wave / 1e4, av_i) + + spec_extincted = spec_tel * extinction # (n_wave_tel,) + + # 5) accumulate + cube = cube.at[pix_i].add(spec_extincted) + return cube, None + + cube_flat, _ = lax.scan(body, init_cube, jnp.arange(nstar, dtype=jnp.int32)) + + cube_3d = cube_flat.reshape(ns, ns, -1) + setattr(rubixdata.stars, "datacube", cube_3d) + logger.debug(f"Datacube shape: {cube_3d.shape}") + return rubixdata + + return calculate_dusty_datacube_particlewise diff --git a/rubix/core/pipeline.py b/rubix/core/pipeline.py index 90a0a46..8f24406 100644 --- a/rubix/core/pipeline.py +++ b/rubix/core/pipeline.py @@ -26,11 +26,13 @@ GasData, RubixData, StarsData, - get_reshape_data, get_rubix_data, ) from .dust import get_extinction -from .ifu import get_calculate_datacube_particlewise +from .ifu import ( + get_calculate_datacube_particlewise, + get_calculate_dusty_datacube_particlewise, +) from .lsf import get_convolve_lsf from .noise import get_apply_noise from .psf import get_convolve_psf @@ -123,10 +125,13 @@ def _get_pipeline_functions(self) -> list: rotate_galaxy = get_galaxy_rotation(self.user_config) filter_particles = get_filter_particles(self.user_config) spaxel_assignment = get_spaxel_assignment(self.user_config) - apply_extinction = get_extinction(self.user_config) + calculate_extinction = get_extinction(self.user_config) calculate_datacube_particlewise = get_calculate_datacube_particlewise( self.user_config ) + calculate_dusty_datacube_particlewise = ( + get_calculate_dusty_datacube_particlewise(self.user_config) + ) convolve_psf = get_convolve_psf(self.user_config) convolve_lsf = get_convolve_lsf(self.user_config) apply_noise = get_apply_noise(self.user_config) @@ -135,15 +140,16 @@ def _get_pipeline_functions(self) -> list: rotate_galaxy, filter_particles, spaxel_assignment, - apply_extinction, + calculate_extinction, calculate_datacube_particlewise, + calculate_dusty_datacube_particlewise, convolve_psf, convolve_lsf, apply_noise, ] return functions - def run_sharded(self, inputdata, devices): + def run_sharded(self, inputdata, devices=None): """ Runs the pipeline on sharded input data in parallel using jax.shard_map. It splits the particle arrays (e.g. under stars and gas) into shards, runs @@ -174,8 +180,11 @@ def run_sharded(self, inputdata, devices): self.logger.info("Compiling the expressions...") self.func = self._pipeline.compile_expression() - # devices = jax.devices() - num_devices = len(devices) + if devices is None: + devices = jax.devices() + num_devices = len(devices) + else: + num_devices = len(devices) self.logger.info("Number of devices: %d", num_devices) mesh = Mesh(devices, axis_names=("data",)) diff --git a/rubix/galaxy/alignment.py b/rubix/galaxy/alignment.py index 2b2e104..9fe47cf 100644 --- a/rubix/galaxy/alignment.py +++ b/rubix/galaxy/alignment.py @@ -231,8 +231,8 @@ def apply_rotation( @jaxtyped(typechecker=typechecker) def rotate_galaxy( - positions: Float[Array, "* 3"], - velocities: Float[Array, "* 3"], + positions: Float[Array, "..."], + velocities: Float[Array, "..."], positions_stars: Float[Array, "..."], masses_stars: Float[Array, "..."], halfmass_radius: Union[Float[Array, "..."], float], diff --git a/rubix/spectra/dust/dust_extinction.py b/rubix/spectra/dust/dust_extinction.py index f74ef85..6008ee5 100644 --- a/rubix/spectra/dust/dust_extinction.py +++ b/rubix/spectra/dust/dust_extinction.py @@ -172,7 +172,7 @@ def apply_spaxel_extinction( wavelength: Float[Array, "n_wave"], n_spaxels: int, spaxel_area: Float[Array, "..."], -) -> Float[Array, "1 n_star n_wave"]: +) -> Float[Array, "..."]: r""" Calculate the extinction for each star in the spaxel and apply dust extinction to it's associated SSP. @@ -236,10 +236,10 @@ def apply_spaxel_extinction( # sort the arrays by pixel assignment and z position gas_sorted_idx = jnp.lexsort( - (rubixdata.gas.coords[0, :, 2], rubixdata.gas.pixel_assignment[0]) + (rubixdata.gas.coords[:, 2], rubixdata.gas.pixel_assignment) ) stars_sorted_idx = jnp.lexsort( - (rubixdata.stars.coords[0, :, 2], rubixdata.stars.pixel_assignment[0]) + (rubixdata.stars.coords[:, 2], rubixdata.stars.pixel_assignment) ) # determine the segment boundaries @@ -248,7 +248,7 @@ def apply_spaxel_extinction( gas_segment_boundaries = jnp.concatenate( [ jnp.searchsorted( - rubixdata.gas.pixel_assignment[0][gas_sorted_idx], + rubixdata.gas.pixel_assignment[gas_sorted_idx], spaxel_IDs, side="left", ), @@ -258,7 +258,7 @@ def apply_spaxel_extinction( stars_segment_boundaries = jnp.concatenate( [ jnp.searchsorted( - rubixdata.stars.pixel_assignment[0][stars_sorted_idx], + rubixdata.stars.pixel_assignment[stars_sorted_idx], spaxel_IDs, side="left", ), @@ -277,14 +277,14 @@ def apply_spaxel_extinction( # with this we can calculate the dust mass # we need to correct by factor of 16 for the difference in atomic mass log_OH = 12 + jnp.log10( - rubixdata.gas.metals[0, :, 4] / (16 * rubixdata.gas.metals[0, :, 0]) + rubixdata.gas.metals[:, 4] / (16 * rubixdata.gas.metals[:, 0]) ) dust_to_gas_ratio = calculate_dust_to_gas_ratio( log_OH, rubix_config["ssp"]["dust"]["dust_to_gas_model"], rubix_config["ssp"]["dust"]["Xco"], ) - dust_mass = rubixdata.gas.mass[0] * dust_to_gas_ratio + dust_mass = rubixdata.gas.mass * dust_to_gas_ratio dust_grain_density = config["ssp"]["dust"]["dust_grain_density"] extinction = ( @@ -293,7 +293,7 @@ def apply_spaxel_extinction( ) # Preallocate arrays - Av_array = jnp.zeros_like(rubixdata.stars.mass[0]) + Av_array = jnp.zeros_like(rubixdata.stars.mass) def body_fn(carry, idx): Av_array = carry @@ -319,14 +319,14 @@ def body_fn(carry, idx): cumulative_dust_mass = jnp.cumsum(extinction * gas_mask) * gas_mask # resort the arrays as jnp.interp requires sorted arrays and our approach of using masks to select the segment is not compatible with this requirement. - xp_arr = rubixdata.gas.coords[0, :, 2][gas_sorted_idx] * gas_mask2 + xp_arr = rubixdata.gas.coords[:, 2][gas_sorted_idx] * gas_mask2 fp_arr = cumulative_dust_mass xp_arr, fp_arr = jax.lax.sort_key_val(xp_arr, fp_arr) interpolated_column_density = ( jnp.interp( - rubixdata.stars.coords[0, :, 2][stars_sorted_idx], + rubixdata.stars.coords[:, 2][stars_sorted_idx], xp_arr, fp_arr, left="extrapolate", @@ -350,9 +350,10 @@ def body_fn(carry, idx): # undo the sorting of the stars undo_sort = jnp.argsort(stars_sorted_idx) - extinction = extinction[undo_sort] + Av_array = Av_array[undo_sort] + # extinction = extinction[undo_sort] # Apply the extinction to the SSP fluxes - extincted_ssp_template_fluxes = rubixdata.stars.spectra * extinction + # extincted_ssp_template_fluxes = rubixdata.stars.spectra * extinction - return extincted_ssp_template_fluxes + return Av_array diff --git a/tests/test_dust_extinction.py b/tests/test_dust_extinction.py index eccbad2..4e4b253 100644 --- a/tests/test_dust_extinction.py +++ b/tests/test_dust_extinction.py @@ -37,19 +37,19 @@ def mock_config(): def mock_rubixdata(): class MockGas: def __init__(self): - self.coords = jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) + self.coords = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) self.pixel_assignment = jnp.array([[0, 1]]) self.metals = jnp.array( - [[[0.01, 0.02, 0.03, 0.04, 0.05], [0.06, 0.07, 0.08, 0.09, 0.1]]] + [[0.01, 0.02, 0.03, 0.04, 0.05], [0.06, 0.07, 0.08, 0.09, 0.1]] ) - self.mass = jnp.array([[1.0, 2.0]]) + self.mass = jnp.array([1.0, 2.0]) class MockStars: def __init__(self): - self.coords = jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) - self.pixel_assignment = jnp.array([[0, 1]]) - self.mass = jnp.array([[1.0, 2.0]]) - self.spectra = jnp.array([[[1.0, 2.0], [3.0, 4.0]]]) + self.coords = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + self.pixel_assignment = jnp.array([0, 1]) + self.mass = jnp.array([1.0, 2.0]) + self.spectra = jnp.array([[1.0, 2.0], [3.0, 4.0]]) class MockRubixData(RubixData): def __init__(self): @@ -68,7 +68,7 @@ def test_spaxel_extinction_Cardelli(mock_config, mock_rubixdata): mock_config, mock_rubixdata, wavelength, n_spaxels, spaxel_area ) - assert result.shape == (1, 2, 2) + assert result.shape == (2,) assert jnp.all(result >= 0) @@ -102,7 +102,7 @@ def test_spaxel_extinction_Gordon(mock_config, mock_rubixdata): mock_config, mock_rubixdata, wavelength, n_spaxels, spaxel_area ) - assert result.shape == (1, 2, 2) + assert result.shape == (2,) assert jnp.all(result >= 0) @@ -131,19 +131,19 @@ def mock_config(): def mock_rubixdata(): class MockGas: def __init__(self): - self.coords = jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) - self.pixel_assignment = jnp.array([[0, 1]]) + self.coords = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + self.pixel_assignment = jnp.array([0, 1]) self.metals = jnp.array( - [[[0.01, 0.02, 0.03, 0.04, 0.05], [0.06, 0.07, 0.08, 0.09, 0.1]]] + [[0.01, 0.02, 0.03, 0.04, 0.05], [0.06, 0.07, 0.08, 0.09, 0.1]] ) - self.mass = jnp.array([[1.0, 2.0]]) + self.mass = jnp.array([1.0, 2.0]) class MockStars: def __init__(self): - self.coords = jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) - self.pixel_assignment = jnp.array([[0, 1]]) - self.mass = jnp.array([[1.0, 2.0]]) - self.spectra = jnp.array([[[1.0, 2.0], [3.0, 4.0]]]) + self.coords = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + self.pixel_assignment = jnp.array([0, 1]) + self.mass = jnp.array([1.0, 2.0]) + self.spectra = jnp.array([[1.0, 2.0], [3.0, 4.0]]) class MockRubixData(RubixData): def __init__(self): diff --git a/tests/test_telescope_factory.py b/tests/test_telescope_factory.py index 044159b..81a9caf 100644 --- a/tests/test_telescope_factory.py +++ b/tests/test_telescope_factory.py @@ -12,9 +12,7 @@ SQUARE_APERTURE, ) from rubix.telescope.base import BaseTelescope -from rubix.telescope.factory import ( - TelescopeFactory, -) +from rubix.telescope.factory import TelescopeFactory jax.config.update("jax_platform_name", "cpu")