Skip to content

Refactor Pipeline parallelization using jax.sharding #141

@ufuk-cakir

Description

@ufuk-cakir

The Problem

the current parallization strategy in the pipeline uses jax.pmap and has some limitations

  1. Manual reshaping during pipeline -- requires reshape_data step to manually perepare data arrays with a leading dimension matching the number of devices.
  2. Limited flexibility -- implementing more comples strategies in the future (like model parallelism) is less natural
  3. potential optimizations -- jax.sharding integrates more deeply with the XLA compilor, with potential perfoarmce optimizations
  4. jit(pmap) causes issues -- pmap is called inside the pipeline, and then the final function is jitted. This causes known issues, see jit(pmap(f)) causes inefficient behavior jax-ml/jax#2926

Proposed Solution

Implement the parallelism using the jax.sharding api. For a general introduction, check the JAX documentation here

  1. Define a logical Mesh representing the GPU devices:
    We need a 1D mesh where we want to split the data across the particle axis
devices = jax.devices()
mesh = jax.create_mesh((num_devices,),"data")
  1. Use PartitionSpec and NamedSharding to define how data should be distributed (sharded) or replicated across the mesh. Assuming particle data has shame (num_particles,...), we want to split data along the first axis which we named "data".
    Here we need to define, which arrays are split across devices, and which arrays are copied. For example, we want to split e.g mass, metallicity, but keep e.g spatial_bin_edges, as a copy on all GPU devices.
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
particle_sharding = NamedSharding(mesh, PartitionSpec('data',None))

this means split the data across the first dimension across the devices, but keep everything else intact. i.e f

GPU 0 gets: particle_data[0:N, :]
GPU 1 gets: particle_data[N:2N, :]
GPU 2 gets: particle_data[2N:3N, :]

For other components that we need to copy on all devices, this follows the following structure

replicate_sharding= NamedSharding(mesh,PartitionSpec(None)) # For  scalrs 
replicate_sharding = NamedSharding(mesh,PartitionSpec(None,None)) # For  2D values, ie spatial_bin_edges 

We then need to define this for all the fields in the RubixData pytree, this could look like the following:

particle_sharding = NamedSharding(mesh, PartitionSpec('data', None))
replicated_sharding = NamedSharding(mesh, PartitionSpec(None))
datacube_replicated_sharding = NamedSharding(mesh, PartitionSpec(None, None, None))
replicated_2d_sharding = NamedSharding(mesh, PartitionSpec(None, None))


# Use None for fields not present in a specific run
galaxy_sharding_spec = Galaxy(
    redshift=replicated_sharding,
    center=replicated_2d_sharding, # Assuming center is (3,) replicated over 'data' -> (N,3) -> sharding (None,) if 1D, (None, None) if 2D
    halfmassrad_stars=replicated_sharding
)

stars_sharding_spec = StarsData(
    coords=particle_sharding,           # Shard particle fields
    velocity=particle_sharding,
    mass=particle_sharding,
    metallicity=particle_sharding,
    age=particle_sharding,
    pixel_assignment=particle_sharding,
    spatial_bin_edges=replicated_sharding, # Replicate global info
    mask=particle_sharding,
    spectra=particle_sharding,          # Intermediate spectra are sharded
    datacube=datacube_replicated_sharding # Final cube replicated
)

# Define also for gas
gas_sharding_spec = ...

# The complete sharding specification Pytree
rubixdata_sharding_spec = RubixData(
    galaxy=galaxy_sharding_spec,
    stars=stars_sharding_spec,
    gas=gas_sharding_spec 
)
  1. Place data onto device with the specified sharding
particle_data_sharded = jax.device_put(particle_data_cpu, particle_sharding)
scalar_param_replicated = jax.device_put(scalar_param_cpu, replicated_sharding)
  1. Replace pmap calls with jax.jit and annotate the function inptuts and outputs with their respective NamedSharding specification5. Use jax.lax.psum within the jitted function for cross-device reduction (e.g summing partial datacubes)
@partial(jax.jit,
         #how inputs ARE sharded when the function is called
         in_shardings=(particle_sharding, replicated_sharding),
         # Specify how the output SHOULD be sharded
         out_shardings=particle_sharding)
def some_function(data, param):
  1. remove the reshape_data step entirely

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions