From c459519a8b2c68588c3cde1bc06f3a838bf930d4 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Tue, 25 Jun 2024 02:35:05 +0530 Subject: [PATCH] Covert a `safetensor` checkpoint from Hugging Face hub (#1662) * chore: adding gemma and llama3 * chore: adding init * chore: removing hard coded values * chore: using backbone properties * chore: reformat * chore: review changes * chore: removing einops with custom np operations * fix: variable name * check: none type for reshape and transpose patterns * chore: fixing the nesting of reshape and transpose patterns * fixing nesting of patterns * chore: gemma weight rearrange fix * chore: adding a hook function to reshape and transpose the hf tensors to match the keras weights * fix: variable to assign * fix: gemma port * chore: adding tests * review comments * adding safetensors as a dep * chore: adding jax memory cleanup * utf 8 encoding * chore: changing tests * chore: fixing tests * fix tests * chore: adding guard rails for None types * Trigger Build * review suggestions * fix raising ValueError * fix error message --- keras_nlp/src/models/backbone.py | 9 +- keras_nlp/src/models/backbone_test.py | 4 +- keras_nlp/src/models/preprocessor.py | 11 +- keras_nlp/src/models/preprocessor_test.py | 5 +- keras_nlp/src/models/task.py | 14 +- keras_nlp/src/models/task_test.py | 4 +- keras_nlp/src/tokenizers/tokenizer.py | 8 +- keras_nlp/src/tokenizers/tokenizer_test.py | 5 +- keras_nlp/src/utils/preset_utils.py | 19 +- keras_nlp/src/utils/preset_utils_test.py | 6 +- keras_nlp/src/utils/transformers/__init__.py | 13 ++ keras_nlp/src/utils/transformers/convert.py | 48 ++++ .../src/utils/transformers/convert_gemma.py | 179 +++++++++++++++ .../utils/transformers/convert_gemma_test.py | 27 +++ .../src/utils/transformers/convert_llama3.py | 206 ++++++++++++++++++ .../utils/transformers/convert_llama3_test.py | 27 +++ .../utils/transformers/safetensor_utils.py | 44 ++++ requirements-common.txt | 1 + 18 files changed, 600 insertions(+), 30 deletions(-) create mode 100644 keras_nlp/src/utils/transformers/__init__.py create mode 100644 keras_nlp/src/utils/transformers/convert.py create mode 100644 keras_nlp/src/utils/transformers/convert_gemma.py create mode 100644 keras_nlp/src/utils/transformers/convert_gemma_test.py create mode 100644 keras_nlp/src/utils/transformers/convert_llama3.py create mode 100644 keras_nlp/src/utils/transformers/convert_llama3_test.py create mode 100644 keras_nlp/src/utils/transformers/safetensor_utils.py diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index 3e052d75a2..468a24f0ca 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -20,6 +20,7 @@ from keras_nlp.src.utils.preset_utils import CONFIG_FILE from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.src.utils.preset_utils import check_config_class +from keras_nlp.src.utils.preset_utils import check_format from keras_nlp.src.utils.preset_utils import get_file from keras_nlp.src.utils.preset_utils import jax_memory_cleanup from keras_nlp.src.utils.preset_utils import list_presets @@ -27,8 +28,8 @@ from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object -from keras_nlp.src.utils.preset_utils import validate_metadata from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @keras_nlp_export("keras_nlp.models.Backbone") @@ -173,7 +174,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from ) ``` """ - validate_metadata(preset) + format = check_format(preset) + + if format == "transformers": + return load_transformers_backbone(cls, preset, load_weights) + preset_cls = check_config_class(preset) if not issubclass(preset_cls, cls): raise ValueError( diff --git a/keras_nlp/src/models/backbone_test.py b/keras_nlp/src/models/backbone_test.py index 98d1d2d35d..639d730b19 100644 --- a/keras_nlp/src/models/backbone_test.py +++ b/keras_nlp/src/models/backbone_test.py @@ -50,9 +50,7 @@ def test_from_preset(self): def test_from_preset_errors(self): with self.assertRaises(ValueError): GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False) - with self.assertRaisesRegex( - FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" - ): + with self.assertRaises(ValueError): # No loading on a non-keras model. Backbone.from_preset("hf://google-bert/bert-base-uncased") diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index a3a3e13beb..d8cb6eb2fc 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -22,11 +22,11 @@ from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import check_config_class from keras_nlp.src.utils.preset_utils import check_file_exists +from keras_nlp.src.utils.preset_utils import check_format from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object -from keras_nlp.src.utils.preset_utils import validate_metadata from keras_nlp.src.utils.python_utils import classproperty @@ -128,7 +128,14 @@ def from_preset( ) ``` """ - validate_metadata(preset) + format = check_format(preset) + + if format == "transformers": + if cls.tokenizer_cls is None: + raise ValueError("Tokenizer class is None") + tokenizer = cls.tokenizer_cls.from_preset(preset) + return cls(tokenizer=tokenizer, **kwargs) + if cls == Preprocessor: raise ValueError( "Do not call `Preprocessor.from_preset()` directly. Instead call a " diff --git a/keras_nlp/src/models/preprocessor_test.py b/keras_nlp/src/models/preprocessor_test.py index 5401cd8b48..fdfaccc5fe 100644 --- a/keras_nlp/src/models/preprocessor_test.py +++ b/keras_nlp/src/models/preprocessor_test.py @@ -27,7 +27,6 @@ RobertaPreprocessor, ) from keras_nlp.src.tests.test_case import TestCase -from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import check_config_class @@ -67,9 +66,7 @@ def test_from_preset_errors(self): with self.assertRaises(ValueError): # No loading on an incorrect class. BertPreprocessor.from_preset("gpt2_base_en") - with self.assertRaisesRegex( - FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" - ): + with self.assertRaises(ValueError): # No loading on a non-keras model. Preprocessor.from_preset("hf://google-bert/bert-base-uncased") diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 4803d25266..8a59fe274f 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -28,13 +28,13 @@ from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE from keras_nlp.src.utils.preset_utils import check_config_class from keras_nlp.src.utils.preset_utils import check_file_exists +from keras_nlp.src.utils.preset_utils import check_format from keras_nlp.src.utils.preset_utils import get_file from keras_nlp.src.utils.preset_utils import jax_memory_cleanup from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object -from keras_nlp.src.utils.preset_utils import validate_metadata from keras_nlp.src.utils.python_utils import classproperty @@ -187,7 +187,17 @@ def from_preset( ) ``` """ - validate_metadata(preset) + format = check_format(preset) + + if format == "transformers": + if cls.backbone_cls is None: + raise ValueError("Backbone class is None") + if cls.preprocessor_cls is None: + raise ValueError("Preprocessor class is None") + + backbone = cls.backbone_cls.from_preset(preset) + preprocessor = cls.preprocessor_cls.from_preset(preset) + return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) if cls == Task: raise ValueError( diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index 14272694de..e1e235fa01 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -79,9 +79,7 @@ def test_from_preset_errors(self): with self.assertRaises(ValueError): # No loading on an incorrect class. BertClassifier.from_preset("gpt2_base_en", load_weights=False) - with self.assertRaisesRegex( - FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" - ): + with self.assertRaises(ValueError): # No loading on a non-keras model. CausalLM.from_preset("hf://google-bert/bert-base-uncased") diff --git a/keras_nlp/src/tokenizers/tokenizer.py b/keras_nlp/src/tokenizers/tokenizer.py index 7b918f5a99..aac0f4b917 100644 --- a/keras_nlp/src/tokenizers/tokenizer.py +++ b/keras_nlp/src/tokenizers/tokenizer.py @@ -20,14 +20,15 @@ from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import check_config_class +from keras_nlp.src.utils.preset_utils import check_format from keras_nlp.src.utils.preset_utils import get_file from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.preset_utils import save_tokenizer_assets -from keras_nlp.src.utils.preset_utils import validate_metadata from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer @keras_nlp_export( @@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from tokenizer.detokenize([5, 6, 7, 8, 9]) ``` """ - validate_metadata(preset) + format = check_format(preset) + if format == "transformers": + return load_transformers_tokenizer(cls, preset) + preset_cls = check_config_class( preset, config_file=TOKENIZER_CONFIG_FILE ) diff --git a/keras_nlp/src/tokenizers/tokenizer_test.py b/keras_nlp/src/tokenizers/tokenizer_test.py index 1558db44a2..acc22b4b5d 100644 --- a/keras_nlp/src/tokenizers/tokenizer_test.py +++ b/keras_nlp/src/tokenizers/tokenizer_test.py @@ -31,7 +31,6 @@ from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.src.tests.test_case import TestCase from keras_nlp.src.tokenizers.tokenizer import Tokenizer -from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import check_config_class @@ -70,9 +69,7 @@ def test_from_preset(self): def test_from_preset_errors(self): with self.assertRaises(ValueError): GPT2Tokenizer.from_preset("bert_tiny_en_uncased") - with self.assertRaisesRegex( - FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" - ): + with self.assertRaises(ValueError): # No loading on a non-keras model. Tokenizer.from_preset("hf://google-bert/bert-base-uncased") diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 09783eadfa..789cfdb896 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -59,16 +59,19 @@ # Config file names. CONFIG_FILE = "config.json" +HF_CONFIG_FILE = "config.json" TOKENIZER_CONFIG_FILE = "tokenizer.json" TASK_CONFIG_FILE = "task.json" PREPROCESSOR_CONFIG_FILE = "preprocessor.json" METADATA_FILE = "metadata.json" +SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json" README_FILE = "README.md" # Weight file names. MODEL_WEIGHTS_FILE = "model.weights.h5" TASK_WEIGHTS_FILE = "task.weights.h5" +SAFETENSOR_FILE = "model.safetensors" # Global state for preset registry. BUILTIN_PRESETS = {} @@ -324,7 +327,7 @@ def _validate_tokenizer(preset, allow_incomplete=False): ) config_path = get_file(preset, TOKENIZER_CONFIG_FILE) try: - with open(config_path) as config_file: + with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) except Exception as e: raise ValueError( @@ -357,7 +360,7 @@ def _validate_backbone(preset): f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`." ) try: - with open(config_path) as config_file: + with open(config_path, encoding="utf-8") as config_file: json.load(config_file) except Exception as e: raise ValueError( @@ -530,12 +533,17 @@ def upload_preset( def load_config(preset, config_file=CONFIG_FILE): config_path = get_file(preset, config_file) - with open(config_path) as config_file: + with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) return config -def validate_metadata(preset): +def check_format(preset): + if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( + preset, SAFETENSOR_CONFIG_FILE + ): + return "transformers" + if not check_file_exists(preset, METADATA_FILE): raise FileNotFoundError( f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, " @@ -548,6 +556,7 @@ def validate_metadata(preset): f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. " "Please verify that the model you are trying to load is a Keras model." ) + return "keras" def load_serialized_object( @@ -566,7 +575,7 @@ def check_config_class( ): """Validate a preset is being loaded on the correct class.""" config_path = get_file(preset, config_file) - with open(config_path) as config_file: + with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) return keras.saving.get_registered_object(config["registered_name"]) diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py index 9a55f07ee6..9185f4b6da 100644 --- a/keras_nlp/src/utils/preset_utils_test.py +++ b/keras_nlp/src/utils/preset_utils_test.py @@ -26,7 +26,7 @@ from keras_nlp.src.utils.preset_utils import CONFIG_FILE from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import validate_metadata +from keras_nlp.src.utils.preset_utils import check_format class PresetUtilsTest(TestCase): @@ -100,7 +100,7 @@ def test_missing_metadata(self): with self.assertRaisesRegex( FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" ): - validate_metadata(preset_dir) + check_format(preset_dir) def test_incorrect_metadata(self): temp_dir = self.get_temp_dir() @@ -112,4 +112,4 @@ def test_incorrect_metadata(self): json.dump(data, f) with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"): - validate_metadata(preset_dir) + check_format(preset_dir) diff --git a/keras_nlp/src/utils/transformers/__init__.py b/keras_nlp/src/utils/transformers/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/src/utils/transformers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/utils/transformers/convert.py b/keras_nlp/src/utils/transformers/convert.py new file mode 100644 index 0000000000..25ee54bbf1 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert huggingface models to KerasNLP.""" + + +from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone +from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer +from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone +from keras_nlp.src.utils.transformers.convert_llama3 import ( + load_llama3_tokenizer, +) + + +def load_transformers_backbone(cls, preset, load_weights): + if cls is None: + raise ValueError("Backbone class is None") + if cls.__name__ == "GemmaBackbone": + return load_gemma_backbone(cls, preset, load_weights) + if cls.__name__ == "Llama3Backbone": + return load_llama3_backbone(cls, preset, load_weights) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) + + +def load_transformers_tokenizer(cls, preset): + if cls is None: + raise ValueError("Tokenizer class is None") + if cls.__name__ == "GemmaTokenizer": + return load_gemma_tokenizer(cls, preset) + if cls.__name__ == "Llama3Tokenizer": + return load_llama3_tokenizer(cls, preset) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) diff --git a/keras_nlp/src/utils/transformers/convert_gemma.py b/keras_nlp/src/utils/transformers/convert_gemma.py new file mode 100644 index 0000000000..e8094ad480 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_gemma.py @@ -0,0 +1,179 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import get_file +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import set_keras_weight + + +def load_gemma_backbone(cls, preset, load_weights): + """ + Load and initialize the Gemma backbone model. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + transformers_config = load_config(preset, HF_CONFIG_FILE) + + backbone = cls( + vocabulary_size=transformers_config["vocab_size"], + num_layers=transformers_config["num_hidden_layers"], + num_query_heads=transformers_config["num_attention_heads"], + num_key_value_heads=transformers_config["num_key_value_heads"], + hidden_dim=transformers_config["hidden_size"], + intermediate_dim=transformers_config["intermediate_size"] * 2, + head_dim=transformers_config["head_dim"], + ) + + if not load_weights: + return backbone + + jax_memory_cleanup(backbone) + # Code to port the weights from safetensors into the keras nlp model + safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE) + safetensor_files = { + fname: get_file(preset, fname) + for fname in set(safetensor_config["weight_map"].values()) + } + port_weight = partial( + set_keras_weight, + safetensor_files=safetensor_files, + safetensor_config=safetensor_config, + ) + + # Embedding layer + port_weight( + keras_variable=backbone.get_layer("token_embedding").variables[0], + hf_weight_key="model.embed_tokens.weight", + ) + + # Attention blocks + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"decoder_block_{i}") + # Norm layers + port_weight( + keras_variable=decoder_layer.pre_attention_norm.variables[0], + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + port_weight( + keras_variable=decoder_layer.pre_ffw_norm.variables[0], + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Attention layers + port_weight( + keras_variable=decoder_layer.attention.query_dense.variables[0], + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + # rearrange_patterns="(a c) b -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + port_weight( + keras_variable=decoder_layer.attention.key_dense.variables[0], + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + # rearrange_patterns="(a c) b -> a b c", + # rearrange_dims={"a": backbone.num_key_value_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + port_weight( + keras_variable=decoder_layer.attention.value_dense.variables[0], + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + # rearrange_patterns="(a c) b -> a b c", + # rearrange_dims={"a": backbone.num_key_value_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[0], keras_shape[2], keras_shape[1]), + ), + axes=(0, 2, 1), + ), + ) + port_weight( + keras_variable=decoder_layer.attention.output_dense.variables[0], + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[2], keras_shape[0], keras_shape[1]), + ), + axes=(1, 2, 0), + ), + ) + + # MLP layers + port_weight( + keras_variable=decoder_layer.gating_ffw.variables[0], + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + port_weight( + keras_variable=decoder_layer.gating_ffw_2.variables[0], + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + port_weight( + keras_variable=decoder_layer.ffw_linear.variables[0], + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Final normalization layer + port_weight( + keras_variable=backbone.get_layer("final_normalization").variables[0], + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def load_gemma_tokenizer(cls, preset): + """ + Load the Gemma tokenizer. + + Args: + cls (class): Tokenizer class. + preset (str): Preset configuration name. + + Returns: + tokenizer: Initialized tokenizer. + """ + return cls(get_file(preset, "tokenizer.model")) diff --git a/keras_nlp/src/utils/transformers/convert_gemma_test.py b/keras_nlp/src/utils/transformers/convert_gemma_test.py new file mode 100644 index 0000000000..a7e5c795c7 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_gemma_test.py @@ -0,0 +1,27 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = GemmaCausalLM.from_preset("hf://ariG23498/tiny-gemma-test") + prompt = "What is your favorite condiment?" + model.generate([prompt], max_length=15) + + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_llama3.py b/keras_nlp/src/utils/transformers/convert_llama3.py new file mode 100644 index 0000000000..6ce954f37e --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_llama3.py @@ -0,0 +1,206 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import get_file +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import set_keras_weight + + +def load_llama3_backbone(cls, preset, load_weights): + """ + Load and initialize the Llama3 backbone model. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + transformers_config = load_config(preset, HF_CONFIG_FILE) + + backbone = cls( + vocabulary_size=transformers_config["vocab_size"], + num_layers=transformers_config["num_hidden_layers"], + num_query_heads=transformers_config["num_attention_heads"], + hidden_dim=transformers_config["hidden_size"], + intermediate_dim=transformers_config["intermediate_size"], + num_key_value_heads=transformers_config["num_key_value_heads"], + ) + + if not load_weights: + return backbone + + jax_memory_cleanup(backbone) + # Code to port the weights from safetensors into the keras nlp model + safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE) + safetensor_files = { + fname: get_file(preset, fname) + for fname in set(safetensor_config["weight_map"].values()) + } + port_weight = partial( + set_keras_weight, + safetensor_files=safetensor_files, + safetensor_config=safetensor_config, + ) + + # Embedding layers + port_weight( + keras_variable=backbone.get_layer("token_embedding").variables[0], + hf_weight_key="model.embed_tokens.weight", + ) + port_weight( + keras_variable=backbone.get_layer("token_embedding").variables[1], + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Attention blocks + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + # Norm layers + port_weight( + keras_variable=decoder_layer._self_attention_layernorm.variables[0], + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + port_weight( + keras_variable=decoder_layer._feedforward_layernorm.variables[0], + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Attention layers + port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.variables[ + 0 + ], + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + # rearrange_patterns="(b c) a -> a b c, + # rearrange_dims={"b": backbone.num_query_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[1], keras_shape[2], keras_shape[0]), + ), + axes=(2, 0, 1), + ), + ) + port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.variables[ + 0 + ], + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + # rearrange_patterns="(b c) a -> a b c", + # rearrange_dims={"b": backbone.num_key_value_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[1], keras_shape[2], keras_shape[0]), + ), + axes=(2, 0, 1), + ), + ) + port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.variables[ + 0 + ], + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + # rearrange_patterns="(b c) a -> a b c", + # rearrange_dims={"b": backbone.num_key_value_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[1], keras_shape[2], keras_shape[0]), + ), + axes=(2, 0, 1), + ), + ) + port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.variables[ + 0 + ], + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=lambda hf_tensor, keras_shape: np.transpose( + np.reshape( + hf_tensor, + (keras_shape[2], keras_shape[0], keras_shape[1]), + ), + axes=(1, 2, 0), + ), + ) + + # MLP layers + port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.variables[0], + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.variables[ + 0 + ], + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + port_weight( + keras_variable=decoder_layer._feedforward_output_dense.variables[0], + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Final normalization layer + port_weight( + keras_variable=backbone.get_layer( + "sequence_output_layernorm" + ).variables[0], + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def load_llama3_tokenizer(cls, preset): + """ + Load the Llama3 tokenizer. + + Args: + cls (class): Tokenizer class. + preset (str): Preset configuration name. + + Returns: + tokenizer: Initialized tokenizer. + """ + tokenizer_config = load_config(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + + bot = tokenizer_config["added_tokens"][0] # begin of text + eot = tokenizer_config["added_tokens"][1] # end of text + + vocab[bot["content"]] = bot["id"] + vocab[eot["content"]] = eot["id"] + + return cls(vocabulary=vocab, merges=merges) diff --git a/keras_nlp/src/utils/transformers/convert_llama3_test.py b/keras_nlp/src/utils/transformers/convert_llama3_test.py new file mode 100644 index 0000000000..856fc2a946 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_llama3_test.py @@ -0,0 +1,27 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM +from keras_nlp.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = Llama3CausalLM.from_preset("hf://ariG23498/tiny-llama3-test") + prompt = "What is your favorite condiment?" + model.generate([prompt], max_length=15) + + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py new file mode 100644 index 0000000000..db9c5fe584 --- /dev/null +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -0,0 +1,44 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import safetensors +except ImportError: + safetensors = None + + +def set_keras_weight( + safetensor_files, + safetensor_config, + keras_variable, + hf_weight_key, + hook_fn=None, +): + if safetensors is None: + raise ImportError( + "Converting from the huggingface/transformers model format" + "requires the safetensors package." + "Please install with `pip install safetensors`." + ) + else: + from safetensors import safe_open + + safetensor_file = safetensor_files[ + safetensor_config["weight_map"][hf_weight_key] + ] + with safe_open(safetensor_file, framework="np") as f: + hf_tensor = f.get_tensor(hf_weight_key) + + if hook_fn: + hf_tensor = hook_fn(hf_tensor, list(keras_variable.shape)) + keras_variable.assign(hf_tensor) diff --git a/requirements-common.txt b/requirements-common.txt index 5c2a8a3d90..4e90ca9fab 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -17,3 +17,4 @@ namex rouge-score sentencepiece tensorflow-datasets +safetensors