Skip to content

Update gemma_backbone.py for sharding config. #1491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,20 +249,23 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
model_dim = model_parallel_dim_name
# The sharding is partition for the hidden_dim of the model.
# The sharding is set to replicated the hidden_dim of the model.
# So that the contrasting dimensions for qkv matmul are replicated.
# and will be run as local computation.
# See https://github.com/keras-team/keras-nlp/issues/1464 for more details.
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, model_dim)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
None,
model_dim,
None,
model_dim,
)
layout_map["decoder_block.*attention_output.*kernel"] = (
None,
None,
model_dim,
None,
)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

return layout_map
16 changes: 8 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,26 @@ def test_distribution(self):

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
if "attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), (None, None, "model")
)
if "attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, "model")
tuple(w.value.sharding.spec), (None, "model", None)
)
if "ffw_gating/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
if "ffw_gating_2/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))