Skip to content

Commit ba5913a

Browse files
Add a preprocessor for the Llama backbone
1 parent a59a26f commit ba5913a

File tree

2 files changed

+248
-0
lines changed

2 files changed

+248
-0
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras_nlp.api_export import keras_nlp_export
15+
from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker
16+
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer
17+
from keras_nlp.models.preprocessor import Preprocessor
18+
from keras_nlp.utils.keras_utils import (
19+
convert_inputs_to_list_of_tensor_segments,
20+
)
21+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
22+
from keras_nlp.utils.python_utils import classproperty
23+
24+
25+
@keras_nlp_export("keras_nlp.models.LlamaPreprocessor")
26+
class LlamaPreprocessor(Preprocessor):
27+
"""A Llama preprocessing layer which tokenizes and packs inputs.
28+
29+
This preprocessing layer will do three things:
30+
31+
1. Tokenize any number of input segments using the `tokenizer`.
32+
2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`.
33+
with the appropriate tokens.
34+
3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"`
35+
that can be passed directly to `keras_nlp.models.LlamaBackbone`.
36+
37+
This layer can be used directly with `tf.data.Dataset.map` to preprocess
38+
string data in the `(x, y, sample_weight)` format used by
39+
`keras.Model.fit`.
40+
41+
Args:
42+
tokenizer: A `keras_nlp.models.LlamaTokenizer` instance.
43+
sequence_length: The length of the packed inputs.
44+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
45+
start token to each input sequence. Default is `True`.
46+
add_end_token: If `True`, the preprocessor will append the tokenizer
47+
end token to each input sequence. Default is `False`.
48+
49+
Call arguments:
50+
x: A tensor of single string sequences, or a tuple of multiple
51+
tensor sequences to be packed together. Inputs may be batched or
52+
unbatched. For single sequences, raw python inputs will be converted
53+
to tensors. For multiple sequences, pass tensors directly.
54+
y: Any label data. Will be passed through unaltered.
55+
sample_weight: Any label weight data. Will be passed through unaltered.
56+
sequence_length: Pass to override the configured `sequence_length` of
57+
the layer.
58+
59+
Examples:
60+
61+
Directly calling the from_preset().
62+
```python
63+
preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset(
64+
"llama_base_en"
65+
)
66+
67+
# Tokenize and pack a single sentence.
68+
preprocessor("The quick brown fox jumped.")
69+
70+
# Tokenize and a batch of single sentences.
71+
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
72+
73+
# Preprocess a batch of sentence pairs.
74+
# When handling multiple sequences, always convert to tensors first!
75+
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
76+
second = tf.constant(["The fox tripped.", "Oh look, a whale."])
77+
preprocessor((first, second))
78+
```
79+
80+
Mapping with `tf.data.Dataset`.
81+
```python
82+
preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset(
83+
"llama_base_en"
84+
)
85+
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
86+
second = tf.constant(["The fox tripped.", "Oh look, a whale."])
87+
label = tf.constant([1, 1])
88+
89+
# Map labeled single sentences.
90+
ds = tf.data.Dataset.from_tensor_slices((first, label))
91+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
92+
93+
# Map unlabeled single sentences.
94+
ds = tf.data.Dataset.from_tensor_slices(first)
95+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
96+
97+
# Map labeled sentence pairs.
98+
ds = tf.data.Dataset.from_tensor_slices(((first, second), label))
99+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
100+
101+
# Map unlabeled sentence pairs.
102+
ds = tf.data.Dataset.from_tensor_slices((first, second))
103+
104+
# Watch out for tf.data's default unpacking of tuples here!
105+
# Best to invoke the `preprocessor` directly in this case.
106+
ds = ds.map(
107+
lambda first, second: preprocessor(x=(first, second)),
108+
num_parallel_calls=tf.data.AUTOTUNE,
109+
)
110+
```
111+
"""
112+
113+
def __init__(
114+
self,
115+
tokenizer,
116+
sequence_length=1024,
117+
add_start_token=True,
118+
add_end_token=False,
119+
**kwargs,
120+
):
121+
super().__init__(**kwargs)
122+
self.tokenizer = tokenizer
123+
self.packer = None
124+
self.add_start_token = add_start_token
125+
self.add_end_token = add_end_token
126+
self.sequence_length = sequence_length
127+
128+
def build(self, input_shape):
129+
# Defer packer creation to `build()` so that we can be sure tokenizer
130+
# assets have loaded when restoring a saved model.
131+
self.packer = StartEndPacker(
132+
start_value=self.tokenizer.start_token_id,
133+
end_value=self.tokenizer.end_token_id,
134+
sequence_length=self.sequence_length,
135+
return_padding_mask=True,
136+
)
137+
self.built = True
138+
139+
def get_config(self):
140+
config = super().get_config()
141+
config.update(
142+
{
143+
"sequence_length": self.sequence_length,
144+
"add_start_token": self.add_start_token,
145+
"add_end_token": self.add_end_token,
146+
}
147+
)
148+
return config
149+
150+
def call(
151+
self,
152+
x,
153+
y=None,
154+
sample_weight=None,
155+
sequence_length=None,
156+
):
157+
x = convert_inputs_to_list_of_tensor_segments(x)
158+
if len(x) != 1:
159+
raise ValueError(
160+
"Llama requires each input feature to contain only "
161+
f"one segment, but received {len(x)}. If you are using Llama"
162+
" for a multi-segment classification task, please refer to "
163+
"classification models like BERT or RoBERTa."
164+
)
165+
sequence_length = sequence_length or self.sequence_length
166+
token_ids, padding_mask = self.packer(
167+
self.tokenizer(x[0]),
168+
sequence_length=sequence_length,
169+
add_start_value=self.add_start_token,
170+
add_end_value=self.add_end_token,
171+
)
172+
x = {
173+
"token_ids": token_ids,
174+
"padding_mask": padding_mask,
175+
}
176+
return pack_x_y_sample_weight(x, y, sample_weight)
177+
178+
@property
179+
def sequence_length(self):
180+
"""The padded length of model input sequences."""
181+
return self._sequence_length
182+
183+
@sequence_length.setter
184+
def sequence_length(self, value):
185+
self._sequence_length = value
186+
if self.packer is not None:
187+
self.packer.sequence_length = value
188+
189+
@classproperty
190+
def tokenizer_cls(cls):
191+
return LlamaTokenizer
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor
18+
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer
19+
from keras_nlp.tests.test_case import TestCase
20+
21+
22+
class LlamaPreprocessorTest(TestCase):
23+
def setUp(self):
24+
self.tokenizer = LlamaTokenizer(
25+
# Generated using create_llama_test_proto.py
26+
proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm")
27+
)
28+
self.init_kwargs = {
29+
"tokenizer": self.tokenizer,
30+
"sequence_length": 8,
31+
}
32+
self.input_data = (
33+
["the quick brown fox"],
34+
[1], # Pass through labels.
35+
[1.0], # Pass through sample_weights.
36+
)
37+
38+
def test_preprocessor_basics(self):
39+
self.run_preprocessor_test(
40+
cls=LlamaPreprocessor,
41+
init_kwargs=self.init_kwargs,
42+
input_data=self.input_data,
43+
expected_output=(
44+
{
45+
"token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]],
46+
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]],
47+
},
48+
[1], # Pass through labels.
49+
[1.0], # Pass through sample_weights.
50+
),
51+
)
52+
53+
def test_errors_for_2d_list_input(self):
54+
preprocessor = LlamaPreprocessor(**self.init_kwargs)
55+
ambiguous_input = [["one", "two"], ["three", "four"]]
56+
with self.assertRaises(ValueError):
57+
preprocessor(ambiguous_input)

0 commit comments

Comments
 (0)