@@ -38,17 +38,24 @@ class HasSharding(tp.Protocol):
38
38
def _has_sharding (x : tp .Any ) -> tp .TypeGuard [HasSharding ]:
39
39
return hasattr (x , 'sharding' ) and x .sharding is not None
40
40
41
- def add_axis (tree : A , index : int , params : tp .Mapping ) -> A :
42
- axis_name = _get_partition_name (params )
41
+ def add_axis (tree : A , index : int , transform_metadata : tp .Mapping ) -> A :
42
+ axis_name , other_meta = _get_partition_name_and_metadata (transform_metadata )
43
+
44
+ def insert_field (fields , index , value ):
45
+ iterable = list (fields )
46
+ while len (iterable ) < index :
47
+ iterable .append (None )
48
+ iterable .insert (index , value )
49
+ return tuple (iterable )
43
50
44
51
def _add_axis (x : tp .Any ):
45
52
if isinstance (x , variablelib .VariableState ):
46
53
if _has_sharding (x ) and x .sharding is not None :
47
- sharding : list [ str | None ] = list (x .sharding )
48
- while len ( sharding ) < index :
49
- sharding . append ( None )
50
- sharding . insert ( index , axis_name )
51
- x . sharding = tuple ( sharding ) # type: ignore
54
+ x . sharding = insert_field (x .sharding , index , axis_name )
55
+
56
+ for k , v in other_meta . items ():
57
+ if hasattr ( x , k ) and ( t := getattr ( x , k )) and isinstance ( t , tuple ):
58
+ setattr ( x , k , insert_field ( t , index , v ))
52
59
53
60
assert isinstance (x , variablelib .VariableState )
54
61
x .add_axis (index , axis_name )
@@ -59,15 +66,23 @@ def _add_axis(x: tp.Any):
59
66
)
60
67
61
68
62
- def remove_axis (tree : A , index : int , params : tp .Mapping [tp .Any , tp .Any ]) -> A :
63
- axis_name = _get_partition_name (params )
69
+ def remove_axis (tree : A , index : int , transform_metadata : tp .Mapping [tp .Any , tp .Any ]) -> A :
70
+ axis_name , other_meta = _get_partition_name_and_metadata (transform_metadata )
71
+
72
+ def remove_field (fields , index , value ):
73
+ iterable = list (fields )
74
+ assert iterable .pop (index ) == value
75
+ return tuple (iterable )
64
76
65
77
def _remove_axis (x : tp .Any ):
66
78
if isinstance (x , variablelib .VariableState ):
67
79
if hasattr (x , 'sharding' ) and x .sharding is not None :
68
- sharding = list (x .sharding )
69
- assert sharding .pop (index ) == axis_name
70
- x .sharding = tuple (sharding )
80
+ x .sharding = remove_field (x .sharding , index , axis_name )
81
+
82
+ for k , v in other_meta .items ():
83
+ if hasattr (x , k ) and (t := getattr (x , k )) and isinstance (t , tuple ):
84
+ setattr (x , k , remove_field (t , index , v ))
85
+
71
86
x .remove_axis (index , axis_name )
72
87
return x
73
88
@@ -78,13 +93,17 @@ def _remove_axis(x: tp.Any):
78
93
)
79
94
80
95
81
- def _get_partition_name (params : tp .Mapping [tp .Any , tp .Any ]) -> str :
82
- if PARTITION_NAME not in params :
96
+ def _get_partition_name_and_metadata (
97
+ transform_metadata : tp .Mapping [tp .Any , tp .Any ]
98
+ ) -> tuple [str , tp .Mapping [tp .Any , tp .Any ]]:
99
+ if PARTITION_NAME not in transform_metadata :
83
100
raise ValueError (
84
101
'Trying to transform a Partitioned variable but "partition_name" '
85
- f'is not specified in scan_metadata : { params } '
102
+ f'is not specified in transform_metadata : { transform_metadata } '
86
103
)
87
- return params [PARTITION_NAME ]
104
+ other_meta = dict (transform_metadata ) # shallow copy
105
+ other_meta .pop (PARTITION_NAME )
106
+ return transform_metadata [PARTITION_NAME ], other_meta
88
107
89
108
90
109
def get_partition_spec (tree : A ) -> A :
0 commit comments