Skip to content

Commit

Permalink
Covert a safetensor checkpoint from Hugging Face hub (#1662)
Browse files Browse the repository at this point in the history
* chore: adding gemma and llama3

* chore: adding init

* chore: removing hard coded values

* chore: using backbone properties

* chore: reformat

* chore: review changes

* chore: removing einops with custom np operations

* fix: variable name

* check: none type for reshape and transpose patterns

* chore: fixing the nesting of reshape and transpose patterns

* fixing nesting of patterns

* chore: gemma weight rearrange fix

* chore: adding a hook function to reshape and transpose the hf tensors to match the keras weights

* fix: variable to assign

* fix: gemma port

* chore: adding tests

* review comments

* adding safetensors as a dep

* chore: adding jax memory cleanup

* utf 8 encoding

* chore: changing tests

* chore: fixing tests

* fix tests

* chore: adding guard rails for None types

* Trigger Build

* review suggestions

* fix raising ValueError

* fix error message
  • Loading branch information
ariG23498 authored Jun 24, 2024
1 parent b58b56e commit c459519
Show file tree
Hide file tree
Showing 18 changed files with 600 additions and 30 deletions.
9 changes: 7 additions & 2 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -173,7 +174,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Backbone.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
11 changes: 9 additions & 2 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -128,7 +128,14 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
if cls.tokenizer_cls is None:
raise ValueError("Tokenizer class is None")
tokenizer = cls.tokenizer_cls.from_preset(preset)
return cls(tokenizer=tokenizer, **kwargs)

if cls == Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead call a "
Expand Down
5 changes: 1 addition & 4 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
RobertaPreprocessor,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -67,9 +66,7 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertPreprocessor.from_preset("gpt2_base_en")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Preprocessor.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
14 changes: 12 additions & 2 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -187,7 +187,17 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
if cls.backbone_cls is None:
raise ValueError("Backbone class is None")
if cls.preprocessor_cls is None:
raise ValueError("Preprocessor class is None")

backbone = cls.backbone_cls.from_preset(preset)
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

if cls == Task:
raise ValueError(
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertClassifier.from_preset("gpt2_base_en", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
CausalLM.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
8 changes: 6 additions & 2 deletions keras_nlp/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import save_tokenizer_assets
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer


@keras_nlp_export(
Expand Down Expand Up @@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
validate_metadata(preset)
format = check_format(preset)
if format == "transformers":
return load_transformers_tokenizer(cls, preset)

preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
Expand Down
5 changes: 1 addition & 4 deletions keras_nlp/src/tokenizers/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -70,9 +69,7 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Tokenizer.from_preset("bert_tiny_en_uncased")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Tokenizer.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
19 changes: 14 additions & 5 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@

# Config file names.
CONFIG_FILE = "config.json"
HF_CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
TASK_CONFIG_FILE = "task.json"
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
METADATA_FILE = "metadata.json"
SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"

README_FILE = "README.md"

# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"
TASK_WEIGHTS_FILE = "task.weights.h5"
SAFETENSOR_FILE = "model.safetensors"

# Global state for preset registry.
BUILTIN_PRESETS = {}
Expand Down Expand Up @@ -324,7 +327,7 @@ def _validate_tokenizer(preset, allow_incomplete=False):
)
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
try:
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
except Exception as e:
raise ValueError(
Expand Down Expand Up @@ -357,7 +360,7 @@ def _validate_backbone(preset):
f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`."
)
try:
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
json.load(config_file)
except Exception as e:
raise ValueError(
Expand Down Expand Up @@ -530,12 +533,17 @@ def upload_preset(

def load_config(preset, config_file=CONFIG_FILE):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
return config


def validate_metadata(preset):
def check_format(preset):
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
preset, SAFETENSOR_CONFIG_FILE
):
return "transformers"

if not check_file_exists(preset, METADATA_FILE):
raise FileNotFoundError(
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
Expand All @@ -548,6 +556,7 @@ def validate_metadata(preset):
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
"Please verify that the model you are trying to load is a Keras model."
)
return "keras"


def load_serialized_object(
Expand All @@ -566,7 +575,7 @@ def check_config_class(
):
"""Validate a preset is being loaded on the correct class."""
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
return keras.saving.get_registered_object(config["registered_name"])

Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.preset_utils import check_format


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_missing_metadata(self):
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
validate_metadata(preset_dir)
check_format(preset_dir)

def test_incorrect_metadata(self):
temp_dir = self.get_temp_dir()
Expand All @@ -112,4 +112,4 @@ def test_incorrect_metadata(self):
json.dump(data, f)

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
validate_metadata(preset_dir)
check_format(preset_dir)
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
48 changes: 48 additions & 0 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert huggingface models to KerasNLP."""


from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer
from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone
from keras_nlp.src.utils.transformers.convert_llama3 import (
load_llama3_tokenizer,
)


def load_transformers_backbone(cls, preset, load_weights):
if cls is None:
raise ValueError("Backbone class is None")
if cls.__name__ == "GemmaBackbone":
return load_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "Llama3Backbone":
return load_llama3_backbone(cls, preset, load_weights)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
)


def load_transformers_tokenizer(cls, preset):
if cls is None:
raise ValueError("Tokenizer class is None")
if cls.__name__ == "GemmaTokenizer":
return load_gemma_tokenizer(cls, preset)
if cls.__name__ == "Llama3Tokenizer":
return load_llama3_tokenizer(cls, preset)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
)
Loading

0 comments on commit c459519

Please sign in to comment.