1
1
import pytest
2
2
3
3
from bayesflow .networks import MLP
4
+ from bayesflow .metrics import RootMeanSquaredError
4
5
5
6
6
7
@pytest .fixture ()
@@ -12,6 +13,7 @@ def diffusion_model_edm_F():
12
13
integrate_kwargs = {"method" : "rk45" , "steps" : 250 },
13
14
noise_schedule = "edm" ,
14
15
prediction_type = "F" ,
16
+ metrics = [RootMeanSquaredError ()],
15
17
)
16
18
17
19
@@ -82,22 +84,32 @@ def flow_matching():
82
84
return FlowMatching (
83
85
subnet = MLP ([8 , 8 ]),
84
86
integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
87
+ metrics = [RootMeanSquaredError ()],
85
88
)
86
89
87
90
88
91
@pytest .fixture ()
89
92
def consistency_model ():
90
93
from bayesflow .networks import ConsistencyModel
91
94
92
- return ConsistencyModel (total_steps = 100 , subnet = MLP ([8 , 8 ]))
95
+ return ConsistencyModel (
96
+ total_steps = 100 ,
97
+ subnet = MLP ([8 , 8 ]),
98
+ metrics = [RootMeanSquaredError ()],
99
+ )
93
100
94
101
95
102
@pytest .fixture ()
96
103
def affine_coupling_flow ():
97
104
from bayesflow .networks import CouplingFlow
98
105
99
106
return CouplingFlow (
100
- depth = 2 , subnet = "mlp" , subnet_kwargs = dict (widths = [8 , 8 ]), transform = "affine" , transform_kwargs = dict (clamp = 1.8 )
107
+ depth = 2 ,
108
+ subnet = "mlp" ,
109
+ subnet_kwargs = dict (widths = [8 , 8 ]),
110
+ transform = "affine" ,
111
+ transform_kwargs = dict (clamp = 1.8 ),
112
+ metrics = [RootMeanSquaredError ()],
101
113
)
102
114
103
115
@@ -106,15 +118,24 @@ def spline_coupling_flow():
106
118
from bayesflow .networks import CouplingFlow
107
119
108
120
return CouplingFlow (
109
- depth = 2 , subnet = "mlp" , subnet_kwargs = dict (widths = [8 , 8 ]), transform = "spline" , transform_kwargs = dict (bins = 8 )
121
+ depth = 2 ,
122
+ subnet = "mlp" ,
123
+ subnet_kwargs = dict (widths = [8 , 8 ]),
124
+ transform = "spline" ,
125
+ transform_kwargs = dict (bins = 8 ),
126
+ metrics = [RootMeanSquaredError ()],
110
127
)
111
128
112
129
113
130
@pytest .fixture ()
114
131
def free_form_flow ():
115
132
from bayesflow .experimental import FreeFormFlow
116
133
117
- return FreeFormFlow (encoder_subnet = MLP ([16 , 16 ]), decoder_subnet = MLP ([16 , 16 ]))
134
+ return FreeFormFlow (
135
+ encoder_subnet = MLP ([16 , 16 ]),
136
+ decoder_subnet = MLP ([16 , 16 ]),
137
+ metrics = [RootMeanSquaredError ()],
138
+ )
118
139
119
140
120
141
@pytest .fixture ()
@@ -236,35 +257,35 @@ def generative_inference_network(request):
236
257
def time_series_network (summary_dim ):
237
258
from bayesflow .networks import TimeSeriesNetwork
238
259
239
- return TimeSeriesNetwork (summary_dim = summary_dim )
260
+ return TimeSeriesNetwork (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
240
261
241
262
242
263
@pytest .fixture (scope = "function" )
243
264
def time_series_transformer (summary_dim ):
244
265
from bayesflow .networks import TimeSeriesTransformer
245
266
246
- return TimeSeriesTransformer (summary_dim = summary_dim )
267
+ return TimeSeriesTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
247
268
248
269
249
270
@pytest .fixture (scope = "function" )
250
271
def fusion_transformer (summary_dim ):
251
272
from bayesflow .networks import FusionTransformer
252
273
253
- return FusionTransformer (summary_dim = summary_dim )
274
+ return FusionTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
254
275
255
276
256
277
@pytest .fixture (scope = "function" )
257
278
def set_transformer (summary_dim ):
258
279
from bayesflow .networks import SetTransformer
259
280
260
- return SetTransformer (summary_dim = summary_dim )
281
+ return SetTransformer (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
261
282
262
283
263
284
@pytest .fixture (scope = "function" )
264
285
def deep_set (summary_dim ):
265
286
from bayesflow .networks import DeepSet
266
287
267
- return DeepSet (summary_dim = summary_dim )
288
+ return DeepSet (summary_dim = summary_dim , metrics = [ RootMeanSquaredError ()] )
268
289
269
290
270
291
@pytest .fixture (
0 commit comments