Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence, Callable

import numpy as np
import warnings

import keras

Expand All @@ -9,7 +10,6 @@
from bayesflow.types import Tensor
from bayesflow.utils import (
filter_kwargs,
logging,
split_arrays,
squeeze_inner_estimates_dict,
concatenate_valid,
Expand Down Expand Up @@ -148,18 +148,21 @@ def build_adapter(
def compile(
self,
*args,
inference_metrics: Sequence[keras.Metric] = None,
summary_metrics: Sequence[keras.Metric] = None,
**kwargs,
):
if inference_metrics:
self.inference_network._metrics = inference_metrics
if "inference_metrics" in kwargs:
warnings.warn(
"Supplying inference metrics to the approximator is no longer supported. "
"Please pass the metrics directly to the network using the metrics parameter.",
DeprecationWarning,
)

if summary_metrics:
if self.summary_network is None:
logging.warning("Ignoring summary metrics because there is no summary network.")
else:
self.summary_network._metrics = summary_metrics
if "summary_metrics" in kwargs:
warnings.warn(
"Supplying summary metrics to the approximator is no longer supported. "
"Please pass the metrics directly to the network using the metrics parameter.",
DeprecationWarning,
)

return super().compile(*args, **kwargs)

Expand Down Expand Up @@ -329,16 +332,6 @@ def get_config(self):

return base_config | serialize(config)

def get_compile_config(self):
base_config = super().get_compile_config() or {}

config = {
"inference_metrics": self.inference_network._metrics,
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
}

return base_config | serialize(config)

def estimate(
self,
conditions: Mapping[str, np.ndarray],
Expand Down
37 changes: 16 additions & 21 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import keras
import numpy as np
import warnings

from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
Expand Down Expand Up @@ -151,18 +152,21 @@ def build_dataset(
def compile(
self,
*args,
classifier_metrics: Sequence[keras.Metric] = None,
summary_metrics: Sequence[keras.Metric] = None,
**kwargs,
):
if classifier_metrics:
self.classifier_network._metrics = classifier_metrics
if "classifier_metrics" in kwargs:
warnings.warn(
"Supplying classifier metrics to the approximator is no longer supported. "
"Please pass the metrics directly to the network using the metrics parameter.",
DeprecationWarning,
)

if summary_metrics:
if self.summary_network is None:
logging.warning("Ignoring summary metrics because there is no summary network.")
else:
self.summary_network._metrics = summary_metrics
if "summary_metrics" in kwargs:
warnings.warn(
"Supplying summary metrics to the approximator is no longer supported. "
"Please pass the metrics directly to the network using the metrics parameter.",
DeprecationWarning,
)

return super().compile(*args, **kwargs)

Expand Down Expand Up @@ -223,9 +227,10 @@ def compute_metrics(
classifier_metrics = {"loss": cross_entropy}

if stage != "training" and any(self.classifier_network.metrics):
predictions = keras.ops.argmax(logits, axis=-1)
# compute sample-based metrics
probabilities = keras.ops.softmax(logits)
classifier_metrics |= {
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
metric.name: metric(model_indices, probabilities) for metric in self.classifier_network.metrics
}

if "loss" in summary_metrics:
Expand Down Expand Up @@ -342,16 +347,6 @@ def get_config(self):

return base_config | serialize(config)

def get_compile_config(self):
base_config = super().get_compile_config() or {}

config = {
"classifier_metrics": self.classifier_network._metrics,
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
}

return base_config | serialize(config)

def predict(
self,
*,
Expand Down
2 changes: 0 additions & 2 deletions bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
find_network,
jacobian,
jvp,
model_kwargs,
vjp,
weighted_mean,
)
Expand Down Expand Up @@ -240,7 +239,6 @@ def from_config(cls, config, custom_objects=None):

def get_config(self):
base_config = super().get_config()
base_config = model_kwargs(base_config)

config = {
"beta": self.beta,
Expand Down
6 changes: 4 additions & 2 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
from bayesflow.utils import find_network, weighted_mean, tensor_utils, expand_right_as
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -115,7 +115,9 @@ def from_config(cls, config, custom_objects=None):

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

# base distribution is passed manually to InferenceNetwork parent class, do not store it here
base_config.pop("base_distribution")

config = {
"total_steps": self.total_steps,
Expand Down
4 changes: 1 addition & 3 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from bayesflow.types import Tensor
from bayesflow.utils import (
find_permutation,
layer_kwargs,
weighted_mean,
)
from bayesflow.utils.serialization import deserialize, serializable, serialize
Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(
Keyword arguments forwarded to the affine or spline transforms
(e.g., bins for splines)
**kwargs
Additional keyword arguments passed to `InvertibleLayer`.
Additional keyword arguments passed to `InferenceNetwork`.

"""
super().__init__(base_distribution=base_distribution, **kwargs)
Expand Down Expand Up @@ -131,7 +130,6 @@ def from_config(cls, config, custom_objects=None):

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

config = {
"subnet": self.subnet,
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/networks/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
expand_right_as,
find_network,
jacobian_trace,
layer_kwargs,
weighted_mean,
integrate,
integrate_stochastic,
Expand Down Expand Up @@ -156,7 +155,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

# base distribution is passed manually to InferenceNetwork parent class, do not store it here
base_config.pop("base_distribution")

config = {
"subnet": self.subnet,
Expand Down
2 changes: 0 additions & 2 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
find_network,
integrate,
jacobian_trace,
layer_kwargs,
optimal_transport,
weighted_mean,
tensor_utils,
Expand Down Expand Up @@ -154,7 +153,6 @@ def from_config(cls, config, custom_objects=None):

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

config = {
"subnet": self.subnet,
Expand Down
44 changes: 42 additions & 2 deletions bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,46 @@
from typing import Literal
from collections.abc import Sequence

import keras
from keras.src.utils import python_utils

from bayesflow.types import Shape, Tensor
from bayesflow.utils import layer_kwargs, find_distribution
from bayesflow.utils.decorators import allow_batch_size
from bayesflow.utils.serialization import serializable, serialize


@serializable("bayesflow.networks")
class InferenceNetwork(keras.Layer):
def __init__(self, base_distribution: str = "normal", **kwargs):
"""
Constructs an inference network using a specified base distribution and optional custom metrics.
Use this interface for custom inference networks.
"""

def __init__(
self,
base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal",
*,
metrics: Sequence[keras.Metric] | None = None,
**kwargs,
):
"""
Creates the network with provided arguments. Optional user-supplied metrics will be stored
in a `custom_metrics` attribute. A special `metrics` attribute will be created internally by `keras.Layer`.

Parameters
----------
base_distribution : Literal["normal", "student", "mixture"] or keras.Layer
Name or the actual base distribution to use. Passed to `find_distribution` to
obtain the corresponding distribution object.
metrics : Sequence[keras.Metric] or None, optional
Sequence of custom Keras Metric instances to compute during training
and evaluation. If `None`, no custom metrics are used.
**kwargs
Additional keyword arguments forwarded to the `keras.Layer` constructor.
"""
super().__init__(**layer_kwargs(kwargs))
self.custom_metrics = metrics
self.base_distribution = find_distribution(base_distribution)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
Expand Down Expand Up @@ -66,9 +99,16 @@ def compute_metrics(

if stage != "training" and any(self.metrics):
# compute sample-based metrics
samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions)
samples = self.sample(batch_shape=(keras.ops.shape(x)[0],), conditions=conditions)

for metric in self.metrics:
metrics[metric.name] = metric(samples, x)

return metrics

@python_utils.default
def get_config(self):
base_config = super().get_config()

config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution}
return base_config | serialize(config)
14 changes: 11 additions & 3 deletions bayesflow/networks/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,22 @@ def __init__(
dropout: Literal[0, None] | float = 0.05,
norm: Literal["batch", "layer"] | keras.Layer = None,
spectral_normalization: bool = False,
metrics: Sequence[keras.Metric] | None = None,
**kwargs,
):
"""
Implements a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, and
spectral normalization.
Creates a flexible multi-layer perceptron (MLP) with optional residual connections, dropout,
spectral normalization, and metrics.

This MLP can be used as a general-purpose feature extractor or function approximator, supporting configurable
depth, width, activation functions, and weight initializations.

If `residual` is enabled, each layer includes a skip connection for improved gradient flow. The model also
supports dropout for regularization and spectral normalization for stability in learning smooth functions.

Optional user-supplied metrics will be stored in a `custom_metrics` attribute. A special `metrics` attribute
will be created internally by `keras.Layer`.

Parameters
----------
widths : Sequence[int], optional
Expand All @@ -54,12 +58,15 @@ def __init__(
dropout : float or None, optional
Dropout rate applied within the MLP layers for regularization. Default is 0.05.
norm: str, optional

Type of learnable normalization to be used (e.g., "batch" or "layer"). Default is None.
spectral_normalization : bool, optional
Whether to apply spectral normalization to stabilize training. Default is False.
metrics: Sequence[keras.Metric], optional
A sequence of callable metrics following keras' `Metric` signature. Default is None.
**kwargs
Additional keyword arguments passed to the Keras layer initialization.
"""
self.custom_metrics = metrics
self.widths = list(widths)
self.activation = activation
self.kernel_initializer = kernel_initializer
Expand Down Expand Up @@ -90,6 +97,7 @@ def get_config(self):
"dropout": self.dropout,
"norm": self.norm,
"spectral_normalization": self.spectral_normalization,
"metrics": self.custom_metrics,
}

return base_config | serialize(config)
Expand Down
11 changes: 8 additions & 3 deletions bayesflow/networks/point_inference_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Sequence
import keras
from keras.src.utils import python_utils

from bayesflow.utils import model_kwargs, find_network
from bayesflow.utils.serialization import deserialize, serializable, serialize
Expand All @@ -17,9 +19,12 @@ def __init__(
self,
scores: dict[str, ScoringRule],
subnet: str | keras.Layer = "mlp",
*,
metrics: Sequence[keras.Metric] | None = None,
**kwargs,
):
super().__init__(**model_kwargs(kwargs))
self.custom_metrics = metrics

self.scores = scores

Expand All @@ -28,6 +33,7 @@ def __init__(
self.config = {
"subnet": serialize(subnet),
"scores": serialize(scores),
"metrics": serialize(metrics),
**kwargs,
}

Expand Down Expand Up @@ -106,6 +112,7 @@ def build_from_config(self, config):
for head_key, head in self.heads[score_key].items():
head.name = config["heads"][score_key][head_key]

@python_utils.default
def get_config(self):
base_config = super().get_config()

Expand All @@ -114,9 +121,7 @@ def get_config(self):
@classmethod
def from_config(cls, config):
config = config.copy()
config["scores"] = deserialize(config["scores"])
config["subnet"] = deserialize(config["subnet"])
return cls(**config)
return cls(**deserialize(config))

def call(
self,
Expand Down
Loading