@@ -33,7 +33,7 @@ def fetch(self, data: dict[str, Any], keys: Iterable[str]):
33
33
if isinstance (current_key , int ):
34
34
return self .fetch (data [current_key ], keys )
35
35
if len (keys ) == 0 :
36
- return data [ current_key ]
36
+ return data . get ( current_key , None )
37
37
elif isinstance (data , list ):
38
38
assert current_key == "{dim}" , current_key
39
39
values = []
@@ -75,19 +75,38 @@ def parse(
75
75
else self .fetch (data , self .units_attr .split (self .nesting_delimiter ))
76
76
)
77
77
78
- if self .trim_channel_transforms :
79
- if len (axis_names ) == len (units ) == len (voxel_size ) == len (offset ):
80
- offset = [o for o , axis in zip (offset , axis_names ) if "^" not in axis ]
81
- voxel_size = [
82
- v for v , axis in zip (voxel_size , axis_names ) if "^" not in axis
83
- ]
84
- units = [u for u , axis in zip (units , axis_names ) if "^" not in axis ]
78
+ if self .trim_channel_transforms and axis_names is not None :
79
+ channel_dims = [True if "^" in axis else False for axis in axis_names ]
80
+ if sum (channel_dims ) > 0 :
81
+ if offset is not None and len (offset ) == len (channel_dims ):
82
+ offset = [
83
+ o
84
+ for o , channel_dim in zip (offset , channel_dims )
85
+ if not channel_dim
86
+ ]
87
+ if voxel_size is not None and len (voxel_size ) == len (channel_dims ):
88
+ voxel_size = [
89
+ v
90
+ for v , channel_dim in zip (voxel_size , channel_dims )
91
+ if not channel_dim
92
+ ]
93
+ if units is not None and len (units ) == len (channel_dims ):
94
+ units = [
95
+ u
96
+ for u , channel_dim in zip (units , channel_dims )
97
+ if not channel_dim
98
+ ]
99
+
100
+ offset = Coordinate (offset ) if offset is not None else None
101
+ voxel_size = Coordinate (voxel_size ) if voxel_size is not None else None
102
+ axis_names = list (axis_names ) if axis_names is not None else None
103
+ units = list (units ) if units is not None else None
85
104
86
105
metadata = MetaData (
87
- offset = Coordinate ( offset ) ,
88
- voxel_size = Coordinate ( voxel_size ) ,
89
- axis_names = list ( axis_names ) ,
90
- units = [ unit if unit is not None else "" for unit in units ] ,
106
+ offset = offset ,
107
+ voxel_size = voxel_size ,
108
+ axis_names = axis_names ,
109
+ units = units ,
91
110
)
92
111
metadata .validate ()
93
112
0 commit comments