Skip to content

Commit 136d395

Browse files
committed
[docs] Adds PipelineOp plotting to quickstart example
1 parent fb76338 commit 136d395

File tree

2 files changed

+209
-110
lines changed

2 files changed

+209
-110
lines changed

AFL/double_agent/PipelineOp.py

+45-33
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,23 @@ class PipelineOp(ABC):
2727
Prefix for output variables when using pattern matching
2828
"""
2929

30-
def __init__(self,
31-
name: Optional[str] | List[str] = None,
32-
input_variable: Optional[str] | List[str] = None,
33-
output_variable: Optional[str] | List[str] = None,
34-
input_prefix: Optional[str] | List[str] = None,
35-
output_prefix: Optional[str] | List[str] = None):
30+
def __init__(
31+
self,
32+
name: Optional[str] | List[str] = None,
33+
input_variable: Optional[str] | List[str] = None,
34+
output_variable: Optional[str] | List[str] = None,
35+
input_prefix: Optional[str] | List[str] = None,
36+
output_prefix: Optional[str] | List[str] = None,
37+
):
3638

3739
if all(x is None for x in [input_variable, output_variable, input_prefix, output_prefix]):
3840
warnings.warn(
39-
'No input/output information set for PipelineOp...this is likely an error',
40-
stacklevel=2
41+
"No input/output information set for PipelineOp...this is likely an error",
42+
stacklevel=2,
4143
)
4244

4345
if name is None:
44-
self.name = 'PipelineOp'
46+
self.name = "PipelineOp"
4547
else:
4648
self.name = name
4749

@@ -60,22 +62,22 @@ def __init__(self,
6062
pass
6163

6264
# variables to exclude when constructing attrs dict for xarray
63-
self._banned_from_attrs = ['output', '_banned_from_attrs']
65+
self._banned_from_attrs = ["output", "_banned_from_attrs"]
6466

6567
@abstractmethod
6668
def calculate(self, dataset: xr.Dataset) -> Self:
6769
pass
6870

6971
def __repr__(self) -> str:
70-
return f'<PipelineOp:{self.name}>'
72+
return f"<PipelineOp:{self.name}>"
7173

7274
def copy(self) -> Self:
7375
return copy.deepcopy(self)
7476

7577
def _prefix_output(self, variable_name: str) -> str:
7678
prefixed_variable = copy.deepcopy(variable_name)
7779
if self.output_prefix is not None:
78-
prefixed_variable = f'{self.output_prefix}_{prefixed_variable}'
80+
prefixed_variable = f"{self.output_prefix}_{prefixed_variable}"
7981
return prefixed_variable
8082

8183
def _get_attrs(self) -> Dict:
@@ -86,7 +88,7 @@ def _get_attrs(self) -> Dict:
8688
except KeyError:
8789
pass
8890

89-
#sanitize
91+
# sanitize
9092
for key in output_dict.keys():
9193
output_dict[key] = str(output_dict[key])
9294
# if output_dict[key] is None:
@@ -98,16 +100,20 @@ def _get_attrs(self) -> Dict:
98100

99101
def _get_variable(self, dataset: xr.Dataset) -> xr.DataArray:
100102
if self.input_variable is None and self.input_prefix is None:
101-
raise ValueError((
102-
"""Can't get variable for {self.name} without input_variable """
103-
"""or input_prefix specified in constructor """
104-
))
103+
raise ValueError(
104+
(
105+
"""Can't get variable for {self.name} without input_variable """
106+
"""or input_prefix specified in constructor """
107+
)
108+
)
105109

106110
if self.input_variable is not None and self.input_prefix is not None:
107-
raise ValueError((
108-
"""Both input_variable and input_prefix were specified in constructor. """
109-
"""Only one should be specified to avoid ambiguous operation"""
110-
))
111+
raise ValueError(
112+
(
113+
"""Both input_variable and input_prefix were specified in constructor. """
114+
"""Only one should be specified to avoid ambiguous operation"""
115+
)
116+
)
111117

112118
if self.input_variable is not None:
113119
output = dataset[self.input_variable].copy()
@@ -137,10 +143,12 @@ def add_to_dataset(self, dataset, copy_dataset=True):
137143
value.attrs.update(self._get_attrs())
138144
dataset1[name] = value
139145
else:
140-
raise ValueError((
141-
f"""Items in output dictionary of PipelineOp {self.name} must be xr.Dataset or xr.DataArray """
142-
f"""Found variable named {name} of type {type(value)}."""
143-
))
146+
raise ValueError(
147+
(
148+
f"""Items in output dictionary of PipelineOp {self.name} must be xr.Dataset or xr.DataArray """
149+
f"""Found variable named {name} of type {type(value)}."""
150+
)
151+
)
144152
return dataset1
145153

146154
def add_to_tiled(self, tiled_data):
@@ -150,20 +158,24 @@ def add_to_tiled(self, tiled_data):
150158
# for name, dataarray in self.output.items():
151159
# tiled_data.add_array(name, value.values)
152160

153-
def plot(self,**mpl_kwargs) -> plt.Figure:
161+
def plot(self, sample_dim: str = "sample", **mpl_kwargs) -> plt.Figure:
162+
"""Plots the output of the PipelineOp.
163+
164+
This method attempts to guess how to plot the data produced by the operation.
165+
"""
154166
n = len(self.output)
155-
if n>0:
156-
fig, axes = plt.subplots(n,1,figsize=(8,n*4))
157-
if n>1:
167+
if n > 0:
168+
fig, axes = plt.subplots(n, 1, figsize=(6, n * 3))
169+
if n > 1:
158170
axes = list(axes.flatten())
159171
else:
160172
axes = [axes]
161173

162-
for i,(name,data) in enumerate(self.output.items()):
163-
if 'sample' in data.dims:
164-
data = data.plot(hue='sample',ax=axes[i],**mpl_kwargs)
174+
for i, (name, data) in enumerate(self.output.items()):
175+
if data.ndim > 1 and (sample_dim in data.dims):
176+
data.plot(hue=sample_dim, ax=axes[i], **mpl_kwargs)
165177
else:
166-
data.plot(ax=axes[i],**mpl_kwargs)
178+
data.plot(ax=axes[i], **mpl_kwargs)
167179
axes[i].set(title=name)
168180
return fig
169181
else:

0 commit comments

Comments
 (0)