Skip to content

Commit 1ea451b

Browse files
committed
make metrics serializable
It seems that metrics do not store their state, I'm not sure yet if this is intended behavior.
1 parent 619d14b commit 1ea451b

File tree

7 files changed

+70
-10
lines changed

7 files changed

+70
-10
lines changed

bayesflow/metrics/functional/maximum_mean_discrepancy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22

33
from bayesflow.types import Tensor
44
from bayesflow.utils import issue_url
5+
from bayesflow.utils.serialization import serializable
6+
7+
from typing import Literal
58

69
from .kernels import gaussian, inverse_multiquadratic
710

811

12+
@serializable("bayesflow.metrics")
913
def maximum_mean_discrepancy(
10-
x: Tensor, y: Tensor, kernel: str = "inverse_multiquadratic", unbiased: bool = False, **kwargs
14+
x: Tensor,
15+
y: Tensor,
16+
kernel: Literal["inverse_multiquadratic", "gaussian"] = "inverse_multiquadratic",
17+
unbiased: bool = False,
18+
**kwargs,
1119
) -> Tensor:
1220
"""Computes a mixture of Gaussian radial basis functions (RBFs) between the samples of x and y.
1321

bayesflow/metrics/functional/root_mean_squared_error.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from keras import ops
33

44
from bayesflow.types import Tensor
5+
from bayesflow.utils.serialization import serializable
56

67

8+
@serializable("bayesflow.metrics")
79
def root_mean_squared_error(x1: Tensor, x2: Tensor, normalize: bool = False, **kwargs) -> Tensor:
810
"""Computes the (normalized) root mean squared error between samples x1 and x2.
911

bayesflow/metrics/maximum_mean_discrepancy.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from functools import partial
2-
31
import keras
42

5-
from bayesflow.utils.serialization import serializable
3+
from bayesflow.utils.serialization import deserialize, serializable, serialize
64
from .functional import maximum_mean_discrepancy
75

86

@@ -17,10 +15,22 @@ def __init__(
1715
):
1816
super().__init__(name=name, **kwargs)
1917
self.mmd = self.add_variable(shape=(), initializer="zeros", name="mmd")
20-
self.mmd_fn = partial(maximum_mean_discrepancy, kernel=kernel, unbiased=unbiased)
18+
self.kernel = kernel
19+
self.unbiased = unbiased
2120

2221
def update_state(self, x, y):
23-
self.mmd.assign(keras.ops.cast(self.mmd_fn(x, y), self.dtype))
22+
self.mmd.assign(
23+
keras.ops.cast(maximum_mean_discrepancy(x, y, kernel=self.kernel, unbiased=self.unbiased), self.dtype)
24+
)
2425

2526
def result(self):
2627
return self.mmd.value
28+
29+
def get_config(self):
30+
base_config = super().get_config()
31+
config = {"kernel": self.kernel, "unbiased": self.unbiased}
32+
return base_config | serialize(config)
33+
34+
@classmethod
35+
def from_config(cls, config, custom_objects=None):
36+
return cls(**deserialize(config, custom_objects=custom_objects))
Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
from functools import partial
21
import keras
32

4-
from bayesflow.utils.serialization import serializable
3+
from bayesflow.utils.serialization import deserialize, serializable
54
from .functional import root_mean_squared_error
65

76

87
@serializable("bayesflow.metrics")
98
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
109
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
11-
fn = partial(root_mean_squared_error, **kwargs)
12-
super().__init__(fn, name=name, dtype=dtype)
10+
super().__init__(root_mean_squared_error, name=name, dtype=dtype, **kwargs)
11+
12+
def get_config(self):
13+
base_config = super().get_config()
14+
# fn is fixed and passed directly in the constructor
15+
base_config.pop("fn")
16+
return base_config
17+
18+
@classmethod
19+
def from_config(cls, config, custom_objects=None):
20+
return cls(**deserialize(config, custom_objects=custom_objects))

tests/test_metrics/__init__.py

Whitespace-only changes.

tests/test_metrics/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def root_mean_squared_error():
6+
from bayesflow.metrics import RootMeanSquaredError
7+
8+
return RootMeanSquaredError(normalize=True, name="rmse", dtype="float64")
9+
10+
11+
@pytest.fixture()
12+
def maximum_mean_discrepancy():
13+
from bayesflow.metrics import MaximumMeanDiscrepancy
14+
15+
return MaximumMeanDiscrepancy(name="mmd", kernel="gaussian", unbiased=True, dtype="float64")
16+
17+
18+
@pytest.fixture(params=["root_mean_squared_error", "maximum_mean_discrepancy"])
19+
def metric(request):
20+
return request.getfixturevalue(request.param)

tests/test_metrics/test_metrics.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from bayesflow.utils.serialization import serialize, deserialize
2+
import keras
3+
4+
5+
def test_serialize_deserialize(metric, random_samples):
6+
metric.update_state(keras.random.normal((2, 3)), keras.random.normal((2, 3)))
7+
8+
serialized = serialize(metric)
9+
deserialized = deserialize(serialized)
10+
reserialized = serialize(deserialized)
11+
12+
assert reserialized == serialized

0 commit comments

Comments
 (0)