From b58b56e4430ec3d6256c6230c0d8db5872c16a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Fri, 21 Jun 2024 17:10:38 +0200 Subject: [PATCH] fix GemmaBackbone.get_layout_map + test (#1669) * fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per https://github.com/keras-team/keras/issues/19496#issuecomment-2089424525 * fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per https://github.com/keras-team/keras/issues/19496#issuecomment-2089424525 * Also fixing forgotten ffw_gating_2 in GemmaBackbone.get_layout_map. The sharding spec ("batch", "model") is the one that provides the best training performance. ("batch", "model") and (None, None) are slower (the first one by 40%, the second by 2%). Fixing test too, including typo ffw_linearl => ffw_linear * changed test_architecture_characteristics test to follow the 4->8 heads change necessary for the test to work on TPUs. Also fixed formatting. * Update gemma_backbone_test.py Better test messages --------- Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> --- keras_nlp/src/models/gemma/gemma_backbone.py | 9 ++-- .../src/models/gemma/gemma_backbone_test.py | 41 +++++++++++++++++-- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/keras_nlp/src/models/gemma/gemma_backbone.py b/keras_nlp/src/models/gemma/gemma_backbone.py index 30e70e4311..8a89251aa4 100644 --- a/keras_nlp/src/models/gemma/gemma_backbone.py +++ b/keras_nlp/src/models/gemma/gemma_backbone.py @@ -255,17 +255,18 @@ def get_layout_map( # See https://arxiv.org/abs/2403.08295 layout_map = keras.distribution.LayoutMap(device_mesh) layout_map["token_embedding/embeddings"] = (model_dim, data_dim) - layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( + layout_map["decoder_block.*attention.*(query|key|value).kernel"] = ( model_dim, data_dim, None, ) - layout_map["decoder_block.*attention_output.*kernel"] = ( + layout_map["decoder_block.*attention_output.kernel"] = ( model_dim, None, data_dim, ) - layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) - layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) + layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim) return layout_map diff --git a/keras_nlp/src/models/gemma/gemma_backbone_test.py b/keras_nlp/src/models/gemma/gemma_backbone_test.py index d15843c74e..c680c1abbe 100644 --- a/keras_nlp/src/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/src/models/gemma/gemma_backbone_test.py @@ -24,8 +24,8 @@ def setUp(self): self.init_kwargs = { "vocabulary_size": 256128, "num_layers": 2, - "num_query_heads": 4, - "num_key_value_heads": 4, + "num_query_heads": 8, + "num_key_value_heads": 8, "hidden_dim": 128, "intermediate_dim": 256, "head_dim": 128, @@ -82,7 +82,7 @@ def test_all_presets(self): def test_architecture_characteristics(self): model = GemmaBackbone(**self.init_kwargs) - self.assertEqual(model.count_params(), 33407616) + self.assertEqual(model.count_params(), 33931904) self.assertEqual(len(model.layers), 6) def test_distribution(self): @@ -132,7 +132,40 @@ def test_distribution(self): self.assertEqual( tuple(w.value.sharding.spec), ("batch", "model") ) - if "ffw_linearl" in w.path: + if "ffw_linear" in w.path: self.assertEqual( tuple(w.value.sharding.spec), ("model", "batch") ) + + def test_distribution_with_lora(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = GemmaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + with distribution.scope(): + model = GemmaBackbone(**self.init_kwargs) + model.enable_lora(rank=4) + + for w in model.weights: + if "attention/query/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "attention/query/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) + if "attention/value/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "attention/value/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None))