Skip to content

Conversation

@anschaible
Copy link
Collaborator

When calculating the datacube for many stellar particles (in the order of hundred thousands), we had memory issues. The final datacube had in a lot of spaxels negative flux.

Wer thought this is solved by switching from pmap to shard_map. However, the error still occurred. The speculation was that we got an over/underflow at the point, when the pipeline was asigning the spectra to the stellar particles, because rubix hold all spectra for the individual particles in memory, which lead to a spike in memory and on GPUs the code even failed because it run out of memory.

Therefore this branch changes the spectra assignment and datacube calculation. Rubix is now looking up the spectrum for one particle, mass weights it, doppler shift it and resamples it and then adds the spectrum at the spaxel in the datacube according to the spaxel_assignment. We lax.scan over all particles. Testing on the MaStar ssp template this removes our issue with negative flux. At the same time the computation time does not increase for this method. For comparison see notebook rubix_pipeline_single_function_shard_map_memory.ipynb

I open already the pull request to get feedback in an early stage of this major change in the code structure. Things still to do on this branch, before we move it to the main:

  • test if the memory issue is also resolved now for GPUs (so far only tested on CPUs)
  • fsps is still behaving strange, but already for very few particles (e.g. 100), have a look into the template, what is going wrong there
  • fix pytests
  • clean the code by removing old functions, once we agreed on a code version

Left: old method, right new method with lax.scan using the MaStar template

Bildschirmfoto 2025-06-03 um 09 47 29 Bildschirmfoto 2025-06-03 um 09 47 15

anschaible and others added 30 commits March 26, 2025 18:00
…, pading the input data, only typechecking for RubixData class has to be commented outbecause it gets in conflict with NamedSharding, now test on single and multiple GPUs
…, when directly adding to the cube, but hopefully more memory efficient, will be tested, as soon as jarvis is back online
Copy link
Collaborator

@TobiBu TobiBu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is still some cleaning to be done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this just for debugging and can it be deleted once we are about to merge? Or shall it stay here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, we can delete this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. is this only for testing during development or shall it stay forever?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, can be deleted

args: []
kwargs: {}

calc_ifu_memory:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be called calc_ifu_memory forever or do we rename once we are done developing this feature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should rename it, otherwise it could be confusing. I delete the original calc_ifu config and then we have to change it also in the notebooks

density: "rho"
temperature: "temp"
metallicity: "metals"
metals: "metals"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have metals and metallicity. what's the difference? there should be none, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Das ist in der config so, metallicity sollte vermutlich entfernt werden? metals wird aber nochmal in input_handler/pynbody per hand gesetzt, aber ja, da sollte man nochmal reinschauen, ob man metallicity vielleicht nicht ganz wegfallen lassen kann

else:
representationString.append(f"{k}: None")
return "\n\t".join(representationString)
# def __repr__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall those commented lines stay? or can they be removed? I think this was from williams experiments, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not work with the sharding to carry this through the pipeline, I remove it, as it is anyways commented out

Comment on lines 285 to 301
# if the particle number is not modulo the device number, we have to padd a few empty particles
# to make it work
# this is a bit of a hack, but it works
n = inputdata.stars.coords.shape[0]
pad = (num_devices - (n % num_devices)) % num_devices

if pad:
# pad along the first axis
inputdata.stars.coords = jnp.pad(inputdata.stars.coords, ((0, pad), (0, 0)))
inputdata.stars.velocity = jnp.pad(
inputdata.stars.velocity, ((0, pad), (0, 0))
)
inputdata.stars.mass = jnp.pad(inputdata.stars.mass, ((0, pad)))
inputdata.stars.age = jnp.pad(inputdata.stars.age, ((0, pad)))
inputdata.stars.metallicity = jnp.pad(
inputdata.stars.metallicity, ((0, pad))
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we gonna implement this?

local_cube = out_local.stars.datacube # shape (25,25,5994)
# in‐XLA all‐reduce across the "data" axis:
summed_cube = lax.psum(local_cube, axis_name="data")
return summed_cube # replicated on each device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really have to replicate on each device???

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, how it works otherwise. Feel free to change, this was the first thing working

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don´t quite get the logic behind this. wouldn't this allocate a lot of memory on each device? what happens on the devices afterwards that would need the result on all devices? could this be a source of memory issues?

I would propose to refactor the run_sharded function to support a more flexible approach here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this refactor happen or shall we open an issue for this?

getattr(self.sim, cls), fields[cls], units[cls], cls
)

# for cls in self.data:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need al those commented lines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it

print("Sample_inputs:")
for key in sample_inputs:
sample_inputs[key] = reshape_array(sample_inputs[key])
# sample_inputs[key] = reshape_array(sample_inputs[key])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this commented line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, removed it

@anschaible anschaible linked an issue Jul 4, 2025 that may be closed by this pull request
@MaHaWo
Copy link
Collaborator

MaHaWo commented Jul 17, 2025

could we maybe move the notebooks into a separate PR to reduce the size a little?

MaHaWo
MaHaWo previously requested changes Jul 22, 2025
Copy link
Collaborator

@MaHaWo MaHaWo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments here and there, but I would prefer if the run_sharded function could be refactored to allow for the passing of a device configuration and output device perhaps. Maybe the actual sharding could happen in a separate function too, to structure the code a little better.

@@ -0,0 +1,114 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn´t put test code like this file into the code base.

"source": [
"# NBVAL_SKIP\n",
"#import os\n",
"# os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe don't put in hardcoded paths anywhere. this is not executable by anyone other than yourself. rather leave it open and explain what to do here perhaps?

@@ -0,0 +1,377 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks to me a bit like it is experimental? In general, code that just tests functionality or exists for experimenting is not something that should be merged into main imho. Rather, put it into some other place or remove it again before the final version of the branch should be merged.

"met = inputdata.stars.metallicity\n",
"factor = 1\n",
"inputdata.stars.coords = jnp.concatenate([coords]*factor, axis=0)\n",
"inputdata.stars.velocity = jnp.concatenate([vel]*factor, axis=0)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dummy data? maybe add a markdown cell to make clear the purpose?

"source": [
"# NBVAL_SKIP\n",
"import jax.numpy as jnp\n",
"gpu_number = jnp.array([1, 2, 3, 4, 5, 6, 7])\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are all these hardcoded numbers? if these are performance experiments, that should not go here...

pynbody.analysis.angmom.faceon(halo.s)
ang_mom_vec = pynbody.analysis.angmom.ang_mom_vec(halo.s)
rotation_matrix = pynbody.analysis.angmom.calc_sideon_matrix(ang_mom_vec)
if not os.path.exists("./data"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this path general/robust enough? In general, it would be better to avoid hardcoding paths. Rather draw them from an environment variable or config ...

def handler_with_mock_data(mock_simulation, mock_config):
"""
with patch("pynbody.load", return_value=mock_simulation):
with patch("pynbody.analysis.angmom.faceon", return_value=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before, what is this multiline comment for?

expected_result = jnp.stack(
def test_get_calculate_datacube_particlewise():
# Setup config and telescope
config = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could make sense making these things into fixtures, as well as the other classes and configs that exist in this file.

# The cube should have nonzero values (sanity check)
assert jnp.any(output_cube != 0)

print("run_sharded output shape:", output_cube.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we remove these?

n_particles = num_devices if num_devices > 1 else 2 # At least two for sanity

# Mock input data
input_data = RubixData(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this conceivably be a fixture?

local_cube = out_local.stars.datacube # shape (25,25,5994)
# in‐XLA all‐reduce across the "data" axis:
summed_cube = lax.psum(local_cube, axis_name="data")
return summed_cube # replicated on each device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this refactor happen or shall we open an issue for this?

@TobiBu TobiBu merged commit d9041c9 into main Nov 10, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refactor Pipeline parallelization using jax.sharding

4 participants