Skip to content

Commit b5d4ed8

Browse files
cpgaffney1Flax Authors
authored and
Flax Authors
committed
Remove usages of orbax_utils.save_args_from_target, as this function does nothing (it used to control a checkpointing behavior that has since been optimized away).
PiperOrigin-RevId: 718571228
1 parent e4418e2 commit b5d4ed8

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

flax/training/checkpoints.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,8 @@ def save_checkpoint(
690690
' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-you-don-t-save-pytrees'
691691
)
692692

693-
save_args = orbax_utils.save_args_from_target(target)
694693
orbax_checkpointer.save(
695-
ckpt_path, target, save_args=save_args, force=overwrite
694+
ckpt_path, target, force=overwrite
696695
)
697696
# Do a process check here in case people call this for multihost.
698697
if process_index() == 0:
@@ -843,9 +842,8 @@ def save_checkpoint_multiprocess(
843842
_remove_invalid_ckpts(
844843
ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True
845844
)
846-
save_args = orbax_utils.save_args_from_target(target)
847845
orbax_checkpointer.save(
848-
ckpt_path, target, save_args=save_args, force=overwrite
846+
ckpt_path, target, force=overwrite
849847
)
850848
end_time = time.time()
851849
monitoring.record_event_duration_secs(

0 commit comments

Comments
 (0)