1
- from functools import partial
2
-
3
1
import keras
4
2
5
- from bayesflow .utils .serialization import serializable
3
+ from bayesflow .utils .serialization import deserialize , serializable , serialize
6
4
from .functional import maximum_mean_discrepancy
7
5
8
6
@@ -17,10 +15,22 @@ def __init__(
17
15
):
18
16
super ().__init__ (name = name , ** kwargs )
19
17
self .mmd = self .add_variable (shape = (), initializer = "zeros" , name = "mmd" )
20
- self .mmd_fn = partial (maximum_mean_discrepancy , kernel = kernel , unbiased = unbiased )
18
+ self .kernel = kernel
19
+ self .unbiased = unbiased
21
20
22
21
def update_state (self , x , y ):
23
- self .mmd .assign (keras .ops .cast (self .mmd_fn (x , y ), self .dtype ))
22
+ self .mmd .assign (
23
+ keras .ops .cast (maximum_mean_discrepancy (x , y , kernel = self .kernel , unbiased = self .unbiased ), self .dtype )
24
+ )
24
25
25
26
def result (self ):
26
27
return self .mmd .value
28
+
29
+ def get_config (self ):
30
+ base_config = super ().get_config ()
31
+ config = {"kernel" : self .kernel , "unbiased" : self .unbiased }
32
+ return base_config | serialize (config )
33
+
34
+ @classmethod
35
+ def from_config (cls , config , custom_objects = None ):
36
+ return cls (** deserialize (config , custom_objects = custom_objects ))
0 commit comments