@@ -27,21 +27,23 @@ class PipelineOp(ABC):
27
27
Prefix for output variables when using pattern matching
28
28
"""
29
29
30
- def __init__ (self ,
31
- name : Optional [str ] | List [str ] = None ,
32
- input_variable : Optional [str ] | List [str ] = None ,
33
- output_variable : Optional [str ] | List [str ] = None ,
34
- input_prefix : Optional [str ] | List [str ] = None ,
35
- output_prefix : Optional [str ] | List [str ] = None ):
30
+ def __init__ (
31
+ self ,
32
+ name : Optional [str ] | List [str ] = None ,
33
+ input_variable : Optional [str ] | List [str ] = None ,
34
+ output_variable : Optional [str ] | List [str ] = None ,
35
+ input_prefix : Optional [str ] | List [str ] = None ,
36
+ output_prefix : Optional [str ] | List [str ] = None ,
37
+ ):
36
38
37
39
if all (x is None for x in [input_variable , output_variable , input_prefix , output_prefix ]):
38
40
warnings .warn (
39
- ' No input/output information set for PipelineOp...this is likely an error' ,
40
- stacklevel = 2
41
+ " No input/output information set for PipelineOp...this is likely an error" ,
42
+ stacklevel = 2 ,
41
43
)
42
44
43
45
if name is None :
44
- self .name = ' PipelineOp'
46
+ self .name = " PipelineOp"
45
47
else :
46
48
self .name = name
47
49
@@ -60,22 +62,22 @@ def __init__(self,
60
62
pass
61
63
62
64
# variables to exclude when constructing attrs dict for xarray
63
- self ._banned_from_attrs = [' output' , ' _banned_from_attrs' ]
65
+ self ._banned_from_attrs = [" output" , " _banned_from_attrs" ]
64
66
65
67
@abstractmethod
66
68
def calculate (self , dataset : xr .Dataset ) -> Self :
67
69
pass
68
70
69
71
def __repr__ (self ) -> str :
70
- return f' <PipelineOp:{ self .name } >'
72
+ return f" <PipelineOp:{ self .name } >"
71
73
72
74
def copy (self ) -> Self :
73
75
return copy .deepcopy (self )
74
76
75
77
def _prefix_output (self , variable_name : str ) -> str :
76
78
prefixed_variable = copy .deepcopy (variable_name )
77
79
if self .output_prefix is not None :
78
- prefixed_variable = f' { self .output_prefix } _{ prefixed_variable } '
80
+ prefixed_variable = f" { self .output_prefix } _{ prefixed_variable } "
79
81
return prefixed_variable
80
82
81
83
def _get_attrs (self ) -> Dict :
@@ -86,7 +88,7 @@ def _get_attrs(self) -> Dict:
86
88
except KeyError :
87
89
pass
88
90
89
- #sanitize
91
+ # sanitize
90
92
for key in output_dict .keys ():
91
93
output_dict [key ] = str (output_dict [key ])
92
94
# if output_dict[key] is None:
@@ -98,16 +100,20 @@ def _get_attrs(self) -> Dict:
98
100
99
101
def _get_variable (self , dataset : xr .Dataset ) -> xr .DataArray :
100
102
if self .input_variable is None and self .input_prefix is None :
101
- raise ValueError ((
102
- """Can't get variable for {self.name} without input_variable """
103
- """or input_prefix specified in constructor """
104
- ))
103
+ raise ValueError (
104
+ (
105
+ """Can't get variable for {self.name} without input_variable """
106
+ """or input_prefix specified in constructor """
107
+ )
108
+ )
105
109
106
110
if self .input_variable is not None and self .input_prefix is not None :
107
- raise ValueError ((
108
- """Both input_variable and input_prefix were specified in constructor. """
109
- """Only one should be specified to avoid ambiguous operation"""
110
- ))
111
+ raise ValueError (
112
+ (
113
+ """Both input_variable and input_prefix were specified in constructor. """
114
+ """Only one should be specified to avoid ambiguous operation"""
115
+ )
116
+ )
111
117
112
118
if self .input_variable is not None :
113
119
output = dataset [self .input_variable ].copy ()
@@ -137,10 +143,12 @@ def add_to_dataset(self, dataset, copy_dataset=True):
137
143
value .attrs .update (self ._get_attrs ())
138
144
dataset1 [name ] = value
139
145
else :
140
- raise ValueError ((
141
- f"""Items in output dictionary of PipelineOp { self .name } must be xr.Dataset or xr.DataArray """
142
- f"""Found variable named { name } of type { type (value )} ."""
143
- ))
146
+ raise ValueError (
147
+ (
148
+ f"""Items in output dictionary of PipelineOp { self .name } must be xr.Dataset or xr.DataArray """
149
+ f"""Found variable named { name } of type { type (value )} ."""
150
+ )
151
+ )
144
152
return dataset1
145
153
146
154
def add_to_tiled (self , tiled_data ):
@@ -150,20 +158,24 @@ def add_to_tiled(self, tiled_data):
150
158
# for name, dataarray in self.output.items():
151
159
# tiled_data.add_array(name, value.values)
152
160
153
- def plot (self ,** mpl_kwargs ) -> plt .Figure :
161
+ def plot (self , sample_dim : str = "sample" , ** mpl_kwargs ) -> plt .Figure :
162
+ """Plots the output of the PipelineOp.
163
+
164
+ This method attempts to guess how to plot the data produced by the operation.
165
+ """
154
166
n = len (self .output )
155
- if n > 0 :
156
- fig , axes = plt .subplots (n ,1 , figsize = (8 , n * 4 ))
157
- if n > 1 :
167
+ if n > 0 :
168
+ fig , axes = plt .subplots (n , 1 , figsize = (6 , n * 3 ))
169
+ if n > 1 :
158
170
axes = list (axes .flatten ())
159
171
else :
160
172
axes = [axes ]
161
173
162
- for i ,(name ,data ) in enumerate (self .output .items ()):
163
- if 'sample' in data .dims :
164
- data = data .plot (hue = 'sample' , ax = axes [i ],** mpl_kwargs )
174
+ for i , (name , data ) in enumerate (self .output .items ()):
175
+ if data . ndim > 1 and ( sample_dim in data .dims ) :
176
+ data .plot (hue = sample_dim , ax = axes [i ], ** mpl_kwargs )
165
177
else :
166
- data .plot (ax = axes [i ],** mpl_kwargs )
178
+ data .plot (ax = axes [i ], ** mpl_kwargs )
167
179
axes [i ].set (title = name )
168
180
return fig
169
181
else :
0 commit comments