Skip to content

Commit

Permalink
Fix missing config param in attention layers
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorhansen committed Sep 24, 2023
1 parent c7cf3c7 commit a411263
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/py/models/utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
super().__init__(**kwargs)
self.num_heads = num_heads
self.depth = depth
self.use_bias = use_bias

dim = num_heads * depth
self.query = tf.keras.layers.Dense(
Expand Down Expand Up @@ -234,7 +235,11 @@ def get_config(self):
return (
super().get_config()
| self.get_initializer_config()
| {"num_heads": self.num_heads, "depth": self.depth}
| {
"num_heads": self.num_heads,
"depth": self.depth,
"use_bias": self.use_bias,
}
)


Expand Down Expand Up @@ -331,6 +336,7 @@ def get_config(self):
"depth": self.mha.depth,
"rff": self.rff,
"use_layer_norm": self.use_layer_norm,
"use_bias": self.mha.use_bias,
}
)

Expand Down Expand Up @@ -409,6 +415,7 @@ def get_config(self):
"depth": self.mab.mha.depth,
"rff": self.mab.rff,
"use_layer_norm": self.mab.use_layer_norm,
"use_bias": self.mab.mha.use_bias,
}
)

Expand Down Expand Up @@ -514,6 +521,7 @@ def get_config(self):
"rff": self.mab.rff,
"rff_s": self.rff_s,
"use_layer_norm": self.mab.use_layer_norm,
"use_bias": self.mab.mha.use_bias,
"seed_initializer": tf.keras.initializers.serialize(
self.seed_initializer
),
Expand Down

0 comments on commit a411263

Please sign in to comment.