25
25
from clu .metric_writers import utils
26
26
from clu .metric_writers .async_writer import AsyncMultiWriter
27
27
from clu .metric_writers .async_writer import AsyncWriter
28
- from clu .metric_writers .summary_writer import SummaryWriter
29
28
from clu .metric_writers .interface import MetricWriter
30
29
from clu .metric_writers .logging_writer import LoggingWriter
31
30
from clu .metric_writers .multi_writer import MultiWriter
31
+ from clu .metric_writers .summary_writer import SummaryWriter
32
32
import clu .metrics
33
33
import flax .struct
34
34
import jax .numpy as jnp
@@ -129,17 +129,20 @@ def test_write(self):
129
129
"image" : ImageMetric (jnp .asarray ([[4 , 5 ], [1 , 2 ]])),
130
130
}
131
131
histogram_metrics = {
132
- "hist" :
133
- HistogramMetric (value = jnp .asarray ([7 , 8 ]), num_buckets = num_buckets ),
134
- "hist2" :
135
- HistogramMetric (
136
- value = jnp .asarray ([9 , 10 ]), num_buckets = num_buckets ),
132
+ "hist" : HistogramMetric (
133
+ value = jnp .asarray ([7 , 8 ]), num_buckets = num_buckets
134
+ ),
135
+ "hist2" : HistogramMetric (
136
+ value = jnp .asarray ([9 , 10 ]), num_buckets = num_buckets
137
+ ),
137
138
}
138
139
audio_metrics = {
139
- "audio" :
140
- AudioMetric (value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate ),
141
- "audio2" :
142
- AudioMetric (value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate + 2 ),
140
+ "audio" : AudioMetric (
141
+ value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate
142
+ ),
143
+ "audio2" : AudioMetric (
144
+ value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate + 2
145
+ ),
143
146
}
144
147
text_metrics = {
145
148
"text" : TextMetric (value = "hello" ),
@@ -148,10 +151,10 @@ def test_write(self):
148
151
"lr" : HyperParamMetric (value = 0.01 ),
149
152
}
150
153
summary_metrics = {
151
- "summary" :
152
- SummaryMetric ( value = jnp .asarray ([2 , 3 , 10 ]), metadata = "some info" ),
153
- "summary2" :
154
- SummaryMetric (value = jnp .asarray ([2 , 3 , 10 ]), metadata = 5 ),
154
+ "summary" : SummaryMetric (
155
+ value = jnp .asarray ([2 , 3 , 10 ]), metadata = "some info"
156
+ ),
157
+ "summary2" : SummaryMetric (value = jnp .asarray ([2 , 3 , 10 ]), metadata = 5 ),
155
158
}
156
159
metrics = {
157
160
** scalar_metrics ,
@@ -166,29 +169,36 @@ def test_write(self):
166
169
utils .write_values (writer , step , metrics )
167
170
168
171
writer .write_scalars .assert_called_once_with (
169
- step , {k : m .compute () for k , m in scalar_metrics .items ()})
170
- writer .write_images .assert_called_once_with (step ,
171
- _to_summary (image_metrics ))
172
+ step , {k : m .compute () for k , m in scalar_metrics .items ()}
173
+ )
174
+ writer .write_images .assert_called_once_with (
175
+ step , _to_summary (image_metrics )
176
+ )
172
177
writer .write_histograms .assert_called_once_with (
173
178
step ,
174
179
_to_summary (histogram_metrics ),
175
- num_buckets = {k : v .num_buckets for k , v in histogram_metrics .items ()})
180
+ num_buckets = {k : v .num_buckets for k , v in histogram_metrics .items ()},
181
+ )
176
182
writer .write_audios .assert_called_with (
177
183
step ,
178
184
ONEOF (_to_list_of_dicts (_to_summary (audio_metrics ))),
179
- sample_rate = ONEOF ([sample_rate , sample_rate + 2 ]))
185
+ sample_rate = ONEOF ([sample_rate , sample_rate + 2 ]),
186
+ )
180
187
writer .write_texts .assert_called_once_with (step , _to_summary (text_metrics ))
181
- writer .write_hparams .assert_called_once_with (step ,
182
- _to_summary (hparam_metrics ))
188
+ writer .write_hparams .assert_called_once_with (
189
+ step , _to_summary (hparam_metrics )
190
+ )
183
191
writer .write_summaries .assert_called_with (
184
192
step ,
185
193
ONEOF (_to_list_of_dicts (_to_summary (summary_metrics ))),
186
- metadata = ONEOF (["some info" , 5 ]))
194
+ metadata = ONEOF (["some info" , 5 ]),
195
+ )
187
196
188
197
189
198
def test_create_default_writer_summary_writer_is_added (self ):
190
199
writer = utils .create_default_writer (
191
- logdir = self .get_temp_dir (), asynchronous = False )
200
+ logdir = self .get_temp_dir (), asynchronous = False
201
+ )
192
202
self .assertTrue (any (isinstance (w , SummaryWriter ) for w in writer ._writers ))
193
203
writer = utils .create_default_writer (logdir = None , asynchronous = False )
194
204
self .assertFalse (any (isinstance (w , SummaryWriter ) for w in writer ._writers ))
0 commit comments