Skip to content

Commit 0738f3e

Browse files
Add causal lm preprocessor for the Llama backbone
1 parent ba5913a commit 0738f3e

File tree

2 files changed

+275
-0
lines changed

2 files changed

+275
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from absl import logging
17+
18+
from keras_nlp.api_export import keras_nlp_export
19+
from keras_nlp.backend import ops
20+
from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor
21+
from keras_nlp.utils.keras_utils import (
22+
convert_inputs_to_list_of_tensor_segments,
23+
)
24+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
25+
26+
27+
@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor")
28+
class LlamaCausalLMPreprocessor(LlamaPreprocessor):
29+
"""Llama Causal LM preprocessor.
30+
31+
This preprocessing layer is meant for use with
32+
`keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of
33+
strings, and return outputs in a `(x, y, sample_weight)` format, where the
34+
`y` label is the next token id in the `x` sequence.
35+
36+
For use with generation, the layer also exposes two methods
37+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
38+
is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods
39+
will be called implicitly in `generate()`. They can also be called
40+
standalone (e.g. to precompute preprocessing inputs for generation in a
41+
separate process).
42+
43+
Args:
44+
tokenizer: A `keras_nlp.models.LlamaTokenizer` instance.
45+
sequence_length: The length of the packed inputs.
46+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
47+
start token to each input sequence. Default is `True`.
48+
add_end_token: If `True`, the preprocessor will append the tokenizer
49+
end token to each input sequence. Default is `False`.
50+
51+
Call arguments:
52+
x: A string, `tf.Tensor` or list of python strings.
53+
y: Label data. Should always be `None` as the layer generates labels.
54+
sample_weight: Label weights. Should always be `None` as the layer
55+
generates label weights.
56+
sequence_length: Pass to override the configured `sequence_length` of
57+
the layer.
58+
59+
Examples:
60+
```python
61+
# Load the preprocessor from a preset.
62+
preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset(
63+
"llama_base_en"
64+
)
65+
66+
# Tokenize and pack a single sentence.
67+
sentence = tf.constant("League of legends")
68+
preprocessor(sentence)
69+
# Same output.
70+
preprocessor("League of legends")
71+
72+
# Tokenize a batch of sentences.
73+
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
74+
preprocessor(sentences)
75+
# Same output.
76+
preprocessor(["Taco tuesday", "Fish taco please!"])
77+
78+
# Map a dataset to preprocess a single sentence.
79+
features = tf.constant(
80+
[
81+
"Avatar 2 is amazing!",
82+
"Well, I am not sure.",
83+
]
84+
)
85+
labels = tf.constant([1, 0])
86+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
87+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
88+
89+
# Map a dataset to preprocess unlabled sentences.
90+
ds = tf.data.Dataset.from_tensor_slices(features)
91+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
92+
```
93+
"""
94+
95+
def call(
96+
self,
97+
x,
98+
y=None,
99+
sample_weight=None,
100+
sequence_length=None,
101+
):
102+
if y is not None or sample_weight is not None:
103+
logging.warning(
104+
"`LlamaCausalLMPreprocessor` generates `y` and "
105+
"`sample_weight` based on your input data, but your data "
106+
"already contains `y` or `sample_weight`. Your `y` and "
107+
"`sample_weight` will be ignored."
108+
)
109+
sequence_length = sequence_length or self.sequence_length
110+
111+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
112+
x = self.tokenizer(x)
113+
# Pad with one extra token to account for the truncation below.
114+
token_ids, padding_mask = self.packer(
115+
x,
116+
sequence_length=sequence_length + 1,
117+
add_start_value=self.add_start_token,
118+
add_end_value=self.add_end_token,
119+
)
120+
# The last token does not have a next token, so we truncate it out.
121+
x = {
122+
"token_ids": token_ids[..., :-1],
123+
"padding_mask": padding_mask[..., :-1],
124+
}
125+
# Target `y` will be the next token.
126+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
127+
return pack_x_y_sample_weight(x, y, sample_weight)
128+
129+
def generate_preprocess(
130+
self,
131+
x,
132+
sequence_length=None,
133+
):
134+
"""Convert strings to integer token input for generation.
135+
136+
Similar to calling the layer for training, this method takes in strings
137+
or tensor strings, tokenizes and packs the input, and computes a padding
138+
mask masking all inputs not filled in with a padded value.
139+
140+
Unlike calling the layer for training, this method does not compute
141+
labels and will never append a `tokenizer.end_token_id` to the end of
142+
the sequence (as generation is expected to continue at the end of the
143+
inputted prompt).
144+
"""
145+
if not self.built:
146+
self.build(None)
147+
148+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
149+
x = self.tokenizer(x)
150+
token_ids, padding_mask = self.packer(
151+
x, sequence_length=sequence_length, add_end_value=False
152+
)
153+
return {
154+
"token_ids": token_ids,
155+
"padding_mask": padding_mask,
156+
}
157+
158+
def generate_postprocess(
159+
self,
160+
x,
161+
):
162+
"""Convert integer token output to strings for generation.
163+
164+
This method reverses `generate_preprocess()`, by first removing all
165+
padding and start/end tokens, and then converting the integer sequence
166+
back to a string.
167+
"""
168+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
169+
# Convert the inputs to numpy arrays if they aren't a tensor already.
170+
if not isinstance(token_ids, tf.Tensor):
171+
token_ids = ops.convert_to_numpy(token_ids)
172+
# Make sure the numpy array has type `int32` since
173+
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
174+
token_ids = token_ids.astype("int32")
175+
if not isinstance(padding_mask, tf.Tensor):
176+
padding_mask = ops.convert_to_numpy(padding_mask)
177+
padding_mask = padding_mask.astype("bool")
178+
# Strip any special tokens during detokenization (e.g. the start and
179+
# end markers). In the future we could make this configurable.
180+
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
181+
padding_mask = padding_mask & (
182+
token_ids != self.tokenizer.start_token_id
183+
)
184+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
185+
return self.tokenizer.detokenize(token_ids)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import pytest
18+
19+
from keras_nlp.models.llama.llama_causal_lm_preprocessor import (
20+
LlamaCausalLMPreprocessor,
21+
)
22+
from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer
23+
from keras_nlp.tests.test_case import TestCase
24+
25+
26+
class LlamaCausalLMPreprocessorTest(TestCase):
27+
def setUp(self):
28+
self.tokenizer = LlamaTokenizer(
29+
# Generated using create_llama_test_proto.py
30+
proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm")
31+
)
32+
self.init_kwargs = {
33+
"tokenizer": self.tokenizer,
34+
"sequence_length": 8,
35+
}
36+
self.input_data = (["the quick brown fox"],)
37+
38+
def test_preprocessor_basics(self):
39+
self.run_preprocessor_test(
40+
cls=LlamaCausalLMPreprocessor,
41+
init_kwargs=self.init_kwargs,
42+
input_data=self.input_data,
43+
expected_output=(
44+
{
45+
"token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]],
46+
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]],
47+
},
48+
[[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels.
49+
[[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights.
50+
),
51+
)
52+
53+
def test_no_start_end_token(self):
54+
input_data = ["the quick brown fox"] * 4
55+
56+
preprocessor = LlamaCausalLMPreprocessor(
57+
**self.init_kwargs,
58+
add_start_token=False,
59+
add_end_token=False,
60+
)
61+
x, y, sw = preprocessor(input_data)
62+
self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4)
63+
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
64+
self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4)
65+
self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4)
66+
67+
def test_generate_preprocess(self):
68+
input_data = "the quick brown fox"
69+
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs)
70+
x = preprocessor.generate_preprocess(input_data)
71+
self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0])
72+
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])
73+
74+
def test_generate_postprocess(self):
75+
input_data = {
76+
"token_ids": [1, 3, 8, 4, 6, 0, 0, 0],
77+
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0],
78+
}
79+
preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs)
80+
x = preprocessor.generate_postprocess(input_data)
81+
self.assertAllEqual(x, "the quick brown fox")
82+
83+
@pytest.mark.extra_large
84+
def test_all_presets(self):
85+
for preset in LlamaCausalLMPreprocessor.presets:
86+
self.run_preset_test(
87+
cls=LlamaCausalLMPreprocessor,
88+
preset=preset,
89+
input_data=self.input_data,
90+
)

0 commit comments

Comments
 (0)