-
Notifications
You must be signed in to change notification settings - Fork 3
Restructure pipelinefunctions #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…, pading the input data, only typechecking for RubixData class has to be commented outbecause it gets in conflict with NamedSharding, now test on single and multiple GPUs
for more information, see https://pre-commit.ci
…, when directly adding to the cube, but hopefully more memory efficient, will be tested, as soon as jarvis is back online
TobiBu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is still some cleaning to be done.
notebooks/debug_spectra_lookup.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this just for debugging and can it be deleted once we are about to merge? Or shall it stay here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, we can delete this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. is this only for testing during development or shall it stay forever?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, can be deleted
rubix/config/pipeline_config.yml
Outdated
| args: [] | ||
| kwargs: {} | ||
|
|
||
| calc_ifu_memory: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this be called calc_ifu_memory forever or do we rename once we are done developing this feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should rename it, otherwise it could be confusing. I delete the original calc_ifu config and then we have to change it also in the notebooks
| density: "rho" | ||
| temperature: "temp" | ||
| metallicity: "metals" | ||
| metals: "metals" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have metals and metallicity. what's the difference? there should be none, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Das ist in der config so, metallicity sollte vermutlich entfernt werden? metals wird aber nochmal in input_handler/pynbody per hand gesetzt, aber ja, da sollte man nochmal reinschauen, ob man metallicity vielleicht nicht ganz wegfallen lassen kann
rubix/core/data.py
Outdated
| else: | ||
| representationString.append(f"{k}: None") | ||
| return "\n\t".join(representationString) | ||
| # def __repr__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall those commented lines stay? or can they be removed? I think this was from williams experiments, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not work with the sharding to carry this through the pipeline, I remove it, as it is anyways commented out
rubix/core/pipeline.py
Outdated
| # if the particle number is not modulo the device number, we have to padd a few empty particles | ||
| # to make it work | ||
| # this is a bit of a hack, but it works | ||
| n = inputdata.stars.coords.shape[0] | ||
| pad = (num_devices - (n % num_devices)) % num_devices | ||
|
|
||
| if pad: | ||
| # pad along the first axis | ||
| inputdata.stars.coords = jnp.pad(inputdata.stars.coords, ((0, pad), (0, 0))) | ||
| inputdata.stars.velocity = jnp.pad( | ||
| inputdata.stars.velocity, ((0, pad), (0, 0)) | ||
| ) | ||
| inputdata.stars.mass = jnp.pad(inputdata.stars.mass, ((0, pad))) | ||
| inputdata.stars.age = jnp.pad(inputdata.stars.age, ((0, pad))) | ||
| inputdata.stars.metallicity = jnp.pad( | ||
| inputdata.stars.metallicity, ((0, pad)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we gonna implement this?
| local_cube = out_local.stars.datacube # shape (25,25,5994) | ||
| # in‐XLA all‐reduce across the "data" axis: | ||
| summed_cube = lax.psum(local_cube, axis_name="data") | ||
| return summed_cube # replicated on each device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really have to replicate on each device???
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, how it works otherwise. Feel free to change, this was the first thing working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don´t quite get the logic behind this. wouldn't this allocate a lot of memory on each device? what happens on the devices afterwards that would need the result on all devices? could this be a source of memory issues?
I would propose to refactor the run_sharded function to support a more flexible approach here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did this refactor happen or shall we open an issue for this?
| getattr(self.sim, cls), fields[cls], units[cls], cls | ||
| ) | ||
|
|
||
| # for cls in self.data: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need al those commented lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed it
tests/test_core_ifu.py
Outdated
| print("Sample_inputs:") | ||
| for key in sample_inputs: | ||
| sample_inputs[key] = reshape_array(sample_inputs[key]) | ||
| # sample_inputs[key] = reshape_array(sample_inputs[key]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this commented line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, removed it
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
could we maybe move the notebooks into a separate PR to reduce the size a little? |
MaHaWo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few comments here and there, but I would prefer if the run_sharded function could be refactored to allow for the passing of a device configuration and output device perhaps. Maybe the actual sharding could happen in a separate function too, to structure the code a little better.
notebooks/rubix_pipeline_sharding.py
Outdated
| @@ -0,0 +1,114 @@ | |||
| import os | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn´t put test code like this file into the code base.
| "source": [ | ||
| "# NBVAL_SKIP\n", | ||
| "#import os\n", | ||
| "# os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe don't put in hardcoded paths anywhere. this is not executable by anyone other than yourself. rather leave it open and explain what to do here perhaps?
| @@ -0,0 +1,377 @@ | |||
| { | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks to me a bit like it is experimental? In general, code that just tests functionality or exists for experimenting is not something that should be merged into main imho. Rather, put it into some other place or remove it again before the final version of the branch should be merged.
| "met = inputdata.stars.metallicity\n", | ||
| "factor = 1\n", | ||
| "inputdata.stars.coords = jnp.concatenate([coords]*factor, axis=0)\n", | ||
| "inputdata.stars.velocity = jnp.concatenate([vel]*factor, axis=0)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dummy data? maybe add a markdown cell to make clear the purpose?
| "source": [ | ||
| "# NBVAL_SKIP\n", | ||
| "import jax.numpy as jnp\n", | ||
| "gpu_number = jnp.array([1, 2, 3, 4, 5, 6, 7])\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are all these hardcoded numbers? if these are performance experiments, that should not go here...
| pynbody.analysis.angmom.faceon(halo.s) | ||
| ang_mom_vec = pynbody.analysis.angmom.ang_mom_vec(halo.s) | ||
| rotation_matrix = pynbody.analysis.angmom.calc_sideon_matrix(ang_mom_vec) | ||
| if not os.path.exists("./data"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this path general/robust enough? In general, it would be better to avoid hardcoding paths. Rather draw them from an environment variable or config ...
tests/test_pynbody_handler.py
Outdated
| def handler_with_mock_data(mock_simulation, mock_config): | ||
| """ | ||
| with patch("pynbody.load", return_value=mock_simulation): | ||
| with patch("pynbody.analysis.angmom.faceon", return_value=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as before, what is this multiline comment for?
| expected_result = jnp.stack( | ||
| def test_get_calculate_datacube_particlewise(): | ||
| # Setup config and telescope | ||
| config = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it could make sense making these things into fixtures, as well as the other classes and configs that exist in this file.
tests/test_core_pipeline.py
Outdated
| # The cube should have nonzero values (sanity check) | ||
| assert jnp.any(output_cube != 0) | ||
|
|
||
| print("run_sharded output shape:", output_cube.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we remove these?
| n_particles = num_devices if num_devices > 1 else 2 # At least two for sanity | ||
|
|
||
| # Mock input data | ||
| input_data = RubixData( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this conceivably be a fixture?
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| local_cube = out_local.stars.datacube # shape (25,25,5994) | ||
| # in‐XLA all‐reduce across the "data" axis: | ||
| summed_cube = lax.psum(local_cube, axis_name="data") | ||
| return summed_cube # replicated on each device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did this refactor happen or shall we open an issue for this?
When calculating the datacube for many stellar particles (in the order of hundred thousands), we had memory issues. The final datacube had in a lot of spaxels negative flux.
Wer thought this is solved by switching from pmap to shard_map. However, the error still occurred. The speculation was that we got an over/underflow at the point, when the pipeline was asigning the spectra to the stellar particles, because rubix hold all spectra for the individual particles in memory, which lead to a spike in memory and on GPUs the code even failed because it run out of memory.
Therefore this branch changes the spectra assignment and datacube calculation. Rubix is now looking up the spectrum for one particle, mass weights it, doppler shift it and resamples it and then adds the spectrum at the spaxel in the datacube according to the spaxel_assignment. We lax.scan over all particles. Testing on the MaStar ssp template this removes our issue with negative flux. At the same time the computation time does not increase for this method. For comparison see notebook rubix_pipeline_single_function_shard_map_memory.ipynb
I open already the pull request to get feedback in an early stage of this major change in the code structure. Things still to do on this branch, before we move it to the main:
Left: old method, right new method with lax.scan using the MaStar template