@@ -835,8 +835,9 @@ def _partial_restore_with_omission(
835835
836836 try :
837837 value_metadata_tree = tree_structure_utils .tree_trim (
838- serialized_item , value_metadata_tree , strict = True
838+ serialized_item , value_metadata_tree , strict = False
839839 )
840+ value_metadata_tree = value_metadata_tree .unsafe_structure
840841 except ValueError as e :
841842 raise ValueError (
842843 'Missing keys were found in the user-provided restore item.'
@@ -856,20 +857,31 @@ def _partial_restore_with_placeholders(
856857 ):
857858 """Restores leaves from `item`, except for those marked as placeholders."""
858859 serialized_item = tree_utils .serialize_tree (item , keep_empty_nodes = True )
859- diff = tree_structure_utils .tree_difference (
860- serialized_item ,
861- value_metadata_tree ,
862- is_leaf = tree_utils .is_empty_or_leaf ,
863- leaves_equal = lambda a , b : True ,
860+ diff = (
861+ tree_structure_utils .tree_difference (
862+ serialized_item ,
863+ value_metadata_tree ,
864+ is_leaf = tree_utils .is_empty_or_leaf ,
865+ leaves_equal = lambda a , b : True ,
866+ )
867+ or {}
864868 )
865- if diff is not None :
866- formatted_diff = tree_structure_utils .format_tree_diff (
867- diff , source_label = 'Item' , target_label = 'Metadata'
868- )
869- raise ValueError (
870- 'User-provided restore item and on-disk value metadata tree'
871- f' structures do not match:\n { formatted_diff } '
872- )
869+ for keypath , value_diff in tree_utils .to_flat_dict (
870+ diff , is_leaf = lambda x : isinstance (x , tree_structure_utils .Diff )
871+ ).items ():
872+ if value_diff .lhs is PLACEHOLDER and value_diff .rhs is None :
873+ parent = value_metadata_tree
874+ for key in keypath [:- 1 ]:
875+ parent = parent [key ]
876+ parent [keypath [- 1 ]] = PLACEHOLDER
877+ else :
878+ formatted_diff = tree_structure_utils .format_tree_diff (
879+ diff , source_label = 'Item' , target_label = 'Metadata'
880+ )
881+ raise ValueError (
882+ 'User-provided restore item and on-disk value metadata tree'
883+ f' structures do not match:\n { formatted_diff } '
884+ )
873885 return jax .tree .map (
874886 lambda v , i : PLACEHOLDER if type_handlers .is_placeholder (i ) else v ,
875887 value_metadata_tree ,
0 commit comments