Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,17 @@ def compute_metrics(
else:
loss = inference_metrics.pop("loss")

if len(self.losses) > 0:
layer_loss = keras.ops.sum(self.losses)
loss += layer_loss
layer_loss_metrics = {"layer_loss": layer_loss}
else:
layer_loss_metrics = {}

inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()}
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}

metrics = {"loss": loss} | inference_metrics | summary_metrics
metrics = {"loss": loss} | layer_loss_metrics | inference_metrics | summary_metrics
return metrics

def _compute_summary_metrics(self, summary_variables: Tensor | None, stage: str) -> tuple[dict, Tensor | None]:
Expand Down
9 changes: 8 additions & 1 deletion bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ def compute_metrics(
else:
loss = classifier_metrics.pop("loss")

if len(self.losses) > 0:
layer_loss = keras.ops.sum(self.losses)
loss += layer_loss
layer_loss_metrics = {"layer_loss": layer_loss}
else:
layer_loss_metrics = {}

classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()}
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}

metrics = {"loss": loss} | classifier_metrics | summary_metrics
metrics = {"loss": loss} | layer_loss_metrics | classifier_metrics | summary_metrics
return metrics

def fit(
Expand Down
151 changes: 151 additions & 0 deletions examples/Custom_losses_with_add_loss.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c0545c7e-d9b0-4e1d-98b9-199afe1bcc31",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a776c519-c0ac-4a14-8841-e3e64d2b1716",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-07-22 14:07:15.076444: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-07-22 14:07:15.079630: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-07-22 14:07:15.087697: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1753186035.101632 583449 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1753186035.105502 583449 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2025-07-22 14:07:15.121221: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2025-07-22 14:07:17.542250: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)\n",
"INFO:bayesflow:Using backend 'tensorflow'\n",
"INFO:bayesflow:Fitting on dataset instance of OnlineDataset.\n",
"INFO:bayesflow:Building on a test batch.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 64ms/step - layer_loss: 4.1039e-04 - loss: 2.8483\n",
"Epoch 2/5\n",
"\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.0456 - loss: 2.5975 \n",
"Epoch 3/5\n",
"\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2125 - loss: 0.8778\n",
"Epoch 4/5\n",
"\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2350 - loss: 0.3789\n",
"Epoch 5/5\n",
"\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2238 - loss: 0.1804\n"
]
}
],
"source": [
"import bayesflow as bf\n",
"import keras\n",
"\n",
"class CustomTimeSeriesNetwork(bf.networks.TimeSeriesNetwork):\n",
" def call(self, x, training=False, **kwargs):\n",
" x = super().call(x, training=training, **kwargs)\n",
" self.add_loss(keras.ops.sum(x**2))\n",
" return x\n",
"\n",
"workflow = bf.BasicWorkflow(\n",
" inference_network=bf.networks.CouplingFlow(),\n",
" summary_network=CustomTimeSeriesNetwork(),\n",
" inference_variables=[\"parameters\"],\n",
" summary_variables=[\"observables\"],\n",
" simulator=bf.simulators.SIR()\n",
")\n",
"\n",
"history = workflow.fit_online(epochs=5, batch_size=32, num_batches_per_epoch=200)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "66900aa0-99e8-41ee-a08d-5e10b946deb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Tensor 'custom_time_series_network_1/Sum:0' shape=() dtype=float32>]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"workflow.approximator.summary_network.losses"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "45e835af-8d1e-4d63-a349-f98e68a02667",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Tensor 'custom_time_series_network_1/Sum:0' shape=() dtype=float32>]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"workflow.approximator.losses"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
48 changes: 48 additions & 0 deletions tests/test_approximators/test_add_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import keras
import io
from contextlib import redirect_stdout


@pytest.fixture()
def approximator_using_add_loss(adapter):
from bayesflow import ContinuousApproximator
from bayesflow.networks import CouplingFlow, MLP

class MLPAddedLoss(MLP):
def call(self, x, training=False, **kwargs):
x = super().call(x, training=training, **kwargs)
self.add_loss(keras.ops.sum(x**2))
return x

return ContinuousApproximator(
adapter=adapter,
inference_network=CouplingFlow(subnet=MLPAddedLoss),
summary_network=None,
)


def test_layer_loss_reported(approximator_using_add_loss, train_dataset, validation_dataset):
import os

if os.environ["KERAS_BACKEND"] == "jax":
pytest.skip(reason="With JAX backend, the compute_metrics method currently fails to consider self.losses.")

approximator = approximator_using_add_loss
approximator.compile(optimizer="AdamW")
num_epochs = 3

# Capture ostream and train model
with io.StringIO() as stream:
with redirect_stdout(stream):
approximator.fit(dataset=train_dataset, validation_data=validation_dataset, epochs=num_epochs)

output = stream.getvalue()

print(output)

# check that there is a progress bar
assert "━" in output, "no progress bar"

# check that layer_loss is reported
assert "layer_loss" in output, "no layer_loss"