Skip to content

Commit

Permalink
fix GemmaBackbone.get_layout_map + test (#1669)
Browse files Browse the repository at this point in the history
* fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment)

* fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment)

* Also fixing forgotten ffw_gating_2 in GemmaBackbone.get_layout_map. The sharding spec ("batch", "model") is the one that provides the best training performance. ("batch", "model") and (None, None) are slower (the first one by 40%, the second by 2%).
Fixing test too, including typo ffw_linearl => ffw_linear

* changed test_architecture_characteristics test to follow the 4->8 heads change necessary for the test to work on TPUs.
Also fixed formatting.

* Update gemma_backbone_test.py

Better test messages

---------

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
  • Loading branch information
martin-gorner and mattdangerw authored Jun 21, 2024
1 parent e0efbc8 commit b58b56e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
9 changes: 5 additions & 4 deletions keras_nlp/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,18 @@ def get_layout_map(
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.*kernel"] = (
layout_map["decoder_block.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim)
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)

return layout_map
41 changes: 37 additions & 4 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def setUp(self):
self.init_kwargs = {
"vocabulary_size": 256128,
"num_layers": 2,
"num_query_heads": 4,
"num_key_value_heads": 4,
"num_query_heads": 8,
"num_key_value_heads": 8,
"hidden_dim": 128,
"intermediate_dim": 256,
"head_dim": 128,
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_all_presets(self):

def test_architecture_characteristics(self):
model = GemmaBackbone(**self.init_kwargs)
self.assertEqual(model.count_params(), 33407616)
self.assertEqual(model.count_params(), 33931904)
self.assertEqual(len(model.layers), 6)

def test_distribution(self):
Expand Down Expand Up @@ -132,7 +132,40 @@ def test_distribution(self):
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_linearl" in w.path:
if "ffw_linear" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)

def test_distribution_with_lora(self):
if keras.backend.backend() != "jax":
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
with distribution.scope():
model = GemmaBackbone(**self.init_kwargs)
model.enable_lora(rank=4)

for w in model.weights:
if "attention/query/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "attention/query/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
if "attention/value/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))

0 comments on commit b58b56e

Please sign in to comment.