Skip to content

Commit 2b28d3b

Browse files
Marvin182copybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 543579662
1 parent a766965 commit 2b28d3b

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

clu/metric_writers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747
from clu.metric_writers.async_writer import AsyncMultiWriter
4848
from clu.metric_writers.async_writer import AsyncWriter
4949
from clu.metric_writers.async_writer import ensure_flushes
50-
from clu.metric_writers.summary_writer import SummaryWriter
5150
from clu.metric_writers.interface import MetricWriter
5251
from clu.metric_writers.logging_writer import LoggingWriter
5352
from clu.metric_writers.multi_writer import MultiWriter
53+
from clu.metric_writers.summary_writer import SummaryWriter
5454
from clu.metric_writers.utils import create_default_writer
5555
from clu.metric_writers.utils import write_values
5656

clu/metric_writers/utils.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
from absl import logging
3232
from clu import values
3333
from clu.metric_writers.async_writer import AsyncMultiWriter
34-
from clu.metric_writers.summary_writer import SummaryWriter
3534
from clu.metric_writers.interface import MetricWriter
3635
from clu.metric_writers.logging_writer import LoggingWriter
3736
from clu.metric_writers.multi_writer import MultiWriter
37+
from clu.metric_writers.summary_writer import SummaryWriter
3838
from etils import epath
3939
import jax.numpy as jnp
4040
import numpy as np
@@ -44,17 +44,22 @@
4444

4545

4646
def _is_scalar(value: Any) -> bool:
47-
if isinstance(value, values.Scalar) or isinstance(value,
48-
(int, float, np.number)):
47+
if isinstance(value, values.Scalar) or isinstance(
48+
value, (int, float, np.number)
49+
):
4950
return True
5051
if isinstance(value, (np.ndarray, jnp.ndarray)):
5152
return value.ndim == 0 or value.size <= 1
5253
return False
5354

5455

55-
def write_values(writer: MetricWriter, step: int,
56-
metrics: Mapping[str, Union[values.Value, values.ArrayType,
57-
values.ScalarType]]):
56+
def write_values(
57+
writer: MetricWriter,
58+
step: int,
59+
metrics: Mapping[
60+
str, Union[values.Value, values.ArrayType, values.ScalarType]
61+
],
62+
):
5863
"""Writes all provided metrics.
5964
6065
Allows providing a mapping of name to Value object, where each Value
@@ -70,8 +75,9 @@ def write_values(writer: MetricWriter, step: int,
7075
histogram_num_buckets = collections.defaultdict(int)
7176
for k, v in metrics.items():
7277
if isinstance(v, values.Summary):
73-
writes[(writer.write_summaries, frozenset({"metadata": v.metadata
74-
}.items()))][k] = v.value
78+
writes[
79+
(writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
80+
][k] = v.value
7581
elif _is_scalar(v):
7682
if isinstance(v, values.Scalar):
7783
writes[(writer.write_scalars, frozenset())][k] = v.value
@@ -87,8 +93,10 @@ def write_values(writer: MetricWriter, step: int,
8793
writes[(writer.write_histograms, frozenset())][k] = v.value
8894
histogram_num_buckets[k] = v.num_buckets
8995
elif isinstance(v, values.Audio):
90-
writes[(writer.write_audios,
91-
frozenset({"sample_rate": v.sample_rate}.items()))][k] = v.value
96+
writes[(
97+
writer.write_audios,
98+
frozenset({"sample_rate": v.sample_rate}.items()),
99+
)][k] = v.value
92100
else:
93101
raise ValueError("Metric: ", k, " has unsupported value: ", v)
94102

@@ -107,7 +115,8 @@ def create_default_writer(
107115
*,
108116
just_logging: bool = False,
109117
asynchronous: bool = True,
110-
collection: Optional[str] = None) -> MultiWriter:
118+
collection: Optional[str] = None,
119+
) -> MultiWriter:
111120
"""Create the default writer for the platform.
112121
113122
On most platforms this will create a MultiWriter that writes to multiple back

clu/metric_writers/utils_test.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from clu.metric_writers import utils
2626
from clu.metric_writers.async_writer import AsyncMultiWriter
2727
from clu.metric_writers.async_writer import AsyncWriter
28-
from clu.metric_writers.summary_writer import SummaryWriter
2928
from clu.metric_writers.interface import MetricWriter
3029
from clu.metric_writers.logging_writer import LoggingWriter
3130
from clu.metric_writers.multi_writer import MultiWriter
31+
from clu.metric_writers.summary_writer import SummaryWriter
3232
import clu.metrics
3333
import flax.struct
3434
import jax.numpy as jnp
@@ -129,17 +129,20 @@ def test_write(self):
129129
"image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
130130
}
131131
histogram_metrics = {
132-
"hist":
133-
HistogramMetric(value=jnp.asarray([7, 8]), num_buckets=num_buckets),
134-
"hist2":
135-
HistogramMetric(
136-
value=jnp.asarray([9, 10]), num_buckets=num_buckets),
132+
"hist": HistogramMetric(
133+
value=jnp.asarray([7, 8]), num_buckets=num_buckets
134+
),
135+
"hist2": HistogramMetric(
136+
value=jnp.asarray([9, 10]), num_buckets=num_buckets
137+
),
137138
}
138139
audio_metrics = {
139-
"audio":
140-
AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate),
141-
"audio2":
142-
AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2),
140+
"audio": AudioMetric(
141+
value=jnp.asarray([1, 5]), sample_rate=sample_rate
142+
),
143+
"audio2": AudioMetric(
144+
value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
145+
),
143146
}
144147
text_metrics = {
145148
"text": TextMetric(value="hello"),
@@ -148,10 +151,10 @@ def test_write(self):
148151
"lr": HyperParamMetric(value=0.01),
149152
}
150153
summary_metrics = {
151-
"summary":
152-
SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata="some info"),
153-
"summary2":
154-
SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
154+
"summary": SummaryMetric(
155+
value=jnp.asarray([2, 3, 10]), metadata="some info"
156+
),
157+
"summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
155158
}
156159
metrics = {
157160
**scalar_metrics,
@@ -166,29 +169,36 @@ def test_write(self):
166169
utils.write_values(writer, step, metrics)
167170

168171
writer.write_scalars.assert_called_once_with(
169-
step, {k: m.compute() for k, m in scalar_metrics.items()})
170-
writer.write_images.assert_called_once_with(step,
171-
_to_summary(image_metrics))
172+
step, {k: m.compute() for k, m in scalar_metrics.items()}
173+
)
174+
writer.write_images.assert_called_once_with(
175+
step, _to_summary(image_metrics)
176+
)
172177
writer.write_histograms.assert_called_once_with(
173178
step,
174179
_to_summary(histogram_metrics),
175-
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()})
180+
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
181+
)
176182
writer.write_audios.assert_called_with(
177183
step,
178184
ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
179-
sample_rate=ONEOF([sample_rate, sample_rate + 2]))
185+
sample_rate=ONEOF([sample_rate, sample_rate + 2]),
186+
)
180187
writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
181-
writer.write_hparams.assert_called_once_with(step,
182-
_to_summary(hparam_metrics))
188+
writer.write_hparams.assert_called_once_with(
189+
step, _to_summary(hparam_metrics)
190+
)
183191
writer.write_summaries.assert_called_with(
184192
step,
185193
ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
186-
metadata=ONEOF(["some info", 5]))
194+
metadata=ONEOF(["some info", 5]),
195+
)
187196

188197

189198
def test_create_default_writer_summary_writer_is_added(self):
190199
writer = utils.create_default_writer(
191-
logdir=self.get_temp_dir(), asynchronous=False)
200+
logdir=self.get_temp_dir(), asynchronous=False
201+
)
192202
self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
193203
writer = utils.create_default_writer(logdir=None, asynchronous=False)
194204
self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))

0 commit comments

Comments
 (0)