diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index 6e69c1848c..8266f61d4f 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -251,7 +251,7 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): model_dim = model_parallel_dim_name # The sharding is set to replicated the hidden_dim of the model. # So that the contrasting dimensions for qkv matmul are replicated. - # and will be run as local computation. + # and will be run as local computation. # See https://github.com/keras-team/keras-nlp/issues/1464 for more details. layout_map = keras.distribution.LayoutMap(device_mesh) layout_map["token_embedding/embeddings"] = (model_dim, None) diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index 855d49658b..a23551df59 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -106,26 +106,26 @@ def test_distribution(self): for w in model.weights: if "token_embedding/embeddings" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) if "attention/query/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), (None, None, "model") ) if "attention/key/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), (None, None, "model") ) if "attention/value/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), (None, None, "model") ) if "attention/attention_output/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, None, "model") + tuple(w.value.sharding.spec), (None, "model", None) ) if "ffw_gating/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) if "ffw_gating_2/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) - if "ffw_linearl" in w.path: self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + if "ffw_linearl" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None))