Skip to content

Commit

Permalink
[JAX] Remove code that sets or tests --jax_coordination_service.
Browse files Browse the repository at this point in the history
--jax_coordination_service defaults to True and has for some time, and support for the non-coordination service case will be removed shortly.

PiperOrigin-RevId: 551932242
  • Loading branch information
hawkinsp authored and jax authors committed Jul 28, 2023
1 parent 9a21ff0 commit ddfdb7a
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def reached_preemption_sync_point(step_id: int) -> bool:
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
All hosts should continue training more steps until this method returns True,
indicating that the `step_id` is equal to the safe step and the hosts should
start saving a checkpoint. This feature requires enabling
`jax.config.jax_coordination_service`.
start saving a checkpoint.
To use this API, all hosts must start training from the same step and call at
every training step. Example usage:
Expand Down

0 comments on commit ddfdb7a

Please sign in to comment.