Skip to content

Commit f52d01a

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Allow unexpected keys in restore_item during partial restoration.
PiperOrigin-RevId: 823229099
1 parent e4d8241 commit f52d01a

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer_test_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -767,30 +767,6 @@ def test_partial_restore_with_omission(self):
767767
)
768768
test_utils.assert_tree_equal(self, expected, restored)
769769

770-
with self.subTest('extra_leaf'):
771-
with self.checkpointer(
772-
PyTreeCheckpointHandler()
773-
) as restore_checkpointer:
774-
reference_item = {
775-
'a': 0,
776-
'c': {
777-
'a': 0,
778-
},
779-
'z': 0,
780-
}
781-
with self.assertRaisesRegex(
782-
ValueError,
783-
'Missing keys were found in the user-provided restore item.',
784-
):
785-
restore_checkpointer.restore(
786-
directory,
787-
args=pytree_checkpoint_handler.PyTreeRestoreArgs(
788-
item=reference_item,
789-
restore_args=self.pytree_restore_args,
790-
partial_restore=True,
791-
),
792-
)
793-
794770
def test_restore_logs_read_event(self):
795771
"""Tests that restore logs a read event to DM Sawmill log."""
796772
with self.checkpointer(PyTreeCheckpointHandler()) as checkpointer:

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -835,8 +835,9 @@ def _partial_restore_with_omission(
835835

836836
try:
837837
value_metadata_tree = tree_structure_utils.tree_trim(
838-
serialized_item, value_metadata_tree, strict=True
838+
serialized_item, value_metadata_tree, strict=False
839839
)
840+
value_metadata_tree = value_metadata_tree.unsafe_structure
840841
except ValueError as e:
841842
raise ValueError(
842843
'Missing keys were found in the user-provided restore item.'
@@ -856,20 +857,31 @@ def _partial_restore_with_placeholders(
856857
):
857858
"""Restores leaves from `item`, except for those marked as placeholders."""
858859
serialized_item = tree_utils.serialize_tree(item, keep_empty_nodes=True)
859-
diff = tree_structure_utils.tree_difference(
860-
serialized_item,
861-
value_metadata_tree,
862-
is_leaf=tree_utils.is_empty_or_leaf,
863-
leaves_equal=lambda a, b: True,
860+
diff = (
861+
tree_structure_utils.tree_difference(
862+
serialized_item,
863+
value_metadata_tree,
864+
is_leaf=tree_utils.is_empty_or_leaf,
865+
leaves_equal=lambda a, b: True,
866+
)
867+
or {}
864868
)
865-
if diff is not None:
866-
formatted_diff = tree_structure_utils.format_tree_diff(
867-
diff, source_label='Item', target_label='Metadata'
868-
)
869-
raise ValueError(
870-
'User-provided restore item and on-disk value metadata tree'
871-
f' structures do not match:\n{formatted_diff}'
872-
)
869+
for keypath, value_diff in tree_utils.to_flat_dict(
870+
diff, is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff)
871+
).items():
872+
if value_diff.lhs is PLACEHOLDER and value_diff.rhs is None:
873+
parent = value_metadata_tree
874+
for key in keypath[:-1]:
875+
parent = parent[key]
876+
parent[keypath[-1]] = PLACEHOLDER
877+
else:
878+
formatted_diff = tree_structure_utils.format_tree_diff(
879+
diff, source_label='Item', target_label='Metadata'
880+
)
881+
raise ValueError(
882+
'User-provided restore item and on-disk value metadata tree'
883+
f' structures do not match:\n{formatted_diff}'
884+
)
873885
return jax.tree.map(
874886
lambda v, i: PLACEHOLDER if type_handlers.is_placeholder(i) else v,
875887
value_metadata_tree,

0 commit comments

Comments
 (0)