Skip to content

Commit fe5a53b

Browse files
Add Falcon Preprocessor. (#1498)
1 parent 7ef18a1 commit fe5a53b

File tree

4 files changed

+547
-0
lines changed

4 files changed

+547
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from absl import logging
17+
18+
from keras_nlp.api_export import keras_nlp_export
19+
from keras_nlp.backend import ops
20+
from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor
21+
from keras_nlp.utils.keras_utils import (
22+
convert_inputs_to_list_of_tensor_segments,
23+
)
24+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
25+
26+
27+
@keras_nlp_export("keras_nlp.models.FalconCausalLMPreprocessor")
28+
class FalconCausalLMPreprocessor(FalconPreprocessor):
29+
"""Falcon Causal LM preprocessor.
30+
31+
This preprocessing layer is meant for use with
32+
`keras_nlp.models.FalconCausalLM`. By default, it will take in batches of
33+
strings, and return outputs in a `(x, y, sample_weight)` format, where the
34+
`y` label is the next token id in the `x` sequence.
35+
36+
For use with generation, the layer also exposes two methods
37+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
38+
is attached to a `keras_nlp.models.FalconCausalLM` instance, these methods
39+
will be called implicitly in `generate()`. They can also be called
40+
standalone (e.g. to precompute preprocessing inputs for generation in a
41+
separate process).
42+
43+
Args:
44+
tokenizer: A `keras_nlp.models.FalconTokenizer` instance.
45+
sequence_length: The length of the packed inputs.
46+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
47+
start token to each input sequence.
48+
add_end_token: If `True`, the preprocessor will append the tokenizer
49+
end token to each input sequence.
50+
51+
Call arguments:
52+
x: A string, `tf.Tensor` or list of python strings.
53+
y: Label data. Should always be `None` as the layer generates labels.
54+
sample_weight: Label weights. Should always be `None` as the layer
55+
generates label weights.
56+
sequence_length: Pass to override the configured `sequence_length` of
57+
the layer.
58+
59+
Examples:
60+
```python
61+
# Load the preprocessor from a preset.
62+
preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset(
63+
"falcon_refinedweb_1b_en"
64+
)
65+
66+
# Tokenize and pack a single sentence.
67+
sentence = tf.constant("League of legends")
68+
preprocessor(sentence)
69+
# Same output.
70+
preprocessor("League of legends")
71+
72+
# Tokenize a batch of sentences.
73+
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
74+
preprocessor(sentences)
75+
# Same output.
76+
preprocessor(["Taco tuesday", "Fish taco please!"])
77+
78+
# Map a dataset to preprocess a single sentence.
79+
features = tf.constant(
80+
[
81+
"Avatar 2 is amazing!",
82+
"Well, I am not sure.",
83+
]
84+
)
85+
labels = tf.constant([1, 0])
86+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
87+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
88+
89+
# Map a dataset to preprocess unlabled sentences.
90+
ds = tf.data.Dataset.from_tensor_slices(features)
91+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
92+
```
93+
"""
94+
95+
def call(
96+
self,
97+
x,
98+
y=None,
99+
sample_weight=None,
100+
sequence_length=None,
101+
):
102+
if y is not None or sample_weight is not None:
103+
logging.warning(
104+
"`FalconCausalLMPreprocessor` generates `y` and `sample_weight` "
105+
"based on your input data, but your data already contains `y` "
106+
"or `sample_weight`. Your `y` and `sample_weight` will be "
107+
"ignored."
108+
)
109+
sequence_length = sequence_length or self.sequence_length
110+
111+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
112+
x = self.tokenizer(x)
113+
# Pad with one extra token to account for the truncation below.
114+
token_ids, padding_mask = self.packer(
115+
x,
116+
sequence_length=sequence_length + 1,
117+
add_start_value=self.add_start_token,
118+
add_end_value=self.add_end_token,
119+
)
120+
# The last token does not have a next token, so we truncate it out.
121+
x = {
122+
"token_ids": token_ids[..., :-1],
123+
"padding_mask": padding_mask[..., :-1],
124+
}
125+
# Target `y` will be the next token.
126+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
127+
return pack_x_y_sample_weight(x, y, sample_weight)
128+
129+
def generate_preprocess(
130+
self,
131+
x,
132+
sequence_length=None,
133+
):
134+
"""Convert strings to integer token input for generation.
135+
136+
Similar to calling the layer for training, this method takes in strings
137+
or tensor strings, tokenizes and packs the input, and computes a padding
138+
mask masking all inputs not filled in with a padded value.
139+
140+
Unlike calling the layer for training, this method does not compute
141+
labels and will never append a `tokenizer.end_token_id` to the end of
142+
the sequence (as generation is expected to continue at the end of the
143+
inputted prompt).
144+
"""
145+
if not self.built:
146+
self.build(None)
147+
148+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
149+
x = self.tokenizer(x)
150+
token_ids, padding_mask = self.packer(
151+
x, sequence_length=sequence_length, add_end_value=False
152+
)
153+
return {
154+
"token_ids": token_ids,
155+
"padding_mask": padding_mask,
156+
}
157+
158+
def generate_postprocess(
159+
self,
160+
x,
161+
):
162+
"""Convert integer token output to strings for generation.
163+
164+
This method reverses `generate_preprocess()`, by first removing all
165+
padding and start/end tokens, and then converting the integer sequence
166+
back to a string.
167+
"""
168+
if not self.built:
169+
self.build(None)
170+
171+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
172+
token_ids = ops.convert_to_numpy(token_ids)
173+
padding_mask = ops.convert_to_numpy(padding_mask)
174+
# Strip any special tokens during detokenization (e.g. the start and
175+
# end markers). In the future we could make this configurable.
176+
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
177+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
178+
return self.tokenizer.detokenize(token_ids)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from keras_nlp.models.falcon.falcon_causal_lm_preprocessor import (
18+
FalconCausalLMPreprocessor,
19+
)
20+
from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer
21+
from keras_nlp.tests.test_case import TestCase
22+
23+
24+
class FalconCausalLMPreprocessorTest(TestCase):
25+
def setUp(self):
26+
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
27+
self.vocab += ["<|endoftext|>"]
28+
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
29+
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
30+
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
31+
self.merges += ["Ġai r", "Ġa i", "pla ne"]
32+
self.tokenizer = FalconTokenizer(
33+
vocabulary=self.vocab,
34+
merges=self.merges,
35+
)
36+
self.init_kwargs = {
37+
"tokenizer": self.tokenizer,
38+
"sequence_length": 8,
39+
}
40+
self.input_data = ["airplane at airport"]
41+
42+
def test_preprocessor_basics(self):
43+
self.run_preprocessor_test(
44+
cls=FalconCausalLMPreprocessor,
45+
init_kwargs=self.init_kwargs,
46+
input_data=self.input_data,
47+
expected_output=(
48+
{
49+
"token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
50+
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
51+
},
52+
[[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels.
53+
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
54+
),
55+
)
56+
57+
def test_no_start_end_token(self):
58+
input_data = ["airplane at airport"] * 4
59+
60+
preprocessor = FalconCausalLMPreprocessor(
61+
**self.init_kwargs,
62+
add_start_token=False,
63+
add_end_token=False,
64+
)
65+
x, y, sw = preprocessor(input_data)
66+
self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
67+
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
68+
self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
69+
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
70+
71+
def test_generate_preprocess(self):
72+
input_data = "airplane at airport"
73+
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
74+
x = preprocessor.generate_preprocess(input_data)
75+
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
76+
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
77+
78+
def test_generate_postprocess(self):
79+
input_data = {
80+
"token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
81+
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
82+
}
83+
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
84+
x = preprocessor.generate_postprocess(input_data)
85+
self.assertAllEqual(x, "airplane at airport")
86+
87+
@pytest.mark.extra_large
88+
def test_all_presets(self):
89+
for preset in FalconCausalLMPreprocessor.presets:
90+
self.run_preset_test(
91+
cls=FalconCausalLMPreprocessor,
92+
preset=preset,
93+
input_data=self.input_data,
94+
)

0 commit comments

Comments
 (0)