-
Notifications
You must be signed in to change notification settings - Fork 3
Labels
enhancementNew feature or requestNew feature or request
Description
The Problem
the current parallization strategy in the pipeline uses jax.pmap and has some limitations
- Manual reshaping during pipeline -- requires
reshape_datastep to manually perepare data arrays with a leading dimension matching the number of devices. - Limited flexibility -- implementing more comples strategies in the future (like model parallelism) is less natural
- potential optimizations --
jax.shardingintegrates more deeply with the XLA compilor, with potential perfoarmce optimizations jit(pmap)causes issues -- pmap is called inside the pipeline, and then the final function is jitted. This causes known issues, seejit(pmap(f))causes inefficient behavior jax-ml/jax#2926
Proposed Solution
Implement the parallelism using the jax.sharding api. For a general introduction, check the JAX documentation here
- Define a logical
Meshrepresenting the GPU devices:
We need a 1D mesh where we want to split the data across the particle axis
devices = jax.devices()
mesh = jax.create_mesh((num_devices,),"data")- Use
PartitionSpecandNamedShardingto define how data should be distributed (sharded) or replicated across the mesh. Assuming particle data has shame(num_particles,...), we want to split data along the first axis which we named "data".
Here we need to define, which arrays are split across devices, and which arrays are copied. For example, we want to split e.g mass, metallicity, but keep e.gspatial_bin_edges, as a copy on all GPU devices.
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
particle_sharding = NamedSharding(mesh, PartitionSpec('data',None))this means split the data across the first dimension across the devices, but keep everything else intact. i.e f
GPU 0 gets: particle_data[0:N, :]
GPU 1 gets: particle_data[N:2N, :]
GPU 2 gets: particle_data[2N:3N, :]
For other components that we need to copy on all devices, this follows the following structure
replicate_sharding= NamedSharding(mesh,PartitionSpec(None)) # For scalrs
replicate_sharding = NamedSharding(mesh,PartitionSpec(None,None)) # For 2D values, ie spatial_bin_edges We then need to define this for all the fields in the RubixData pytree, this could look like the following:
particle_sharding = NamedSharding(mesh, PartitionSpec('data', None))
replicated_sharding = NamedSharding(mesh, PartitionSpec(None))
datacube_replicated_sharding = NamedSharding(mesh, PartitionSpec(None, None, None))
replicated_2d_sharding = NamedSharding(mesh, PartitionSpec(None, None))
# Use None for fields not present in a specific run
galaxy_sharding_spec = Galaxy(
redshift=replicated_sharding,
center=replicated_2d_sharding, # Assuming center is (3,) replicated over 'data' -> (N,3) -> sharding (None,) if 1D, (None, None) if 2D
halfmassrad_stars=replicated_sharding
)
stars_sharding_spec = StarsData(
coords=particle_sharding, # Shard particle fields
velocity=particle_sharding,
mass=particle_sharding,
metallicity=particle_sharding,
age=particle_sharding,
pixel_assignment=particle_sharding,
spatial_bin_edges=replicated_sharding, # Replicate global info
mask=particle_sharding,
spectra=particle_sharding, # Intermediate spectra are sharded
datacube=datacube_replicated_sharding # Final cube replicated
)
# Define also for gas
gas_sharding_spec = ...
# The complete sharding specification Pytree
rubixdata_sharding_spec = RubixData(
galaxy=galaxy_sharding_spec,
stars=stars_sharding_spec,
gas=gas_sharding_spec
)- Place data onto device with the specified sharding
particle_data_sharded = jax.device_put(particle_data_cpu, particle_sharding)
scalar_param_replicated = jax.device_put(scalar_param_cpu, replicated_sharding)- Replace
pmapcalls withjax.jitand annotate the function inptuts and outputs with their respectiveNamedShardingspecification5. Usejax.lax.psumwithin the jitted function for cross-device reduction (e.g summing partial datacubes)
@partial(jax.jit,
#how inputs ARE sharded when the function is called
in_shardings=(particle_sharding, replicated_sharding),
# Specify how the output SHOULD be sharded
out_shardings=particle_sharding)
def some_function(data, param):- remove the
reshape_datastep entirely
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request