From 8d296e99c8add3f7adc235a476bc80d16340f50d Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 30 May 2025 09:00:21 +0000 Subject: [PATCH 01/13] fix: move additional metrics from approximator to networks Supplying the additional metrics for inference and summary networks via the approximators compile method caused problems during deseralization (#497). This can be resolved nicely by moving the metrics directly to the networks' constructors, analogous to how Keras normally handles custom metrics in layers. As summary networks and inference networks inherit from the respective base classes, this change only requires minor adaptations. --- .../approximators/continuous_approximator.py | 34 +++++++--------- .../diffusion_model/diffusion_model.py | 4 +- .../consistency_models/consistency_model.py | 5 ++- .../networks/coupling_flow/coupling_flow.py | 2 - .../networks/flow_matching/flow_matching.py | 2 - bayesflow/networks/inference_network.py | 16 +++++++- bayesflow/networks/summary_network.py | 15 +++++-- tests/test_networks/conftest.py | 39 ++++++++++++++----- tests/test_two_moons/conftest.py | 3 +- tests/test_two_moons/test_two_moons.py | 6 +-- 10 files changed, 79 insertions(+), 47 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 0d87cda2b..34d226489 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -1,13 +1,14 @@ from collections.abc import Mapping, Sequence, Callable import numpy as np +import warnings import keras from bayesflow.adapters import Adapter from bayesflow.networks import InferenceNetwork, SummaryNetwork from bayesflow.types import Tensor -from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict +from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict from bayesflow.utils.serialization import serialize, deserialize, serializable from .approximator import Approximator @@ -97,18 +98,21 @@ def build_adapter( def compile( self, *args, - inference_metrics: Sequence[keras.Metric] = None, - summary_metrics: Sequence[keras.Metric] = None, **kwargs, ): - if inference_metrics: - self.inference_network._metrics = inference_metrics + if "inference_metrics" in kwargs: + warnings.warn( + "Supplying inference metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) - if summary_metrics: - if self.summary_network is None: - logging.warning("Ignoring summary metrics because there is no summary network.") - else: - self.summary_network._metrics = summary_metrics + if "summary_metrics" in kwargs: + warnings.warn( + "Supplying summary metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) return super().compile(*args, **kwargs) @@ -227,16 +231,6 @@ def get_config(self): return base_config | serialize(config) - def get_compile_config(self): - base_config = super().get_compile_config() or {} - - config = { - "inference_metrics": self.inference_network._metrics, - "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, - } - - return base_config | serialize(config) - def estimate( self, conditions: Mapping[str, np.ndarray], diff --git a/bayesflow/experimental/diffusion_model/diffusion_model.py b/bayesflow/experimental/diffusion_model/diffusion_model.py index bcff50fb0..cc5de82ba 100644 --- a/bayesflow/experimental/diffusion_model/diffusion_model.py +++ b/bayesflow/experimental/diffusion_model/diffusion_model.py @@ -10,7 +10,6 @@ expand_right_as, find_network, jacobian_trace, - layer_kwargs, weighted_mean, integrate, integrate_stochastic, @@ -141,7 +140,8 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) + # base distribution is fixed and passed in constructor + base_config.pop("base_distribution") config = { "subnet": self.subnet, diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 74d6acd6a..6baa52e42 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -4,7 +4,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, layer_kwargs, weighted_mean +from bayesflow.utils import find_network, weighted_mean from bayesflow.utils.serialization import deserialize, serializable, serialize from ..inference_network import InferenceNetwork @@ -109,7 +109,8 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) + # base distribution is fixed and passed in constructor + base_config.pop("base_distribution") config = { "total_steps": self.total_steps, diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index adfb4953b..de162e2b8 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -3,7 +3,6 @@ from bayesflow.types import Tensor from bayesflow.utils import ( find_permutation, - layer_kwargs, weighted_mean, ) from bayesflow.utils.serialization import deserialize, serializable, serialize @@ -131,7 +130,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 781f9374d..9d88b009b 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -9,7 +9,6 @@ find_network, integrate, jacobian_trace, - layer_kwargs, optimal_transport, weighted_mean, ) @@ -138,7 +137,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..b3875d225 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,12 +1,16 @@ import keras +from collections.abc import Sequence from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size +from bayesflow.utils.serialization import deserialize, serializable, serialize +@serializable("bayesflow.networks") class InferenceNetwork(keras.Layer): - def __init__(self, base_distribution: str = "normal", **kwargs): + def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs): + self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) @@ -72,3 +76,13 @@ def compute_metrics( metrics[metric.name] = metric(samples, x) return metrics + + def get_config(self): + base_config = super().get_config() + base_config = layer_kwargs(base_config) + config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} + return base_config | serialize(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index e821be3f3..54cdd8e84 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -1,14 +1,17 @@ import keras +from collections.abc import Sequence from bayesflow.metrics.functional import maximum_mean_discrepancy from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape -from bayesflow.utils.serialization import deserialize +from bayesflow.utils.serialization import deserialize, serializable, serialize +@serializable("bayesflow.networks") class SummaryNetwork(keras.Layer): - def __init__(self, base_distribution: str = None, **kwargs): + def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs): + self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) @@ -17,7 +20,7 @@ def build(self, input_shape): x = keras.ops.zeros(input_shape) z = self.call(x) - if self.base_distribution is not None: + if self.base_distribution is not None and not self.base_distribution.built: self.base_distribution.build(keras.ops.shape(z)) @sanitize_input_shape @@ -51,6 +54,12 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ return metrics + def get_config(self): + base_config = super().get_config() + base_config = layer_kwargs(base_config) + config = {"base_distribution": self.base_distribution, "metrics": self.custom_metrics} + return base_config | serialize(config) + @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 678029d92..e25c8520a 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,6 +1,7 @@ import pytest from bayesflow.networks import MLP +from bayesflow.metrics import RootMeanSquaredError @pytest.fixture() @@ -12,6 +13,7 @@ def diffusion_model_edm_F(): integrate_kwargs={"method": "rk45", "steps": 250}, noise_schedule="edm", prediction_type="F", + metrics=[RootMeanSquaredError()], ) @@ -82,6 +84,7 @@ def flow_matching(): return FlowMatching( subnet=MLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 100}, + metrics=[RootMeanSquaredError()], ) @@ -89,7 +92,11 @@ def flow_matching(): def consistency_model(): from bayesflow.networks import ConsistencyModel - return ConsistencyModel(total_steps=100, subnet=MLP([8, 8])) + return ConsistencyModel( + total_steps=100, + subnet=MLP([8, 8]), + metrics=[RootMeanSquaredError()], + ) @pytest.fixture() @@ -97,7 +104,12 @@ def affine_coupling_flow(): from bayesflow.networks import CouplingFlow return CouplingFlow( - depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="affine", transform_kwargs=dict(clamp=1.8) + depth=2, + subnet="mlp", + subnet_kwargs=dict(widths=[8, 8]), + transform="affine", + transform_kwargs=dict(clamp=1.8), + metrics=[RootMeanSquaredError()], ) @@ -106,7 +118,12 @@ def spline_coupling_flow(): from bayesflow.networks import CouplingFlow return CouplingFlow( - depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="spline", transform_kwargs=dict(bins=8) + depth=2, + subnet="mlp", + subnet_kwargs=dict(widths=[8, 8]), + transform="spline", + transform_kwargs=dict(bins=8), + metrics=[RootMeanSquaredError()], ) @@ -114,7 +131,11 @@ def spline_coupling_flow(): def free_form_flow(): from bayesflow.experimental import FreeFormFlow - return FreeFormFlow(encoder_subnet=MLP([16, 16]), decoder_subnet=MLP([16, 16])) + return FreeFormFlow( + encoder_subnet=MLP([16, 16]), + decoder_subnet=MLP([16, 16]), + metrics=[RootMeanSquaredError()], + ) @pytest.fixture() @@ -236,35 +257,35 @@ def generative_inference_network(request): def time_series_network(summary_dim): from bayesflow.networks import TimeSeriesNetwork - return TimeSeriesNetwork(summary_dim=summary_dim) + return TimeSeriesNetwork(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def time_series_transformer(summary_dim): from bayesflow.networks import TimeSeriesTransformer - return TimeSeriesTransformer(summary_dim=summary_dim) + return TimeSeriesTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def fusion_transformer(summary_dim): from bayesflow.networks import FusionTransformer - return FusionTransformer(summary_dim=summary_dim) + return FusionTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer - return SetTransformer(summary_dim=summary_dim) + return SetTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def deep_set(summary_dim): from bayesflow.networks import DeepSet - return DeepSet(summary_dim=summary_dim) + return DeepSet(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture( diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 5cd6f59db..282354f23 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -4,8 +4,9 @@ @pytest.fixture() def inference_network(): from bayesflow.networks import CouplingFlow + from bayesflow.metrics import MaximumMeanDiscrepancy - return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) + return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)), metrics=[MaximumMeanDiscrepancy()]) @pytest.fixture() diff --git a/tests/test_two_moons/test_two_moons.py b/tests/test_two_moons/test_two_moons.py index f71dc7fe0..20245ca66 100644 --- a/tests/test_two_moons/test_two_moons.py +++ b/tests/test_two_moons/test_two_moons.py @@ -13,13 +13,9 @@ def test_compile(approximator, random_samples, jit_compile): def test_fit(approximator, train_dataset, validation_dataset, batch_size): - from bayesflow.metrics import MaximumMeanDiscrepancy from bayesflow.networks import PointInferenceNetwork - inference_metrics = [] - if not isinstance(approximator.inference_network, PointInferenceNetwork): - inference_metrics += [MaximumMeanDiscrepancy()] - approximator.compile(inference_metrics=inference_metrics) + approximator.compile() mock_data = train_dataset[0] mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data) From 51dff0df7e1d5247d38c7f61accc8f0ef4f95941 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 30 May 2025 12:53:21 +0000 Subject: [PATCH 02/13] Adapt Keras' auto-config mechanism for inference/summary networks This change makes it more capable for our purposes by allowing any serializable value, not only the base types in the auto-config. We have to check if this brings any footguns/downsides, or whether this is fine for our setting. It also replaces Keras' functions with our custom serialization functions. --- .../diffusion_model/diffusion_model.py | 2 - bayesflow/networks/base_layer/__init__.py | 1 + bayesflow/networks/base_layer/base_layer.py | 210 ++++++++++++++++++ .../consistency_models/consistency_model.py | 2 - .../networks/fusion_network/fusion_network.py | 7 +- bayesflow/networks/inference_network.py | 15 +- bayesflow/networks/summary_network.py | 16 +- tests/test_adapters/test_adapters.py | 8 +- .../test_fusion_network.py | 4 +- .../test_networks/test_inference_networks.py | 4 +- tests/test_networks/test_summary_networks.py | 4 +- tests/utils/__init__.py | 1 + tests/utils/normalize.py | 25 +++ 13 files changed, 255 insertions(+), 44 deletions(-) create mode 100644 bayesflow/networks/base_layer/__init__.py create mode 100644 bayesflow/networks/base_layer/base_layer.py create mode 100644 tests/utils/normalize.py diff --git a/bayesflow/experimental/diffusion_model/diffusion_model.py b/bayesflow/experimental/diffusion_model/diffusion_model.py index cc5de82ba..d1658b3ae 100644 --- a/bayesflow/experimental/diffusion_model/diffusion_model.py +++ b/bayesflow/experimental/diffusion_model/diffusion_model.py @@ -140,8 +140,6 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: def get_config(self): base_config = super().get_config() - # base distribution is fixed and passed in constructor - base_config.pop("base_distribution") config = { "subnet": self.subnet, diff --git a/bayesflow/networks/base_layer/__init__.py b/bayesflow/networks/base_layer/__init__.py new file mode 100644 index 000000000..9ce365658 --- /dev/null +++ b/bayesflow/networks/base_layer/__init__.py @@ -0,0 +1 @@ +from .base_layer import BaseLayer diff --git a/bayesflow/networks/base_layer/base_layer.py b/bayesflow/networks/base_layer/base_layer.py new file mode 100644 index 000000000..e07a33662 --- /dev/null +++ b/bayesflow/networks/base_layer/base_layer.py @@ -0,0 +1,210 @@ +import keras +import inspect +import textwrap +from functools import wraps + +from keras.src import dtype_policies +from keras.src import tree +from keras.src.backend.common.name_scope import current_path +from keras.src.utils import python_utils +from keras import Operation +from keras.saving import get_registered_name, get_registered_object + +from bayesflow.utils.serialization import serialize, deserialize + + +class BayesFlowSerializableDict: + def __init__(self, **config): + self.config = config + + def serialize(self): + return serialize(self.config) + + +class BaseLayer(keras.Layer): + def __new__(cls, *args, **kwargs): + """We override __new__ to saving serializable constructor arguments. + + These arguments are used to auto-generate an object serialization + config, which enables user-created subclasses to be serializable + out of the box in most cases without forcing the user + to manually implement `get_config()`. + """ + + # Adapted from keras.Operation.__new__, to support all serializable objects, instead + # of only basic types. + + instance = super(Operation, cls).__new__(cls) + + # Generate a config to be returned by default by `get_config()`. + arg_names = inspect.getfullargspec(cls.__init__).args + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) + + # Explicitly serialize `dtype` to support auto_config + dtype = kwargs.get("dtype", None) + if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): + # For backward compatibility, we use a str (`name`) for + # `DTypePolicy` + if dtype.quantization_mode is None: + kwargs["dtype"] = dtype.name + # Otherwise, use `dtype_policies.serialize` + else: + kwargs["dtype"] = dtype_policies.serialize(dtype) + + # Adaptation: we allow all registered serializable objects + supported_types = (str, int, float, bool, type(None)) + try: + flat_arg_values = tree.flatten(kwargs) + auto_config = True + for value in flat_arg_values: + is_serializable = get_registered_object(get_registered_name(type(value))) is not None + is_class = inspect.isclass(value) + if not (isinstance(value, supported_types) or is_serializable or is_class): + auto_config = False + break + except TypeError: + auto_config = False + try: + instance._lock = False + if auto_config: + instance._auto_config = BayesFlowSerializableDict(**kwargs) + else: + instance._auto_config = None + instance._lock = True + except RecursionError: + # Setting an instance attribute in __new__ has the potential + # to trigger an infinite recursion if a subclass overrides + # setattr in an unsafe way. + pass + + ### from keras.Layer.__new__ + + # Wrap the user-provided `build` method in the `build_wrapper` + # to add name scope support and serialization support. + original_build_method = instance.build + + @wraps(original_build_method) + def build_wrapper(*args, **kwargs): + with instance._open_name_scope(): + instance._path = current_path() + original_build_method(*args, **kwargs) + # Record build config. + signature = inspect.signature(original_build_method) + instance._build_shapes_dict = signature.bind(*args, **kwargs).arguments + # Set built, post build actions, and lock state. + instance.built = True + instance._post_build() + instance._lock_state() + + instance.build = build_wrapper + + # Wrap the user-provided `quantize` method in the `quantize_wrapper` + # to add tracker support. + original_quantize_method = instance.quantize + + @wraps(original_quantize_method) + def quantize_wrapper(mode, **kwargs): + instance._check_quantize_args(mode, instance.compute_dtype) + instance._tracker.unlock() + try: + original_quantize_method(mode, **kwargs) + except Exception: + raise + finally: + instance._tracker.lock() + + instance.quantize = quantize_wrapper + + return instance + + @python_utils.default + def get_config(self): + """Returns the config of the object. + + An object config is a Python dictionary (serializable) + containing the information needed to re-instantiate it. + """ + + # Adapted from Operations.get_config to support specifying a default configuration in + # subclasses, without giving up on the automatic config functionality. + config = super().get_config() + if not python_utils.is_default(self.get_config): + # In this case the subclass implements get_config() + return config + + # In this case the subclass doesn't implement get_config(): + # Let's see if we can autogenerate it. + if getattr(self, "_auto_config", None) is not None: + xtra_args = set(config.keys()) + config.update(self._auto_config.config) + # Remove args non explicitly supported + argspec = inspect.getfullargspec(self.__init__) + if argspec.varkw != "kwargs": + for key in xtra_args - xtra_args.intersection(argspec.args[1:]): + config.pop(key, None) + return config + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Object {self.__class__.__name__} was created by passing + non-serializable argument values in `__init__()`, + and therefore the object must override `get_config()` in + order to be serializable. Please implement `get_config()`. + + Example: + + class CustomLayer(keras.layers.Layer): + def __init__(self, arg1, arg2, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def get_config(self): + config = super().get_config() + config.update({{ + "arg1": self.arg1, + "arg2": self.arg2, + }}) + return config""" + ) + ) + + @classmethod + def from_config(cls, config): + """Creates an operation from its config. + + This method is the reverse of `get_config`, capable of instantiating the + same operation from the config dictionary. + + Note: If you override this method, you might receive a serialized dtype + config, which is a `dict`. You can deserialize it as follows: + + ```python + if "dtype" in config and isinstance(config["dtype"], dict): + policy = dtype_policies.deserialize(config["dtype"]) + ``` + + Args: + config: A Python dictionary, typically the output of `get_config`. + + Returns: + An operation instance. + """ + # Adapted from keras.Operation.from_config to use our deserialize function + # Explicitly deserialize dtype config if needed. This enables users to + # directly interact with the instance of `DTypePolicy`. + if "dtype" in config and isinstance(config["dtype"], dict): + config = config.copy() + policy = dtype_policies.deserialize(config["dtype"]) + if not isinstance(policy, dtype_policies.DTypePolicyMap) and policy.quantization_mode is None: + # For backward compatibility, we use a str (`name`) for + # `DTypePolicy` + policy = policy.name + config["dtype"] = policy + try: + return cls(**deserialize(config)) + except Exception as e: + raise TypeError( + f"Error when deserializing class '{cls.__name__}' using config={config}.\n\nException encountered: {e}" + ) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 6baa52e42..bfdca1b0b 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -109,8 +109,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - # base distribution is fixed and passed in constructor - base_config.pop("base_distribution") config = { "total_steps": self.total_steps, diff --git a/bayesflow/networks/fusion_network/fusion_network.py b/bayesflow/networks/fusion_network/fusion_network.py index 8d0132975..f5d81afc1 100644 --- a/bayesflow/networks/fusion_network/fusion_network.py +++ b/bayesflow/networks/fusion_network/fusion_network.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from ..summary_network import SummaryNetwork -from bayesflow.utils.serialization import deserialize, serializable, serialize +from bayesflow.utils.serialization import serializable, serialize from bayesflow.types import Tensor, Shape import keras from keras import ops @@ -116,8 +116,3 @@ def get_config(self) -> dict: "head": self.head, } return base_config | serialize(config) - - @classmethod - def from_config(cls, config: dict, custom_objects=None): - config = deserialize(config, custom_objects=custom_objects) - return cls(**config) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b3875d225..97b4bcb7a 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -4,11 +4,12 @@ from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size -from bayesflow.utils.serialization import deserialize, serializable, serialize +from bayesflow.utils.serialization import serializable +from .base_layer import BaseLayer @serializable("bayesflow.networks") -class InferenceNetwork(keras.Layer): +class InferenceNetwork(BaseLayer): def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) @@ -76,13 +77,3 @@ def compute_metrics( metrics[metric.name] = metric(samples, x) return metrics - - def get_config(self): - base_config = super().get_config() - base_config = layer_kwargs(base_config) - config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} - return base_config | serialize(config) - - @classmethod - def from_config(cls, config, custom_objects=None): - return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 54cdd8e84..868369937 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -5,11 +5,12 @@ from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape -from bayesflow.utils.serialization import deserialize, serializable, serialize +from bayesflow.utils.serialization import serializable +from .base_layer import BaseLayer @serializable("bayesflow.networks") -class SummaryNetwork(keras.Layer): +class SummaryNetwork(BaseLayer): def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) @@ -17,6 +18,7 @@ def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Met @sanitize_input_shape def build(self, input_shape): + print("SN build", self, input_shape) x = keras.ops.zeros(input_shape) z = self.call(x) @@ -53,13 +55,3 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ metrics[metric.name] = metric(outputs, samples) return metrics - - def get_config(self): - base_config = super().get_config() - base_config = layer_kwargs(base_config) - config = {"base_distribution": self.base_distribution, "metrics": self.custom_metrics} - return base_config | serialize(config) - - @classmethod - def from_config(cls, config, custom_objects=None): - return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 23721a938..c951aa0e8 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -4,6 +4,7 @@ import keras from bayesflow.utils.serialization import deserialize, serialize +from tests.utils import normalize_config import bayesflow as bf @@ -29,7 +30,7 @@ def test_serialize_deserialize(adapter, random_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert normalize_config(serialized) == normalize_config(reserialized) random_data["foo"] = random_data["x1"] deserialized_processed = deserialized(random_data) @@ -122,7 +123,6 @@ def test_simple_transforms(random_data): def test_custom_transform(): # test that transform raises errors in all relevant cases - import keras from bayesflow.adapters.transforms import SerializableCustomTransform from copy import deepcopy @@ -335,7 +335,7 @@ def test_nnpe(random_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert normalize_config(serialized) == normalize_config(reserialized) # check that only x1 is changed assert "x1" in result_training @@ -365,7 +365,7 @@ def test_nnpe(random_data): serialized_auto = serialize(ad_auto) deserialized_auto = deserialize(serialized_auto) reserialized_auto = serialize(deserialized_auto) - assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto)) + assert normalize_config(serialized_auto) == normalize_config(serialize(reserialized_auto)) # Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False) # Create data with second dimension having higher variance diff --git a/tests/test_networks/test_fusion_network/test_fusion_network.py b/tests/test_networks/test_fusion_network/test_fusion_network.py index f1dbfa1c0..827c5b20b 100644 --- a/tests/test_networks/test_fusion_network/test_fusion_network.py +++ b/tests/test_networks/test_fusion_network/test_fusion_network.py @@ -2,7 +2,7 @@ import pytest import keras -from tests.utils import assert_layers_equal, allclose +from tests.utils import assert_layers_equal, allclose, normalize_config @pytest.mark.parametrize("automatic", [True, False]) @@ -57,7 +57,7 @@ def test_serialize_deserialize(fusion_network, multimodal_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert normalize_config(serialized) == normalize_config(reserialized) def test_save_and_load(tmp_path, fusion_network, multimodal_data): diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 7766b29f9..752fb68aa 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import serialize, deserialize -from tests.utils import assert_allclose, assert_layers_equal +from tests.utils import assert_allclose, assert_layers_equal, normalize_config def test_build(inference_network, random_samples, random_conditions): @@ -137,7 +137,7 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert normalize_config(serialized) == normalize_config(reserialized) def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 74ce1f5fd..f40012fc4 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import deserialize, serialize -from tests.utils import assert_layers_equal +from tests.utils import assert_layers_equal, normalize_config @pytest.mark.parametrize("automatic", [True, False]) @@ -85,7 +85,7 @@ def test_serialize_deserialize(summary_network, random_set): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert normalize_config(serialized) == normalize_config(reserialized) def test_save_and_load(tmp_path, summary_network, random_set): diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 9c2affc22..5c7806cc9 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -3,4 +3,5 @@ from .check_combinations import * from .jupyter import * from .networks import * +from .normalize import * from .ops import * diff --git a/tests/utils/normalize.py b/tests/utils/normalize.py new file mode 100644 index 000000000..96198dc93 --- /dev/null +++ b/tests/utils/normalize.py @@ -0,0 +1,25 @@ +from copy import deepcopy +import keras + + +def normalize_dtype(config): + """Convert dtypes with DTypePolicy to simple strings""" + config = deepcopy(config) + + def walk_dictionary(cur_dict): + # walks the dicitonary and modifies entries in-place + for key, value in cur_dict.items(): + if key == "dtype" and isinstance(value, dict): + if value.get("class_name", "") == "DTypePolicy": + cur_dict[key] = value["config"]["name"] + continue + if isinstance(value, dict): + walk_dictionary(value) + + walk_dictionary(config) + return config + + +def normalize_config(config): + config = normalize_dtype(config) + config = keras.tree.lists_to_tuples(config) From 1b4ba12997b42613d4eaa0ea3a3f9869aaf42108 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 6 Jun 2025 09:12:32 +0000 Subject: [PATCH 03/13] move auto-config from superclass to monkey-patched decorator --- bayesflow/networks/base_layer/__init__.py | 1 - bayesflow/networks/base_layer/base_layer.py | 210 ------------------ .../networks/fusion_network/fusion_network.py | 7 +- bayesflow/networks/inference_network.py | 3 +- bayesflow/networks/summary_network.py | 4 +- bayesflow/utils/serialization.py | 72 +++++- 6 files changed, 79 insertions(+), 218 deletions(-) delete mode 100644 bayesflow/networks/base_layer/__init__.py delete mode 100644 bayesflow/networks/base_layer/base_layer.py diff --git a/bayesflow/networks/base_layer/__init__.py b/bayesflow/networks/base_layer/__init__.py deleted file mode 100644 index 9ce365658..000000000 --- a/bayesflow/networks/base_layer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_layer import BaseLayer diff --git a/bayesflow/networks/base_layer/base_layer.py b/bayesflow/networks/base_layer/base_layer.py deleted file mode 100644 index e07a33662..000000000 --- a/bayesflow/networks/base_layer/base_layer.py +++ /dev/null @@ -1,210 +0,0 @@ -import keras -import inspect -import textwrap -from functools import wraps - -from keras.src import dtype_policies -from keras.src import tree -from keras.src.backend.common.name_scope import current_path -from keras.src.utils import python_utils -from keras import Operation -from keras.saving import get_registered_name, get_registered_object - -from bayesflow.utils.serialization import serialize, deserialize - - -class BayesFlowSerializableDict: - def __init__(self, **config): - self.config = config - - def serialize(self): - return serialize(self.config) - - -class BaseLayer(keras.Layer): - def __new__(cls, *args, **kwargs): - """We override __new__ to saving serializable constructor arguments. - - These arguments are used to auto-generate an object serialization - config, which enables user-created subclasses to be serializable - out of the box in most cases without forcing the user - to manually implement `get_config()`. - """ - - # Adapted from keras.Operation.__new__, to support all serializable objects, instead - # of only basic types. - - instance = super(Operation, cls).__new__(cls) - - # Generate a config to be returned by default by `get_config()`. - arg_names = inspect.getfullargspec(cls.__init__).args - kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) - - # Explicitly serialize `dtype` to support auto_config - dtype = kwargs.get("dtype", None) - if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): - # For backward compatibility, we use a str (`name`) for - # `DTypePolicy` - if dtype.quantization_mode is None: - kwargs["dtype"] = dtype.name - # Otherwise, use `dtype_policies.serialize` - else: - kwargs["dtype"] = dtype_policies.serialize(dtype) - - # Adaptation: we allow all registered serializable objects - supported_types = (str, int, float, bool, type(None)) - try: - flat_arg_values = tree.flatten(kwargs) - auto_config = True - for value in flat_arg_values: - is_serializable = get_registered_object(get_registered_name(type(value))) is not None - is_class = inspect.isclass(value) - if not (isinstance(value, supported_types) or is_serializable or is_class): - auto_config = False - break - except TypeError: - auto_config = False - try: - instance._lock = False - if auto_config: - instance._auto_config = BayesFlowSerializableDict(**kwargs) - else: - instance._auto_config = None - instance._lock = True - except RecursionError: - # Setting an instance attribute in __new__ has the potential - # to trigger an infinite recursion if a subclass overrides - # setattr in an unsafe way. - pass - - ### from keras.Layer.__new__ - - # Wrap the user-provided `build` method in the `build_wrapper` - # to add name scope support and serialization support. - original_build_method = instance.build - - @wraps(original_build_method) - def build_wrapper(*args, **kwargs): - with instance._open_name_scope(): - instance._path = current_path() - original_build_method(*args, **kwargs) - # Record build config. - signature = inspect.signature(original_build_method) - instance._build_shapes_dict = signature.bind(*args, **kwargs).arguments - # Set built, post build actions, and lock state. - instance.built = True - instance._post_build() - instance._lock_state() - - instance.build = build_wrapper - - # Wrap the user-provided `quantize` method in the `quantize_wrapper` - # to add tracker support. - original_quantize_method = instance.quantize - - @wraps(original_quantize_method) - def quantize_wrapper(mode, **kwargs): - instance._check_quantize_args(mode, instance.compute_dtype) - instance._tracker.unlock() - try: - original_quantize_method(mode, **kwargs) - except Exception: - raise - finally: - instance._tracker.lock() - - instance.quantize = quantize_wrapper - - return instance - - @python_utils.default - def get_config(self): - """Returns the config of the object. - - An object config is a Python dictionary (serializable) - containing the information needed to re-instantiate it. - """ - - # Adapted from Operations.get_config to support specifying a default configuration in - # subclasses, without giving up on the automatic config functionality. - config = super().get_config() - if not python_utils.is_default(self.get_config): - # In this case the subclass implements get_config() - return config - - # In this case the subclass doesn't implement get_config(): - # Let's see if we can autogenerate it. - if getattr(self, "_auto_config", None) is not None: - xtra_args = set(config.keys()) - config.update(self._auto_config.config) - # Remove args non explicitly supported - argspec = inspect.getfullargspec(self.__init__) - if argspec.varkw != "kwargs": - for key in xtra_args - xtra_args.intersection(argspec.args[1:]): - config.pop(key, None) - return config - else: - raise NotImplementedError( - textwrap.dedent( - f""" - Object {self.__class__.__name__} was created by passing - non-serializable argument values in `__init__()`, - and therefore the object must override `get_config()` in - order to be serializable. Please implement `get_config()`. - - Example: - - class CustomLayer(keras.layers.Layer): - def __init__(self, arg1, arg2, **kwargs): - super().__init__(**kwargs) - self.arg1 = arg1 - self.arg2 = arg2 - - def get_config(self): - config = super().get_config() - config.update({{ - "arg1": self.arg1, - "arg2": self.arg2, - }}) - return config""" - ) - ) - - @classmethod - def from_config(cls, config): - """Creates an operation from its config. - - This method is the reverse of `get_config`, capable of instantiating the - same operation from the config dictionary. - - Note: If you override this method, you might receive a serialized dtype - config, which is a `dict`. You can deserialize it as follows: - - ```python - if "dtype" in config and isinstance(config["dtype"], dict): - policy = dtype_policies.deserialize(config["dtype"]) - ``` - - Args: - config: A Python dictionary, typically the output of `get_config`. - - Returns: - An operation instance. - """ - # Adapted from keras.Operation.from_config to use our deserialize function - # Explicitly deserialize dtype config if needed. This enables users to - # directly interact with the instance of `DTypePolicy`. - if "dtype" in config and isinstance(config["dtype"], dict): - config = config.copy() - policy = dtype_policies.deserialize(config["dtype"]) - if not isinstance(policy, dtype_policies.DTypePolicyMap) and policy.quantization_mode is None: - # For backward compatibility, we use a str (`name`) for - # `DTypePolicy` - policy = policy.name - config["dtype"] = policy - try: - return cls(**deserialize(config)) - except Exception as e: - raise TypeError( - f"Error when deserializing class '{cls.__name__}' using config={config}.\n\nException encountered: {e}" - ) diff --git a/bayesflow/networks/fusion_network/fusion_network.py b/bayesflow/networks/fusion_network/fusion_network.py index f5d81afc1..8d0132975 100644 --- a/bayesflow/networks/fusion_network/fusion_network.py +++ b/bayesflow/networks/fusion_network/fusion_network.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from ..summary_network import SummaryNetwork -from bayesflow.utils.serialization import serializable, serialize +from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.types import Tensor, Shape import keras from keras import ops @@ -116,3 +116,8 @@ def get_config(self) -> dict: "head": self.head, } return base_config | serialize(config) + + @classmethod + def from_config(cls, config: dict, custom_objects=None): + config = deserialize(config, custom_objects=custom_objects) + return cls(**config) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 97b4bcb7a..434773112 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -5,11 +5,10 @@ from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size from bayesflow.utils.serialization import serializable -from .base_layer import BaseLayer @serializable("bayesflow.networks") -class InferenceNetwork(BaseLayer): +class InferenceNetwork(keras.Layer): def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 868369937..f61cc576d 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -6,11 +6,10 @@ from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable -from .base_layer import BaseLayer @serializable("bayesflow.networks") -class SummaryNetwork(BaseLayer): +class SummaryNetwork(keras.Layer): def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) @@ -18,7 +17,6 @@ def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Met @sanitize_input_shape def build(self, input_shape): - print("SN build", self, input_shape) x = keras.ops.zeros(input_shape) z = self.call(x) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 5be0e0e1d..be63c7972 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -3,12 +3,16 @@ import builtins import inspect import keras +import functools import numpy as np import sys from warnings import warn # this import needs to be exactly like this to work with monkey patching -from keras.saving import deserialize_keras_object +from keras.saving import deserialize_keras_object, get_registered_object, get_registered_name +from keras.src.saving.serialization_lib import SerializableDict +from keras import dtype_policies +from keras import tree from .context_managers import monkey_patch from .decorators import allow_args @@ -95,6 +99,10 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): return obj +def _deserializing_from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + @allow_args def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): """Register class as Keras serializable. @@ -143,6 +151,68 @@ def serializable(cls, package: str, name: str | None = None, disable_module_chec if name is None: name = copy(cls.__name__) + def init_decorator(original_init): + # Adds auto-config behavior after the __init__ function. This extends the auto-config capabilities provided + # by keras.Operation (base class of keras.Layer) with support for all serializable objects. + # This produces a serialized config that has to be deserialized properly, see below. + @functools.wraps(original_init) + def wrapper(instance, *args, **kwargs): + original_init(instance, *args, **kwargs) + + # Generate a config to be returned by default by `get_config()`. + # Adapted from keras.Operation. + kwargs = kwargs.copy() + arg_names = inspect.getfullargspec(original_init).args + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) + + # Explicitly serialize `dtype` to support auto_config + dtype = kwargs.get("dtype", None) + if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): + # For backward compatibility, we use a str (`name`) for + # `DTypePolicy` + if dtype.quantization_mode is None: + kwargs["dtype"] = dtype.name + # Otherwise, use `dtype_policies.serialize` + else: + kwargs["dtype"] = dtype_policies.serialize(dtype) + + # supported basic types + supported_types = (str, int, float, bool, type(None)) + + flat_arg_values = tree.flatten(kwargs) + auto_config = True + for value in flat_arg_values: + # adaptation: we allow all registered serializable objects + is_serializable_object = ( + isinstance(value, supported_types) + or get_registered_object(get_registered_name(type(value))) is not None + ) + # adaptation: we allow all registered serializable objects + try: + is_serializable_class = inspect.isclass(value) and deserialize(serialize(value)) + except ValueError: + # deserializtion of type failed, probably not registered + is_serializable_class = False + if not (is_serializable_object or is_serializable_class): + auto_config = False + break + + if auto_config: + with monkey_patch(keras.saving.serialize_keras_object, serialize): + instance._auto_config = SerializableDict(**kwargs) + else: + instance._auto_config = None + + return wrapper + + cls.__init__ = init_decorator(cls.__init__) + + if hasattr(cls, "from_config") and cls.from_config.__func__ == keras.Layer.from_config.__func__: + # By default, keras.Layer.from_config does not deserializte the config. For this class, there is a + # from_config method that is identical to keras.Layer.config, so we replace it with a variant that applies + # deserialization to the config. + cls.from_config = classmethod(_deserializing_from_config) + # register subclasses as keras serializable return keras.saving.register_keras_serializable(package=package, name=name)(cls) From 52980fa1ca573b148139ec2e9123d16ced3687af Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 8 Jun 2025 16:10:08 +0000 Subject: [PATCH 04/13] deprecate passing metrics to model comparison approximator --- .../model_comparison_approximator.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index c0554ad66..f34314b79 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -2,6 +2,7 @@ import keras import numpy as np +import warnings from bayesflow.adapters import Adapter from bayesflow.datasets import OnlineDataset @@ -110,13 +111,18 @@ def compile( **kwargs, ): if classifier_metrics: - self.classifier_network._metrics = classifier_metrics + warnings.warn( + "Supplying classifier metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) if summary_metrics: - if self.summary_network is None: - logging.warning("Ignoring summary metrics because there is no summary network.") - else: - self.summary_network._metrics = summary_metrics + warnings.warn( + "Supplying summary metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) return super().compile(*args, **kwargs) @@ -270,16 +276,6 @@ def get_config(self): return base_config | serialize(config) - def get_compile_config(self): - base_config = super().get_compile_config() or {} - - config = { - "classifier_metrics": self.classifier_network._metrics, - "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, - } - - return base_config | serialize(config) - def predict( self, *, From 1eda62c73110ba54be920078e298e6725d4f4841 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 8 Jun 2025 17:13:20 +0000 Subject: [PATCH 05/13] Add support for custom metrics to MLP --- bayesflow/networks/mlp/mlp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index 7184070af..51c3656cd 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -29,6 +29,7 @@ def __init__( dropout: Literal[0, None] | float = 0.05, norm: Literal["batch", "layer"] | keras.Layer = None, spectral_normalization: bool = False, + metrics: Sequence[keras.Metric] | None = None, **kwargs, ): """ @@ -60,6 +61,7 @@ def __init__( **kwargs Additional keyword arguments passed to the Keras layer initialization. """ + self.custom_metrics = metrics self.widths = list(widths) self.activation = activation self.kernel_initializer = kernel_initializer @@ -90,6 +92,7 @@ def get_config(self): "dropout": self.dropout, "norm": self.norm, "spectral_normalization": self.spectral_normalization, + "metrics": self.custom_metrics, } return base_config | serialize(config) From af3f19c2cab4b6a536c37511c0373ab831df9cc8 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 8 Jun 2025 17:14:03 +0000 Subject: [PATCH 06/13] adjust signatures, extend tests --- .../approximators/model_comparison_approximator.py | 12 ++++++------ bayesflow/networks/inference_network.py | 2 +- bayesflow/networks/summary_network.py | 2 +- tests/test_approximators/conftest.py | 5 ++++- tests/test_approximators/test_fit.py | 4 ++++ .../test_model_comparison_approximator/conftest.py | 6 ++++-- .../test_model_comparison_approximator.py | 5 ++++- 7 files changed, 24 insertions(+), 12 deletions(-) diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index c3b06628f..64c7b8914 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -106,18 +106,16 @@ def build_dataset( def compile( self, *args, - classifier_metrics: Sequence[keras.Metric] = None, - summary_metrics: Sequence[keras.Metric] = None, **kwargs, ): - if classifier_metrics: + if "classifier_metrics" in kwargs: warnings.warn( "Supplying classifier metrics to the approximator is no longer supported. " "Please pass the metrics directly to the network using the metrics parameter.", DeprecationWarning, ) - if summary_metrics: + if "summary_metrics" in kwargs: warnings.warn( "Supplying summary metrics to the approximator is no longer supported. " "Please pass the metrics directly to the network using the metrics parameter.", @@ -166,8 +164,10 @@ def compute_metrics( classifier_metrics |= { metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics } - - loss = classifier_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(())) + if "loss" in summary_metrics: + loss = classifier_metrics["loss"] + summary_metrics["loss"] + else: + loss = classifier_metrics.pop("loss") classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()} summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 434773112..bcae87f2a 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -9,7 +9,7 @@ @serializable("bayesflow.networks") class InferenceNetwork(keras.Layer): - def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs): + def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] | None = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index f61cc576d..a2ef3e83b 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -10,7 +10,7 @@ @serializable("bayesflow.networks") class SummaryNetwork(keras.Layer): - def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs): + def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs): self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 3c4d2fd4c..4c67e9321 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -20,8 +20,11 @@ def summary_network(): @pytest.fixture() def inference_network(): from bayesflow.networks import CouplingFlow + from bayesflow.metrics import RootMeanSquaredError - return CouplingFlow(subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32))) + return CouplingFlow( + subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32)), metrics=[RootMeanSquaredError(name="rmse")] + ) @pytest.fixture() diff --git a/tests/test_approximators/test_fit.py b/tests/test_approximators/test_fit.py index b561efb77..8d416c1e2 100644 --- a/tests/test_approximators/test_fit.py +++ b/tests/test_approximators/test_fit.py @@ -49,3 +49,7 @@ def test_loss_progress(approximator, train_dataset, validation_dataset): # check that the shown loss is not nan or zero assert re.search(r"\bnan\b", output) is None, "found nan in output" assert re.search(r"\bloss: 0\.0000e\+00\b", output) is None, "found zero loss in output" + + # check that additional metric is present + assert "val_rmse/inference_rmse" in output, "custom metric (RMSE) not shown" + assert re.search(r"\bval_rmse/inference_rmse: \d+\.\d+", output) is not None, "custom metric not correctly shown" diff --git a/tests/test_approximators/test_model_comparison_approximator/conftest.py b/tests/test_approximators/test_model_comparison_approximator/conftest.py index 8bb4a97d3..57a4a027f 100644 --- a/tests/test_approximators/test_model_comparison_approximator/conftest.py +++ b/tests/test_approximators/test_model_comparison_approximator/conftest.py @@ -51,15 +51,17 @@ def adapter(): @pytest.fixture def summary_network(): from bayesflow.networks import DeepSet + from bayesflow.metrics import RootMeanSquaredError - return DeepSet(summary_dim=2, depth=1) + return DeepSet(summary_dim=2, depth=1, base_distribution="normal", metrics=[RootMeanSquaredError(name="rmse")]) @pytest.fixture def classifier_network(): from bayesflow.networks import MLP + from keras.metrics import CategoricalAccuracy - return MLP(widths=[32, 32]) + return MLP(widths=[32, 32], metrics=[CategoricalAccuracy(name="categorical_accuracy")]) @pytest.fixture diff --git a/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py b/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py index 0246ee7b7..5ca1d28e0 100644 --- a/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py +++ b/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py @@ -55,7 +55,10 @@ def test_fit(approximator, train_dataset, validation_dataset): output = stream.getvalue() # check that the loss is shown - assert "loss" in output + assert "loss/summary_loss" in output + assert "loss/classifier_loss" in output + assert "val_categorical_accuracy/classifier_categorical_accuracy" in output + assert "val_rmse/summary_rmse" in output def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset): From 07f754661e2451171231c30db7d3d2931c908eee Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 8 Jun 2025 17:37:17 +0000 Subject: [PATCH 07/13] pass probs insteads of predictions to classifier metrics This is what most classifier metrics expect, and contains more detail for the metrics to work with. --- bayesflow/approximators/model_comparison_approximator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 64c7b8914..49a5ba85a 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -160,9 +160,9 @@ def compute_metrics( if stage != "training" and any(self.classifier_network.metrics): # compute sample-based metrics - predictions = keras.ops.argmax(logits, axis=-1) + probs = keras.ops.softmax(logits) classifier_metrics |= { - metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics + metric.name: metric(model_indices, probs) for metric in self.classifier_network.metrics } if "loss" in summary_metrics: loss = classifier_metrics["loss"] + summary_metrics["loss"] From 44415f4ef0b4fa43ddf8d2a3b70cc3c54754ceb3 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Tue, 1 Jul 2025 05:07:46 -0400 Subject: [PATCH 08/13] Update docs and typehints --- bayesflow/networks/inference_network.py | 29 ++++++++++++++++++++++--- bayesflow/networks/mlp/mlp.py | 4 +++- bayesflow/networks/summary_network.py | 23 ++++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index bcae87f2a..af7998e22 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,6 +1,8 @@ -import keras +from typing import Literal from collections.abc import Sequence +import keras + from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size @@ -9,7 +11,28 @@ @serializable("bayesflow.networks") class InferenceNetwork(keras.Layer): - def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] | None = None, **kwargs): + def __init__( + self, + base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal", + *, + metrics: Sequence[keras.Metric] | None = None, + **kwargs, + ): + """ + Constructs an inference network using a specified base distribution and optional custom metrics. + Use this interface for custom inference networks. + + Parameters + ---------- + base_distribution : Literal["normal", "student", "mixture"] or keras.Layer + Name or the actual base distribution to use. Passed to `find_distribution` to + obtain the corresponding distribution object. + metrics : Sequence[keras.Metric] or None, optional + Sequence of custom Keras Metric instances to compute during training + and evaluation. If `None`, no custom metrics are used. + **kwargs + Additional keyword arguments forwarded to the `keras.Layer` constructor. + """ self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) @@ -70,7 +93,7 @@ def compute_metrics( if stage != "training" and any(self.metrics): # compute sample-based metrics - samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions) + samples = self.sample(batch_shape=(keras.ops.shape(x)[0],), conditions=conditions) for metric in self.metrics: metrics[metric.name] = metric(samples, x) diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index 51c3656cd..e2455e85a 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -55,9 +55,11 @@ def __init__( dropout : float or None, optional Dropout rate applied within the MLP layers for regularization. Default is 0.05. norm: str, optional - + Type of learnable normalization to be used (e.g., "batch" or "layer"). Default is None. spectral_normalization : bool, optional Whether to apply spectral normalization to stabilize training. Default is False. + metrics: Sequence[keras.Metric], optional + A sequence of callable metrics following keras' `Metric` signature. Default is None. **kwargs Additional keyword arguments passed to the Keras layer initialization. """ diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index a2ef3e83b..4d6ecda0a 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -11,6 +11,29 @@ @serializable("bayesflow.networks") class SummaryNetwork(keras.Layer): def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs): + """ + Builds a summary network with an optional base distribution and custom metrics. Use this class + as an interface for custom summary networks. + + Important: If a base distribution is passed, the summary outputs will be optimized to follow + said distribution, as described in [1]. + + [1] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2023). + Detecting model misspecification in amortized Bayesian inference with neural networks. + In DAGM German Conference on Pattern Recognition (pp. 541-557). Cham: Springer Nature Switzerland. + + Parameters + ---------- + base_distribution : str or None, default None + Name of the base distribution to use. If `None`, a default distribution + is chosen. Passed to `find_distribution` to obtain the corresponding + distribution object. + metrics : Sequence[keras.Metric] or None, optional + Sequence of custom Keras Metric instances to compute during training + and evaluation. If `None`, no custom metrics are used. + **kwargs + Additional keyword arguments forwarded to the `keras.Layer` constructor. + """ self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) self.base_distribution = find_distribution(base_distribution) From 66f8ca7a3b55784db57bed1bcd8cdf170980ecd9 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Tue, 1 Jul 2025 05:17:56 -0400 Subject: [PATCH 09/13] Better docs and comments [skip ci] --- bayesflow/networks/inference_network.py | 9 ++++++-- bayesflow/networks/mlp/mlp.py | 7 +++++-- bayesflow/networks/summary_network.py | 28 +++++++++++++++++-------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index af7998e22..432c86941 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -11,6 +11,11 @@ @serializable("bayesflow.networks") class InferenceNetwork(keras.Layer): + """ + Constructs an inference network using a specified base distribution and optional custom metrics. + Use this interface for custom inference networks. + """ + def __init__( self, base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal", @@ -19,8 +24,8 @@ def __init__( **kwargs, ): """ - Constructs an inference network using a specified base distribution and optional custom metrics. - Use this interface for custom inference networks. + Creates the network with provided arguments. Optional user-supplied metrics will be stored + in a `custom_metrics` attribute. A special `metrics` attribute will be created internally by `keras.Layer`. Parameters ---------- diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index e2455e85a..935933836 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -33,8 +33,8 @@ def __init__( **kwargs, ): """ - Implements a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, and - spectral normalization. + Creates a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, + spectral normalization, and metrics. This MLP can be used as a general-purpose feature extractor or function approximator, supporting configurable depth, width, activation functions, and weight initializations. @@ -42,6 +42,9 @@ def __init__( If `residual` is enabled, each layer includes a skip connection for improved gradient flow. The model also supports dropout for regularization and spectral normalization for stability in learning smooth functions. + Optional user-supplied metrics will be stored in a `custom_metrics` attribute. A special `metrics` attribute + will be created internally by `keras.Layer`. + Parameters ---------- widths : Sequence[int], optional diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 4d6ecda0a..6834824bd 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -10,17 +10,27 @@ @serializable("bayesflow.networks") class SummaryNetwork(keras.Layer): + """ + Builds a summary network with an optional base distribution and custom metrics. Use this class + as an interface for custom summary networks. + + Important + --------- + If a base distribution is passed, the summary outputs will be optimized to follow + that distribution, as described in [1]. + + References + ---------- + [1] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2023). + Detecting model misspecification in amortized Bayesian inference with neural networks. + In DAGM German Conference on Pattern Recognition (pp. 541-557). + Cham: Springer Nature Switzerland. + """ + def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs): """ - Builds a summary network with an optional base distribution and custom metrics. Use this class - as an interface for custom summary networks. - - Important: If a base distribution is passed, the summary outputs will be optimized to follow - said distribution, as described in [1]. - - [1] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2023). - Detecting model misspecification in amortized Bayesian inference with neural networks. - In DAGM German Conference on Pattern Recognition (pp. 541-557). Cham: Springer Nature Switzerland. + Creates the network with provided arguments. Optional user-supplied metrics will be stored + in a `custom_metrics` attribute. A special `metrics` attribute will be created internally by `keras.Layer`. Parameters ---------- From 5f38e86d34d3e3f7bbf110eb7b6ec3ee84009994 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 1 Jul 2025 12:51:34 +0000 Subject: [PATCH 10/13] compare metric in assert layer/model equal --- tests/utils/assertions.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index fc4219206..727b0f867 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -13,6 +13,11 @@ def assert_models_equal(model1: keras.Model, model2: keras.Model): else: assert_layers_equal(layer1, layer2) + assert len(model1.metrics) == len(model2.metrics) + for metric1, metric2 in zip(model1.metrics, model2.metrics): + assert type(metric1) is type(metric2) + assert metric1.name == metric2.name + def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): msg = f"Layers {layer1.name} and {layer2.name} have different types." @@ -40,3 +45,8 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): # this is turned off for now, see https://github.com/bayesflow-org/bayesflow/issues/412 msg = f"Layers {layer1.name} and {layer2.name} have a different name." # assert layer1.name == layer2.name, msg + + assert len(layer1.metrics) == len(layer2.metrics), f"metrics do not match: {layer1.metrics}!={layer2.metrics}" + for metric1, metric2 in zip(layer1.metrics, layer2.metrics): + assert type(metric1) is type(metric2) + assert metric1.name == metric2.name From 12a6d4b8f8e9442ca91358957d9c3b005e2a2cf2 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 1 Jul 2025 12:52:37 +0000 Subject: [PATCH 11/13] add get_config for network base classes - get config has to be manually specified in the base classes, so that the config is stored even when to subclass overrides get_config - to preserve the auto_config behavior, we have to use the `python_utils.default` decorator from, which marks them as default methods. This allows detecting if a subclass has overridden them. This is the same mechanism that Keras uses - moved setting the `custom_metrics` parameter after the `super().__init__` calls, as the tracking is managed in setattr - extended some tests to use metrics --- .../experimental/diffusion_model/diffusion_model.py | 3 +++ .../experimental/free_form_flow/free_form_flow.py | 2 -- .../networks/consistency_models/consistency_model.py | 3 +++ bayesflow/networks/coupling_flow/coupling_flow.py | 2 +- bayesflow/networks/inference_network.py | 12 ++++++++++-- bayesflow/networks/point_inference_network.py | 11 ++++++++--- bayesflow/networks/summary_network.py | 12 ++++++++++-- tests/test_approximators/conftest.py | 7 ++++++- 8 files changed, 41 insertions(+), 11 deletions(-) diff --git a/bayesflow/experimental/diffusion_model/diffusion_model.py b/bayesflow/experimental/diffusion_model/diffusion_model.py index f703568ba..4988878f1 100644 --- a/bayesflow/experimental/diffusion_model/diffusion_model.py +++ b/bayesflow/experimental/diffusion_model/diffusion_model.py @@ -141,6 +141,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: def get_config(self): base_config = super().get_config() + # base distribution is passed manually to InferenceNetwork parent class, do not store it here + base_config.pop("base_distribution") + config = { "subnet": self.subnet, "noise_schedule": self.noise_schedule, diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index d1e826864..bc324770f 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -8,7 +8,6 @@ find_network, jacobian, jvp, - model_kwargs, vjp, weighted_mean, ) @@ -240,7 +239,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = model_kwargs(base_config) config = { "beta": self.beta, diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index bfdca1b0b..41cba2d41 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -110,6 +110,9 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() + # base distribution is passed manually to InferenceNetwork parent class, do not store it here + base_config.pop("base_distribution") + config = { "total_steps": self.total_steps, "subnet": self.subnet, diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index b63909557..31a03837b 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -90,7 +90,7 @@ def __init__( Keyword arguments forwarded to the affine or spline transforms (e.g., bins for splines) **kwargs - Additional keyword arguments passed to `InvertibleLayer`. + Additional keyword arguments passed to `InferenceNetwork`. """ super().__init__(base_distribution=base_distribution, **kwargs) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 432c86941..3469d434f 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -2,11 +2,12 @@ from collections.abc import Sequence import keras +from keras.src.utils import python_utils from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size -from bayesflow.utils.serialization import serializable +from bayesflow.utils.serialization import serializable, serialize @serializable("bayesflow.networks") @@ -38,8 +39,8 @@ def __init__( **kwargs Additional keyword arguments forwarded to the `keras.Layer` constructor. """ - self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) + self.custom_metrics = metrics self.base_distribution = find_distribution(base_distribution) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: @@ -104,3 +105,10 @@ def compute_metrics( metrics[metric.name] = metric(samples, x) return metrics + + @python_utils.default + def get_config(self): + base_config = super().get_config() + + config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} + return base_config | serialize(config) diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 402632355..4e1b499f4 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -1,4 +1,6 @@ +from collections.abc import Sequence import keras +from keras.src.utils import python_utils from bayesflow.utils import model_kwargs, find_network from bayesflow.utils.serialization import deserialize, serializable, serialize @@ -17,9 +19,12 @@ def __init__( self, scores: dict[str, ScoringRule], subnet: str | keras.Layer = "mlp", + *, + metrics: Sequence[keras.Metric] | None = None, **kwargs, ): super().__init__(**model_kwargs(kwargs)) + self.custom_metrics = metrics self.scores = scores @@ -28,6 +33,7 @@ def __init__( self.config = { "subnet": serialize(subnet), "scores": serialize(scores), + "metrics": serialize(metrics), **kwargs, } @@ -106,6 +112,7 @@ def build_from_config(self, config): for head_key, head in self.heads[score_key].items(): head.name = config["heads"][score_key][head_key] + @python_utils.default def get_config(self): base_config = super().get_config() @@ -114,9 +121,7 @@ def get_config(self): @classmethod def from_config(cls, config): config = config.copy() - config["scores"] = deserialize(config["scores"]) - config["subnet"] = deserialize(config["subnet"]) - return cls(**config) + return cls(**deserialize(config)) def call( self, diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 6834824bd..77b4306a1 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -1,11 +1,12 @@ import keras +from keras.src.utils import python_utils from collections.abc import Sequence from bayesflow.metrics.functional import maximum_mean_discrepancy from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape -from bayesflow.utils.serialization import serializable +from bayesflow.utils.serialization import serializable, serialize @serializable("bayesflow.networks") @@ -44,8 +45,8 @@ def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Met **kwargs Additional keyword arguments forwarded to the `keras.Layer` constructor. """ - self.custom_metrics = metrics super().__init__(**layer_kwargs(kwargs)) + self.custom_metrics = metrics self.base_distribution = find_distribution(base_distribution) @sanitize_input_shape @@ -86,3 +87,10 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ metrics[metric.name] = metric(outputs, samples) return metrics + + @python_utils.default + def get_config(self): + base_config = super().get_config() + + config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} + return base_config | serialize(config) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 4c67e9321..31f17404d 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -40,6 +40,7 @@ def continuous_approximator(adapter, inference_network, summary_network): @pytest.fixture() def point_inference_network(): + from bayesflow.metrics import RootMeanSquaredError from bayesflow.networks import PointInferenceNetwork from bayesflow.scores import NormedDifferenceScore, QuantileScore, MultivariateNormalScore @@ -51,11 +52,13 @@ def point_inference_network(): ), subnet="mlp", subnet_kwargs=dict(widths=(32, 32)), + metrics=[RootMeanSquaredError(name="rmse")], ) @pytest.fixture() def point_inference_network_with_multiple_parametric_scores(): + from bayesflow.metrics import RootMeanSquaredError from bayesflow.networks import PointInferenceNetwork from bayesflow.scores import MultivariateNormalScore @@ -64,6 +67,7 @@ def point_inference_network_with_multiple_parametric_scores(): mvn1=MultivariateNormalScore(), mvn2=MultivariateNormalScore(), ), + metrics=[RootMeanSquaredError(name="rmse")], ) @@ -181,9 +185,10 @@ def validation_dataset(batch_size, adapter, simulator): @pytest.fixture() def mean_std_summary_network(): + from bayesflow.metrics import MaximumMeanDiscrepancy from tests.utils import MeanStdSummaryNetwork - return MeanStdSummaryNetwork() + return MeanStdSummaryNetwork(metrics=[MaximumMeanDiscrepancy("mmd")]) @pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"]) From 15e11c1d2ee48656bbab705992c74a4a51c20a9b Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 1 Jul 2025 13:17:44 +0000 Subject: [PATCH 12/13] add missing normalize config call to test --- .../test_point_inference_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 38ba8ea4e..7e4ca920b 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -3,7 +3,7 @@ deserialize_keras_object as deserialize, serialize_keras_object as serialize, ) -from tests.utils import assert_layers_equal +from tests.utils import assert_layers_equal, normalize_config import pytest @@ -72,7 +72,7 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand loaded = keras.saving.load_model(tmp_path / "model.keras") print(net.get_config()) - assert net.get_config() == loaded.get_config() + assert normalize_config(net.get_config()) == normalize_config(loaded.get_config()) assert_layers_equal(net, loaded) From f536064703564a38d025d327bebee53c9aea6f52 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 1 Jul 2025 13:58:59 +0000 Subject: [PATCH 13/13] add assert_configs_equal test utility --- tests/test_adapters/test_adapters.py | 8 ++++---- .../test_fusion_network/test_fusion_network.py | 4 ++-- tests/test_networks/test_inference_networks.py | 4 ++-- .../test_point_inference_network.py | 4 ++-- tests/test_networks/test_summary_networks.py | 4 ++-- tests/utils/assertions.py | 5 +++++ 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index c951aa0e8..ae271038e 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -4,7 +4,7 @@ import keras from bayesflow.utils.serialization import deserialize, serialize -from tests.utils import normalize_config +from tests.utils import assert_configs_equal import bayesflow as bf @@ -30,7 +30,7 @@ def test_serialize_deserialize(adapter, random_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert normalize_config(serialized) == normalize_config(reserialized) + assert_configs_equal(serialized, reserialized) random_data["foo"] = random_data["x1"] deserialized_processed = deserialized(random_data) @@ -335,7 +335,7 @@ def test_nnpe(random_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert normalize_config(serialized) == normalize_config(reserialized) + assert_configs_equal(serialized, reserialized) # check that only x1 is changed assert "x1" in result_training @@ -365,7 +365,7 @@ def test_nnpe(random_data): serialized_auto = serialize(ad_auto) deserialized_auto = deserialize(serialized_auto) reserialized_auto = serialize(deserialized_auto) - assert normalize_config(serialized_auto) == normalize_config(serialize(reserialized_auto)) + assert_configs_equal(serialized_auto, serialize(reserialized_auto)) # Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False) # Create data with second dimension having higher variance diff --git a/tests/test_networks/test_fusion_network/test_fusion_network.py b/tests/test_networks/test_fusion_network/test_fusion_network.py index 827c5b20b..b73e7c426 100644 --- a/tests/test_networks/test_fusion_network/test_fusion_network.py +++ b/tests/test_networks/test_fusion_network/test_fusion_network.py @@ -2,7 +2,7 @@ import pytest import keras -from tests.utils import assert_layers_equal, allclose, normalize_config +from tests.utils import assert_layers_equal, assert_configs_equal, allclose @pytest.mark.parametrize("automatic", [True, False]) @@ -57,7 +57,7 @@ def test_serialize_deserialize(fusion_network, multimodal_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert normalize_config(serialized) == normalize_config(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, fusion_network, multimodal_data): diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index bc42942fb..86b934b08 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import serialize, deserialize -from tests.utils import assert_allclose, assert_layers_equal, normalize_config +from tests.utils import assert_allclose, assert_layers_equal, assert_configs_equal def test_build(inference_network, random_samples, random_conditions): @@ -140,7 +140,7 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert normalize_config(serialized) == normalize_config(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 7e4ca920b..ac924bc1d 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -3,7 +3,7 @@ deserialize_keras_object as deserialize, serialize_keras_object as serialize, ) -from tests.utils import assert_layers_equal, normalize_config +from tests.utils import assert_layers_equal, assert_configs_equal import pytest @@ -72,7 +72,7 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand loaded = keras.saving.load_model(tmp_path / "model.keras") print(net.get_config()) - assert normalize_config(net.get_config()) == normalize_config(loaded.get_config()) + assert_configs_equal(net.get_config(), loaded.get_config()) assert_layers_equal(net, loaded) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index f40012fc4..6b6518452 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import deserialize, serialize -from tests.utils import assert_layers_equal, normalize_config +from tests.utils import assert_layers_equal, assert_configs_equal @pytest.mark.parametrize("automatic", [True, False]) @@ -85,7 +85,7 @@ def test_serialize_deserialize(summary_network, random_set): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert normalize_config(serialized) == normalize_config(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, summary_network, random_set): diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index 727b0f867..240502984 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -1,4 +1,5 @@ import keras +from .normalize import normalize_config def assert_models_equal(model1: keras.Model, model2: keras.Model): @@ -50,3 +51,7 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): for metric1, metric2 in zip(layer1.metrics, layer2.metrics): assert type(metric1) is type(metric2) assert metric1.name == metric2.name + + +def assert_configs_equal(config1: dict, config2: dict): + assert normalize_config(config1) == normalize_config(config2)