Skip to content

Commit

Permalink
Update unit test and fix format.
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 committed Mar 12, 2024
1 parent e95675a commit d11fb86
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
model_dim = model_parallel_dim_name
# The sharding is set to replicated the hidden_dim of the model.
# So that the contrasting dimensions for qkv matmul are replicated.
# and will be run as local computation.
# and will be run as local computation.
# See https://github.com/keras-team/keras-nlp/issues/1464 for more details.
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, None)
Expand Down
16 changes: 8 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,26 @@ def test_distribution(self):

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
if "attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, "model")
tuple(w.value.sharding.spec), (None, "model", None)
)
if "ffw_gating/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
if "ffw_gating_2/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))

0 comments on commit d11fb86

Please sign in to comment.