Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
211b20f
Draft of ApproximatorEnsemble
han-ol Jul 2, 2025
d8fa571
Slice training batch, giving every ensemble member independent traini…
han-ol Jul 4, 2025
a146e6a
OfflineEnsembleDataset: independent slices of indices for each ensemb…
han-ol Jul 4, 2025
e12eb08
Example notebook training an ApproximatorEnsemble with an OfflineEnse…
han-ol Jul 4, 2025
70bb230
Make ApproximatorEnsemble directly importable
han-ol Jul 8, 2025
711cd7b
Add log_prob, estimate and predict wrapper methods; flexible building
han-ol Jul 8, 2025
955ac79
Make use of approximator.build_from_data() in fit
han-ol Jul 8, 2025
d8d84c8
Fix build_from_data in ModelComparisonApproximator
elseml Jul 9, 2025
7daaab2
Make use of approximator.build_from_data() in fit
han-ol Jul 8, 2025
71d8382
Fix build_from_data in ModelComparisonApproximator
elseml Jul 9, 2025
4d3130b
Add predict wrapper method
han-ol Jul 17, 2025
e60a25e
Tests for ApproximatorEnsemble
han-ol Jul 17, 2025
df7b5d2
Make OfflineEnsembleDataset directly importable
han-ol Jul 17, 2025
0693ec5
Merge remote-tracking branch 'upstream/ensembles' into ensembles
han-ol Jul 17, 2025
1ca2a76
Fit test for ensembles
han-ol Jul 17, 2025
59cf297
Tests for OfflineEnsembleDataset
han-ol Jul 17, 2025
3590441
Merge branch 'dev' into ensembles
han-ol Jul 22, 2025
a079bd7
Merge remote-tracking branch 'upstream/dev' into ensembles
vpratz Aug 5, 2025
4f20680
make ApproximatorEnsemble serializable [no ci]
vpratz Aug 5, 2025
2929bc4
Merge remote-tracking branch 'upstream/dev' into ensembles [no ci]
vpratz Aug 13, 2025
d6120c3
support for nested summary variables [no ci]
vpratz Aug 16, 2025
310c9bf
Merge remote-tracking branch 'upstream/dev' into ensembles [no ci]
vpratz Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, display_name, package_name, env_name, install_url, priority):
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .approximators import ContinuousApproximator, PointApproximator, ApproximatorEnsemble
from .datasets import OfflineDataset, OnlineDataset, DiskDataset, OfflineEnsembleDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
2 changes: 2 additions & 0 deletions bayesflow/approximators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .point_approximator import PointApproximator
from .model_comparison_approximator import ModelComparisonApproximator

from .approximator_ensemble import ApproximatorEnsemble

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
3 changes: 1 addition & 2 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
self.build(mock_data_shapes)
self.build_from_data(mock_data)

return super().fit(dataset=dataset, **kwargs)
156 changes: 156 additions & 0 deletions bayesflow/approximators/approximator_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from collections.abc import Mapping

import numpy as np

import keras

from bayesflow.types import Tensor
from bayesflow.utils.serialization import deserialize, serializable, serialize


from .approximator import Approximator
from .model_comparison_approximator import ModelComparisonApproximator


@serializable("bayesflow.approximators")
class ApproximatorEnsemble(Approximator):
def __init__(self, approximators: dict[str, Approximator], **kwargs):
super().__init__(**kwargs)

self.approximators = approximators

self.num_approximators = len(self.approximators)

def build_from_data(self, adapted_data: dict[str, any]):
data_shapes = keras.tree.map_structure(keras.ops.shape, adapted_data)
if len(data_shapes["inference_variables"]) > 2:
# Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
# batch of training data, where the second axis corresponds to different approximators.
data_shapes = keras.tree.map_shape_structure(lambda shape: shape[:1] + shape[2:], data_shapes)
self.build(data_shapes)

def build(self, input_shape: dict[str, tuple[int] | dict[str, dict]]) -> None:
for approximator in self.approximators.values():
approximator.build(input_shape)

def compute_metrics(
self,
inference_variables: Tensor,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
sample_weight: Tensor = None,
stage: str = "training",
) -> dict[str, dict[str, Tensor]]:
# Prepare empty dict for metrics
metrics = {}

# Define the variable slices as None (default) or respective input
_inference_variables = inference_variables
_inference_conditions = inference_conditions
_summary_variables = summary_variables
_sample_weight = sample_weight

for i, (approx_name, approximator) in enumerate(self.approximators.items()):
# During training each approximator receives its own separate slice
if stage == "training" and inference_variables.ndim > 2:
# Pick out the correct slice for each ensemble member
_inference_variables = inference_variables[:, i]
if inference_conditions is not None:
_inference_conditions = inference_conditions[:, i]
if summary_variables is not None:
_summary_variables = keras.tree.map_structure(lambda v: v[:, i], summary_variables)
if sample_weight is not None:
_sample_weight = sample_weight[:, i]

metrics[approx_name] = approximator.compute_metrics(
inference_variables=_inference_variables,
inference_conditions=_inference_conditions,
summary_variables=_summary_variables,
sample_weight=_sample_weight,
stage=stage,
)

# Flatten metrics dict
joint_metrics = {}
for approx_name in metrics.keys():
for metric_key, value in metrics[approx_name].items():
joint_metrics[f"{approx_name}/{metric_key}"] = value
metrics = joint_metrics

# Sum over losses
losses = [v for k, v in metrics.items() if "loss" in k]
metrics["loss"] = keras.ops.sum(losses)

return metrics

def sample(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
samples = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "sample"):
samples[approx_name] = approximator.sample(
num_samples=num_samples, conditions=conditions, split=split, **kwargs
)
return samples

def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
log_prob = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "log_prob"):
log_prob[approx_name] = approximator.log_prob(data=data, **kwargs)
return log_prob

def estimate(
self,
conditions: Mapping[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
estimates = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "estimate"):
estimates[approx_name] = approximator.estimate(conditions=conditions, split=split, **kwargs)
return estimates

def predict(
self,
*,
conditions: Mapping[str, np.ndarray],
probs: bool = True,
**kwargs,
) -> dict[str, np.ndarray]:
predictions = {}
for approx_name, approximator in self.approximators.items():
if isinstance(approximator, ModelComparisonApproximator):
predictions[approx_name] = approximator.predict(conditions=conditions, probs=probs, **kwargs)
return predictions

def _has_obj_method(self, obj, name):
method = getattr(obj, name, None)
return callable(method)

def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
"""
Fetches the current batch size from an input dictionary. Can only be used during training when
inference variables as present.
"""
return keras.ops.shape(data["inference_variables"])[0]

def get_config(self):
base_config = super().get_config()
config = {"approximators": self.approximators}
return base_config | serialize(config)

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

def build_from_config(self, config):
# the approximators are already built
pass
2 changes: 1 addition & 1 deletion bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
self.built = True

def build_from_data(self, adapted_data: dict[str, any]):
self.build(keras.tree.map_structure(keras.ops.shape(adapted_data)))
self.build(keras.tree.map_structure(keras.ops.shape, adapted_data))

@classmethod
def build_adapter(
Expand Down
1 change: 1 addition & 0 deletions bayesflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from .offline_dataset import OfflineDataset
from .offline_ensemble_dataset import OfflineEnsembleDataset
from .online_dataset import OnlineDataset
from .disk_dataset import DiskDataset

Expand Down
30 changes: 30 additions & 0 deletions bayesflow/datasets/offline_ensemble_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np

from .offline_dataset import OfflineDataset


class OfflineEnsembleDataset(OfflineDataset):
"""
A dataset that is pre-simulated and stored in memory, extending :py:class:`OfflineDataset`.

The only difference is that it allows to train an :py:class:`ApproximatorEnsemble` in parallel by returning
batches with ``num_ensemble`` different random subsets of the available data.
"""

def __init__(self, num_ensemble: int, **kwargs):
super().__init__(**kwargs)
self.num_ensemble = num_ensemble

# Create indices with shape (num_samples, num_ensemble)
_indices = np.arange(self.num_samples, dtype="int64")
_indices = np.repeat(_indices[:, None], self.num_ensemble, axis=1)

# Shuffle independently along second axis
for i in range(self.num_ensemble):
np.random.shuffle(_indices[:, i])

self.indices = _indices

# Shuffle first axis
if self._shuffle:
self.shuffle()
Loading