|
| 1 | +# Copyright 2023 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import tensorflow as tf |
| 16 | +from absl import logging |
| 17 | + |
| 18 | +from keras_nlp.api_export import keras_nlp_export |
| 19 | +from keras_nlp.backend import ops |
| 20 | +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor |
| 21 | +from keras_nlp.utils.keras_utils import ( |
| 22 | + convert_inputs_to_list_of_tensor_segments, |
| 23 | +) |
| 24 | +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight |
| 25 | + |
| 26 | + |
| 27 | +@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") |
| 28 | +class LlamaCausalLMPreprocessor(LlamaPreprocessor): |
| 29 | + """Llama Causal LM preprocessor. |
| 30 | +
|
| 31 | + This preprocessing layer is meant for use with |
| 32 | + `keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of |
| 33 | + strings, and return outputs in a `(x, y, sample_weight)` format, where the |
| 34 | + `y` label is the next token id in the `x` sequence. |
| 35 | +
|
| 36 | + For use with generation, the layer also exposes two methods |
| 37 | + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor |
| 38 | + is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods |
| 39 | + will be called implicitly in `generate()`. They can also be called |
| 40 | + standalone (e.g. to precompute preprocessing inputs for generation in a |
| 41 | + separate process). |
| 42 | +
|
| 43 | + Args: |
| 44 | + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. |
| 45 | + sequence_length: The length of the packed inputs. |
| 46 | + add_start_token: If `True`, the preprocessor will prepend the tokenizer |
| 47 | + start token to each input sequence. Default is `True`. |
| 48 | + add_end_token: If `True`, the preprocessor will append the tokenizer |
| 49 | + end token to each input sequence. Default is `False`. |
| 50 | +
|
| 51 | + Call arguments: |
| 52 | + x: A string, `tf.Tensor` or list of python strings. |
| 53 | + y: Label data. Should always be `None` as the layer generates labels. |
| 54 | + sample_weight: Label weights. Should always be `None` as the layer |
| 55 | + generates label weights. |
| 56 | + sequence_length: Pass to override the configured `sequence_length` of |
| 57 | + the layer. |
| 58 | +
|
| 59 | + Examples: |
| 60 | + ```python |
| 61 | + # Load the preprocessor from a preset. |
| 62 | + preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( |
| 63 | + "llama_base_en" |
| 64 | + ) |
| 65 | +
|
| 66 | + # Tokenize and pack a single sentence. |
| 67 | + sentence = tf.constant("League of legends") |
| 68 | + preprocessor(sentence) |
| 69 | + # Same output. |
| 70 | + preprocessor("League of legends") |
| 71 | +
|
| 72 | + # Tokenize a batch of sentences. |
| 73 | + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) |
| 74 | + preprocessor(sentences) |
| 75 | + # Same output. |
| 76 | + preprocessor(["Taco tuesday", "Fish taco please!"]) |
| 77 | +
|
| 78 | + # Map a dataset to preprocess a single sentence. |
| 79 | + features = tf.constant( |
| 80 | + [ |
| 81 | + "Avatar 2 is amazing!", |
| 82 | + "Well, I am not sure.", |
| 83 | + ] |
| 84 | + ) |
| 85 | + labels = tf.constant([1, 0]) |
| 86 | + ds = tf.data.Dataset.from_tensor_slices((features, labels)) |
| 87 | + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) |
| 88 | +
|
| 89 | + # Map a dataset to preprocess unlabled sentences. |
| 90 | + ds = tf.data.Dataset.from_tensor_slices(features) |
| 91 | + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) |
| 92 | + ``` |
| 93 | + """ |
| 94 | + |
| 95 | + def call( |
| 96 | + self, |
| 97 | + x, |
| 98 | + y=None, |
| 99 | + sample_weight=None, |
| 100 | + sequence_length=None, |
| 101 | + ): |
| 102 | + if y is not None or sample_weight is not None: |
| 103 | + logging.warning( |
| 104 | + "`LlamaCausalLMPreprocessor` generates `y` and " |
| 105 | + "`sample_weight` based on your input data, but your data " |
| 106 | + "already contains `y` or `sample_weight`. Your `y` and " |
| 107 | + "`sample_weight` will be ignored." |
| 108 | + ) |
| 109 | + sequence_length = sequence_length or self.sequence_length |
| 110 | + |
| 111 | + x = convert_inputs_to_list_of_tensor_segments(x)[0] |
| 112 | + x = self.tokenizer(x) |
| 113 | + # Pad with one extra token to account for the truncation below. |
| 114 | + token_ids, padding_mask = self.packer( |
| 115 | + x, |
| 116 | + sequence_length=sequence_length + 1, |
| 117 | + add_start_value=self.add_start_token, |
| 118 | + add_end_value=self.add_end_token, |
| 119 | + ) |
| 120 | + # The last token does not have a next token, so we truncate it out. |
| 121 | + x = { |
| 122 | + "token_ids": token_ids[..., :-1], |
| 123 | + "padding_mask": padding_mask[..., :-1], |
| 124 | + } |
| 125 | + # Target `y` will be the next token. |
| 126 | + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] |
| 127 | + return pack_x_y_sample_weight(x, y, sample_weight) |
| 128 | + |
| 129 | + def generate_preprocess( |
| 130 | + self, |
| 131 | + x, |
| 132 | + sequence_length=None, |
| 133 | + ): |
| 134 | + """Convert strings to integer token input for generation. |
| 135 | +
|
| 136 | + Similar to calling the layer for training, this method takes in strings |
| 137 | + or tensor strings, tokenizes and packs the input, and computes a padding |
| 138 | + mask masking all inputs not filled in with a padded value. |
| 139 | +
|
| 140 | + Unlike calling the layer for training, this method does not compute |
| 141 | + labels and will never append a `tokenizer.end_token_id` to the end of |
| 142 | + the sequence (as generation is expected to continue at the end of the |
| 143 | + inputted prompt). |
| 144 | + """ |
| 145 | + if not self.built: |
| 146 | + self.build(None) |
| 147 | + |
| 148 | + x = convert_inputs_to_list_of_tensor_segments(x)[0] |
| 149 | + x = self.tokenizer(x) |
| 150 | + token_ids, padding_mask = self.packer( |
| 151 | + x, sequence_length=sequence_length, add_end_value=False |
| 152 | + ) |
| 153 | + return { |
| 154 | + "token_ids": token_ids, |
| 155 | + "padding_mask": padding_mask, |
| 156 | + } |
| 157 | + |
| 158 | + def generate_postprocess( |
| 159 | + self, |
| 160 | + x, |
| 161 | + ): |
| 162 | + """Convert integer token output to strings for generation. |
| 163 | +
|
| 164 | + This method reverses `generate_preprocess()`, by first removing all |
| 165 | + padding and start/end tokens, and then converting the integer sequence |
| 166 | + back to a string. |
| 167 | + """ |
| 168 | + token_ids, padding_mask = x["token_ids"], x["padding_mask"] |
| 169 | + # Convert the inputs to numpy arrays if they aren't a tensor already. |
| 170 | + if not isinstance(token_ids, tf.Tensor): |
| 171 | + token_ids = ops.convert_to_numpy(token_ids) |
| 172 | + # Make sure the numpy array has type `int32` since |
| 173 | + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. |
| 174 | + token_ids = token_ids.astype("int32") |
| 175 | + if not isinstance(padding_mask, tf.Tensor): |
| 176 | + padding_mask = ops.convert_to_numpy(padding_mask) |
| 177 | + padding_mask = padding_mask.astype("bool") |
| 178 | + # Strip any special tokens during detokenization (e.g. the start and |
| 179 | + # end markers). In the future we could make this configurable. |
| 180 | + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) |
| 181 | + padding_mask = padding_mask & ( |
| 182 | + token_ids != self.tokenizer.start_token_id |
| 183 | + ) |
| 184 | + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) |
| 185 | + return self.tokenizer.detokenize(token_ids) |
0 commit comments