From f7858612f0e59affc0156c374c1719e2c6926653 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:11:03 +0100 Subject: [PATCH 1/5] Create starcode_kv_cache_injection --- starcode_kv_cache_injection | 1 + 1 file changed, 1 insertion(+) create mode 100644 starcode_kv_cache_injection diff --git a/starcode_kv_cache_injection b/starcode_kv_cache_injection new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/starcode_kv_cache_injection @@ -0,0 +1 @@ + From 22d10a7dfca08c43970dc94887ee1b41ea365f62 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:12:21 +0100 Subject: [PATCH 2/5] Delete starcode_kv_cache_injection --- starcode_kv_cache_injection | 1 - 1 file changed, 1 deletion(-) delete mode 100644 starcode_kv_cache_injection diff --git a/starcode_kv_cache_injection b/starcode_kv_cache_injection deleted file mode 100644 index 8b137891791..00000000000 --- a/starcode_kv_cache_injection +++ /dev/null @@ -1 +0,0 @@ - From c01f768856b486f8cd2365d081d87bac9c03b24b Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:12:58 +0100 Subject: [PATCH 3/5] Add files via upload --- starcode_kv_cache_injection/__init__.py | 0 .../kv_cache_injection.py | 216 ++++++++++++++++++ starcode_kv_cache_injection/validation.py | 210 +++++++++++++++++ 3 files changed, 426 insertions(+) create mode 100644 starcode_kv_cache_injection/__init__.py create mode 100644 starcode_kv_cache_injection/kv_cache_injection.py create mode 100644 starcode_kv_cache_injection/validation.py diff --git a/starcode_kv_cache_injection/__init__.py b/starcode_kv_cache_injection/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py new file mode 100644 index 00000000000..2cc52a52da6 --- /dev/null +++ b/starcode_kv_cache_injection/kv_cache_injection.py @@ -0,0 +1,216 @@ +from transformers import AutoTokenizer, AutoConfig + +import onnx +import logging +import os +from typing import List, Optional +from onnx import TensorProto, ModelProto, helper, NodeProto +from sparseml.onnx.utils import ONNXGraph +from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen + +_LOGGER = logging.getLogger(__name__) + +class AdditionalTransformsBigCode(AdditionalTransformsCodeGen): + """ + Since the entries of the causal mask are similar in their values + and layout to the CodeGen causal mask, I inherit from the + AdditionalTransformsCodeGen class + """ + + # position ids are created by a Sub node (the one that is folllowed by a Where node + # in the onnx graph) + POSITION_IDS_MATCHING_PATTERN = dict(op_type="Sub", children_ops=[["Where"]]) + # causal mask is created by a Unsqueeze node (the one that is folllowed by a Where node + # in the onnx graph) + CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]]) + + def add_causal_mask_input(self, model: ModelProto) -> ModelProto: + """ + reformulating this method (originally part of the AdditionalTransformsBase class) + so that the causal mask has shape [batch_size, input_ids_length, 1, sequence_length] + vs the original shape [batch_size, 1, input_ids_length, sequence_length] + """ + + input_ids = self._get_input_proto(model, "input_ids") + attention_mask = self._get_input_proto(model, "attention_mask") + + batch_size = input_ids.type.tensor_type.shape.dim[0].dim_param + input_ids_length = input_ids.type.tensor_type.shape.dim[1].dim_value + sequence_length = attention_mask.type.tensor_type.shape.dim[1].dim_param + + causal_mask_input = helper.make_tensor_value_info( + name=self.CAUSAL_MASK_NAME, + elem_type=TensorProto.INT64, + # this is de-facto the only change from the original method + shape=[batch_size, input_ids_length, 1, sequence_length], + ) + model.graph.input.append(causal_mask_input) + _LOGGER.info(f"Inserted {self.CAUSAL_MASK_NAME} input to the ONNX model") + return model + + def swap_nodes_for_input( + self, + model: ModelProto, + nodes: List[NodeProto], + input_name: str, + nodes_parent_op_type: Optional[str] = None, + ) -> ModelProto: + + """ + Injects the specified input to the graph, replacing the specified nodes. + + :param model: the ONNX model to inject the input into + :param nodes: the nodes to replace with the input + :param input_name: the name of the input to replace the nodes with + :param nodes_parent_op_type: the parent op type of the nodes to replace + + :return: the updated model + """ + + graph = ONNXGraph(model) + for node in nodes: + # edits so that we can have multiple children nodes + children_nodes = graph.get_node_children(node) + for child_node in children_nodes: + if nodes_parent_op_type: + assert child_node.op_type == nodes_parent_op_type, ( + f"Expected to find {nodes_parent_op_type} node, " + f"found {child_node.op_type}" + ) + output_to_replace = node.output[0] + self.log_match(node) + for idx, input_name_child_node in enumerate(child_node.input): + if input_name_child_node == output_to_replace: + graph.update_node_input(child_node, input_name, idx) + + graph.delete_orphaned_node_branches() + + _LOGGER.info( + f"Successfully swapped {len(nodes)} nodes for input '{input_name}'" + ) + + return model + + def transform(self, model: ModelProto) -> ModelProto: + """ + 1. Adds `positions` as an input to the model + 2. Adds `causal_mask` as an input to the model + 2. Finds the node that initially creates the `position_ids` tensor + 3. Updates the node to use the positions input instead of + computing it from the Range op + 4. Finds the nodes that initially create the `causal_mask` tensors + 5. Updates the nodes to use the causal_mask input instead of + computing it from the Slice op + + :param model: model to update + :return: updated model + """ + model = self.add_positions_input(model) + model = self.add_causal_mask_input(model) + + position_ids_nodes = self.find_nodes_by_pattern( + model, pattern=self.POSITION_IDS_MATCHING_PATTERN + ) + if len(position_ids_nodes) != 1: + raise ValueError( + "Expected to find exactly one node matching " + f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, " + f"found {len(position_ids_nodes)}" + ) + + model = self.inject_positions(model, position_ids_nodes, "Where") + + causal_mask_nodes = self.find_nodes_by_pattern( + model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN + ) + model = self.inject_causal_mask(model, causal_mask_nodes, "Where") + model = self.adjust_causal_mask(model) + return model + +def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1): + graph = ONNXGraph(model) + + inputs_to_add = [] + outputs_to_add = [] + + attention_layer_idx = 0 + + for node in model.graph.node: + if node.name in names_nodes_producing_kv_tensors: + + # inject kv cache input/output + cache_input_name_concat = f"past_key_values.{attention_layer_idx}" + cache_output_name_concat = f"present.{attention_layer_idx}" + + cache_input_info = onnx.helper.make_tensor_value_info( + cache_input_name_concat, + TensorProto.FLOAT, + [ + batch_size, + "past_sequence_len", + hidden_size_kv_cache, + ] + ) + + cache_output_info = onnx.helper.make_tensor_value_info( + cache_output_name_concat, + TensorProto.FLOAT, + [ + batch_size, + "past_sequence_len + 1", + hidden_size_kv_cache, + ] + ) + + cache_parent = node + concat_axis = 1 # concat over length axis + + concat_node = onnx.helper.make_node( + op_type="Concat", + inputs=[cache_input_name_concat, cache_parent.output[1]], + outputs=[cache_output_name_concat], + axis=concat_axis, + name=f"concat.{cache_input_name_concat}", + ) + + for _node in model.graph.node: + for input_idx, input_id in enumerate(_node.input): + if input_id == cache_parent.output[1] and _node.name != concat_node.name: + _node.input[input_idx] = cache_output_name_concat + + graph.add_node(concat_node) + inputs_to_add.extend([cache_input_info]) + outputs_to_add.extend([cache_output_info]) + + attention_layer_idx += 1 + _LOGGER.info(f"Injected kv cache input/output for attention layer {attention_layer_idx}") + + model.graph.input.extend(inputs_to_add) + model.graph.output.extend(outputs_to_add) + return model + + +def main(deployment_folder_path, save_name_injected_model): + onnx_model = onnx.load(os.path.join(deployment_folder_path, "model.onnx"), load_external_data=False) + config = AutoConfig.from_pretrained(os.path.join(deployment_folder_path, "config.json")) + # KV Cache injection + onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model, + names_nodes_producing_kv_tensors=[f"/transformer/h.{i}/attn/Split" for i in range(config.n_layer)], + hidden_size_kv_cache=2 * config.n_embd // config.n_head) + # Adjustment of causal masks and positions + transformation = AdditionalTransformsBigCode() + onnx_model = transformation.transform(model = onnx_model) + # Save the model + _LOGGER.info(f"Saved injected model to {os.path.join(deployment_folder_path, save_name_injected_model)}") + onnx.save_model(onnx_model, os.path.join(deployment_folder_path, save_name_injected_model)) + + + +if __name__ == "__main__": + PATH_TO_DEPLOYMENT_FOLDER = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/" + # model created by running: + # sparseml.export /Users/damian/Code/nm/sparseml/tiny_starcoder_py/ --task text-generation --integration transformers --sequence_length 256 --trust_remote_code True + NAME_INJECTED_MODEL = "test.onnx" + main(PATH_TO_DEPLOYMENT_FOLDER, NAME_INJECTED_MODEL) + + diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py new file mode 100644 index 00000000000..cd3150d9ba0 --- /dev/null +++ b/starcode_kv_cache_injection/validation.py @@ -0,0 +1,210 @@ +import onnxruntime as ort +import numpy as np +import onnx +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from onnx.tools import update_model_dims +from sparseml.onnx.utils import ONNXGraph +import logging +import numpy +from typing import List, Union + +_LOGGER = logging.getLogger(__name__) + + +def create_causal_mask( + input_ids: Union[numpy.ndarray, List[int]], + attention_mask: Union[numpy.ndarray, List[int]], + dtype: numpy.dtype = numpy.int64, +) -> numpy.ndarray: + """ + Compute a causal mask from a set of module inputs. + In transformers, a causal mask is a boolean mask that is used to + prevent information from future positions in a sequence from + being used to predict the current position. Each element of the mask + is set to 1 if the corresponding position in the input sequence + is allowed to attend to positions up to and including that position, + and 0 otherwise. + + in case of single-token input, the causal mask is an array + of shape [1, 1, 1, sequence_length], + (essentially the reshaped attention_mask) + + in case of a multi-token input, the causal mask is an array + of shape [batch_size, 1, input_ids_length, sequence_length] + it is a concatenation of a: + - past (cache) causal mask + - and a causal mask (a lower triangular matrix of 1's and 0's) + e.g + ``` + input_ids = [[1,2,3,4]] + attention_mask = [[1,1,1,1,1,1]] + + causal_mask = [[[[ 1 1 | 1 0 0 0 ], + [ 1 1 | 1 1 0 0 ], + [ 1 1 | 1 1 1 0 ], + [ 1 1 | 1 1 1 1 ]]]] + ``` + or + ``` + input_ids = [[1,2,3,4]] + attention_mask = [[0,0,1,1,1,1,1]] + + causal_mask = [[[[ 0 0 1 1 | 1 0 0 0 ], + [ 0 0 1 1 | 1 1 0 0 ], + [ 0 0 1 1 | 1 1 1 0 ], + [ 0 0 1 1 | 1 1 1 1 ]]]] + ``` + + :param input_ids: input ids of the model input + :param attention_mask: attention mask of the model input + :param dtype: data type of the mask + :return: causal mask + """ + if isinstance(input_ids, numpy.ndarray): + batch_size, input_ids_length = input_ids.shape + + else: + batch_size, input_ids_length = 1, len(input_ids) + + if isinstance(attention_mask, numpy.ndarray): + sequence_length = attention_mask.shape[1] + else: + sequence_length = len(attention_mask) + attention_mask = numpy.array(attention_mask)[None, ...] + + if input_ids_length == 1: + causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length)) + return causal_mask.astype(dtype) + + causal_mask = numpy.tril( + numpy.ones((batch_size, 1, input_ids_length, input_ids_length), dtype=dtype), 0 + ) + past_causal_mask = numpy.ones( + (batch_size, 1, input_ids_length, sequence_length - input_ids_length), + dtype=dtype, + ) + causal_mask = numpy.concatenate((past_causal_mask, causal_mask), axis=-1) + + num_zeros = numpy.count_nonzero(attention_mask == 0) + + # changes to the original function + causal_mask[:, :, num_zeros:, :] = 0 + causal_mask = causal_mask.reshape(1, sequence_length, 1, -1) + + return causal_mask + +def apply_input_shapes(model, onnx_model_path, sequence_length, config): + kv_cache_hidden_dim = config.n_embd // config.n_head + cache_changes_in = {n.name: [1, "dynamic_len_1", 2 * kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")} + cache_changes_out = {n.name: [1, "dynamic_len_2", 2 * kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")} + graph = ONNXGraph(model) + + graph.delete_unused_initializers() + graph.delete_orphaned_node_branches() + graph.sort_nodes_topologically() + + model = update_model_dims.update_inputs_outputs_dims(model, + {"input_ids": [1, "dynamic_len_3"], + "positions": [1, "dynamic_len_4"], + "attention_mask": [1, sequence_length], + "causal_mask": [1, "dynamic_len_5", 1, "dynamic_len_6"], + **cache_changes_in}, + + {"logits": [1, "dynamic_len_6", config.vocab_size], **cache_changes_out}) + + onnx.save(model, onnx_model_path) + return model + + +def multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt): + # feed the whole sequence to the model so that we can initially validate + # the correctness of the kv cache injected model + kv_cache_hidden_dim = config.n_embd // config.n_head + inputs = tokenizer(prompt, return_tensors="np", padding='max_length', max_length=sequence_length) + input_ids = inputs.input_ids # (1, sequence_length) + attention_mask = inputs.attention_mask # (1, sequence_length) + kv_cache = {f"past_key_values.{i}": np.zeros((1, 0, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in + range(config.n_layer)} # (1, 0, 2 * embedding [because we have k and v's concatenated]) + causal_mask = create_causal_mask(input_ids, attention_mask) # (1, sequence_length, 1, sequence_length) + positions = attention_mask.cumsum(-1) - 1 # (1, sequence_length) + + session = ort.InferenceSession(onnx_model_path) + + out = session.run( + None, + { + "input_ids": input_ids, + "attention_mask": attention_mask, + **kv_cache, + "causal_mask": causal_mask, + "positions": positions, + }, + ) + logits, *kv_cache = out + + num_tokens_processed = logits_gt.shape[1] # only test the relevant, non-padded tokens + assert np.allclose(logits[:, :num_tokens_processed, :], logits_gt, atol=1e-3) + assert all(np.allclose(x[:, :num_tokens_processed, :], y, atol=1e-3) for x, y in zip(kv_cache, kv_cache_gt)) + +def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt): + # feed the model one token at a time to validate the correctness of the kv cache injected model + model = onnx.load(onnx_model_path, load_external_data=True) + apply_input_shapes(model, onnx_model_path, sequence_length, config) + + kv_cache_hidden_dim = config.n_embd // config.n_head + inputs = tokenizer(prompt, return_tensors="np") + attention_mask = np.zeros((1, sequence_length), dtype=np.int64) + kv_cache = {f"past_key_values.{i}": np.zeros((1,sequence_length-1, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + session = ort.InferenceSession(onnx_model_path) + + for idx, token in enumerate(inputs.input_ids[0]): + if token == tokenizer.pad_token_id: + break + attention_mask[:, -(idx + 1):] = 1 + positions = np.array([[idx]]) + input_ids = np.array([[token]]) + causal_mask = create_causal_mask(input_ids, attention_mask) + + outputs = session .run(None, { + "input_ids": input_ids, + "attention_mask": attention_mask, + "positions": positions, + "causal_mask": causal_mask, + **kv_cache + }) + # will not run without throwing an error, there are some missing pieces that need to be addressed + +def get_baseline(prompt, hf_model_name, tokenizer): + model = AutoModelForCausalLM.from_pretrained(hf_model_name) + tokens = tokenizer.encode(prompt, return_tensors="pt") + out = model(tokens, return_dict=True) + logits_gt = out.logits.detach().numpy() + kv_cache_gt = [t.detach().numpy() for t in out.past_key_values] + return logits_gt, kv_cache_gt + +def main(prompt, hf_model_name, onnx_model_path, sequence_length): + config = AutoConfig.from_pretrained(hf_model_name) + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + tokenizer.pad_token = tokenizer.eos_token + + logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer) + + multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) + _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") + singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) + _LOGGER.info("Successfully ran single-token inference on the kv cache injected model") + + + +if __name__ == "__main__": + PROMPT = "def eight_queens():\n if True:\n return 1\n " + HF_MODEL_NAME = "bigcode/tiny_starcoder_py" + ONNX_MODEL_PATH = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/test.onnx" + SEQUENCE_LENGTH = 256 + main(PROMPT, HF_MODEL_NAME, ONNX_MODEL_PATH, SEQUENCE_LENGTH) + + + + + + From 8c7f7992073c7982f48424419df24533d8e67bb4 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:15:00 +0100 Subject: [PATCH 4/5] Add files via upload --- .../kv_cache_injection.py | 88 +++++++++++++++++++ starcode_kv_cache_injection/validation.py | 14 +-- 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py index 2cc52a52da6..4b01ee0dd1a 100644 --- a/starcode_kv_cache_injection/kv_cache_injection.py +++ b/starcode_kv_cache_injection/kv_cache_injection.py @@ -7,6 +7,7 @@ from onnx import TensorProto, ModelProto, helper, NodeProto from sparseml.onnx.utils import ONNXGraph from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen +from sparseml.onnx.utils.helpers import get_nodes_by_output_id _LOGGER = logging.getLogger(__name__) @@ -24,6 +25,60 @@ class AdditionalTransformsBigCode(AdditionalTransformsCodeGen): # in the onnx graph) CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]]) + def swap_nodes_for_input( + self, + model: ModelProto, + nodes: List[NodeProto], + input_name: str, + nodes_parent_op_type: Optional[str] = None, + ) -> ModelProto: + + """ + Injects the specified input to the graph, replacing the specified nodes. + + :param model: the ONNX model to inject the input into + :param nodes: the nodes to replace with the input + :param input_name: the name of the input to replace the nodes with + :param nodes_parent_op_type: the parent op type of the nodes to replace + + :return: the updated model + """ + + graph = ONNXGraph(model) + for node in nodes: + child_node = graph.get_node_children(node)[0] + + if nodes_parent_op_type: + assert child_node.op_type == nodes_parent_op_type, ( + f"Expected to find {nodes_parent_op_type} node, " + f"found {child_node.op_type}" + ) + output_to_replace = node.output[0] + self.log_match(node) + for idx, input_name_child_node in enumerate(child_node.input): + if input_name_child_node == output_to_replace: + graph.update_node_input(child_node, input_name, idx) + children_nodes = graph.get_node_children(node) + for child_node in children_nodes: + if nodes_parent_op_type: + assert child_node.op_type == nodes_parent_op_type, ( + f"Expected to find {nodes_parent_op_type} node, " + f"found {child_node.op_type}" + ) + output_to_replace = node.output[0] + self.log_match(node) + for idx, input_name_child_node in enumerate(child_node.input): + if input_name_child_node == output_to_replace: + graph.update_node_input(child_node, input_name, idx) + + graph.delete_orphaned_node_branches() + + _LOGGER.info( + f"Successfully swapped {len(nodes)} nodes for input '{input_name}'" + ) + + return model + def add_causal_mask_input(self, model: ModelProto) -> ModelProto: """ reformulating this method (originally part of the AdditionalTransformsBase class) @@ -91,6 +146,36 @@ def swap_nodes_for_input( return model + def add_constant_reshape_node(self, model: ModelProto) -> ModelProto: + """ + Adds positions as an input to the model. + + Positions is a tensor of shape and dtype + equal to input_ids. + + :param model: model to update + :return: updated model + """ + graph = ONNXGraph(model) + # create a constant node that will feed value (1, 256, 768) to the reshape node + constant_node = onnx.helper.make_node( + "Constant", + inputs=[], + name="abc", + outputs=["reshape_input"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=TensorProto.INT64, + dims=[3], + vals=[1, 256, 768], + ), + ) + graph.add_node(constant_node) + reshape_node = get_nodes_by_output_id(model, "/transformer/Reshape_2_output_0")[0] + reshape_node.input[1] = "reshape_input" + _LOGGER.info(f"Inserted constant reshape node to the ONNX model") + return model + def transform(self, model: ModelProto) -> ModelProto: """ 1. Adds `positions` as an input to the model @@ -107,6 +192,8 @@ def transform(self, model: ModelProto) -> ModelProto: """ model = self.add_positions_input(model) model = self.add_causal_mask_input(model) + model = self.add_constant_reshape_node(model) + position_ids_nodes = self.find_nodes_by_pattern( model, pattern=self.POSITION_IDS_MATCHING_PATTERN @@ -125,6 +212,7 @@ def transform(self, model: ModelProto) -> ModelProto: ) model = self.inject_causal_mask(model, causal_mask_nodes, "Where") model = self.adjust_causal_mask(model) + return model def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1): diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py index cd3150d9ba0..c8eec7a9426 100644 --- a/starcode_kv_cache_injection/validation.py +++ b/starcode_kv_cache_injection/validation.py @@ -73,7 +73,7 @@ def create_causal_mask( attention_mask = numpy.array(attention_mask)[None, ...] if input_ids_length == 1: - causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length)) + causal_mask = numpy.reshape(attention_mask, (batch_size, sequence_length, 1, -1)) return causal_mask.astype(dtype) causal_mask = numpy.tril( @@ -164,19 +164,23 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque positions = np.array([[idx]]) input_ids = np.array([[token]]) causal_mask = create_causal_mask(input_ids, attention_mask) - - outputs = session .run(None, { + outputs = session.run(None, { "input_ids": input_ids, "attention_mask": attention_mask, "positions": positions, "causal_mask": causal_mask, **kv_cache }) + logits, *kv_cache = outputs + for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)): + if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3): + print(f"Cache {_idx} matches for iteration {idx}") # will not run without throwing an error, there are some missing pieces that need to be addressed def get_baseline(prompt, hf_model_name, tokenizer): model = AutoModelForCausalLM.from_pretrained(hf_model_name) tokens = tokenizer.encode(prompt, return_tensors="pt") + model.generate(tokens[:,:1], max_length=256) out = model(tokens, return_dict=True) logits_gt = out.logits.detach().numpy() kv_cache_gt = [t.detach().numpy() for t in out.past_key_values] @@ -189,8 +193,8 @@ def main(prompt, hf_model_name, onnx_model_path, sequence_length): logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer) - multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) - _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") + #multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) + #_LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) _LOGGER.info("Successfully ran single-token inference on the kv cache injected model") From c6ad1916ce50791b23b20c7222cb6f70c81f2ed7 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Wed, 6 Mar 2024 17:27:40 +0100 Subject: [PATCH 5/5] producing running (hopefully) but incorrect model --- .../kv_cache_injection.py | 150 +++++++----------- starcode_kv_cache_injection/run_model.py | 10 ++ starcode_kv_cache_injection/validation.py | 36 +++-- 3 files changed, 92 insertions(+), 104 deletions(-) create mode 100644 starcode_kv_cache_injection/run_model.py diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py index 4b01ee0dd1a..907ed0a6e7f 100644 --- a/starcode_kv_cache_injection/kv_cache_injection.py +++ b/starcode_kv_cache_injection/kv_cache_injection.py @@ -6,6 +6,7 @@ from typing import List, Optional from onnx import TensorProto, ModelProto, helper, NodeProto from sparseml.onnx.utils import ONNXGraph +from sparseml.exporters.transforms.kv_cache.cache_keys_and_values import reshape_kv_cache_inputs_outputs from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen from sparseml.onnx.utils.helpers import get_nodes_by_output_id @@ -25,84 +26,6 @@ class AdditionalTransformsBigCode(AdditionalTransformsCodeGen): # in the onnx graph) CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]]) - def swap_nodes_for_input( - self, - model: ModelProto, - nodes: List[NodeProto], - input_name: str, - nodes_parent_op_type: Optional[str] = None, - ) -> ModelProto: - - """ - Injects the specified input to the graph, replacing the specified nodes. - - :param model: the ONNX model to inject the input into - :param nodes: the nodes to replace with the input - :param input_name: the name of the input to replace the nodes with - :param nodes_parent_op_type: the parent op type of the nodes to replace - - :return: the updated model - """ - - graph = ONNXGraph(model) - for node in nodes: - child_node = graph.get_node_children(node)[0] - - if nodes_parent_op_type: - assert child_node.op_type == nodes_parent_op_type, ( - f"Expected to find {nodes_parent_op_type} node, " - f"found {child_node.op_type}" - ) - output_to_replace = node.output[0] - self.log_match(node) - for idx, input_name_child_node in enumerate(child_node.input): - if input_name_child_node == output_to_replace: - graph.update_node_input(child_node, input_name, idx) - children_nodes = graph.get_node_children(node) - for child_node in children_nodes: - if nodes_parent_op_type: - assert child_node.op_type == nodes_parent_op_type, ( - f"Expected to find {nodes_parent_op_type} node, " - f"found {child_node.op_type}" - ) - output_to_replace = node.output[0] - self.log_match(node) - for idx, input_name_child_node in enumerate(child_node.input): - if input_name_child_node == output_to_replace: - graph.update_node_input(child_node, input_name, idx) - - graph.delete_orphaned_node_branches() - - _LOGGER.info( - f"Successfully swapped {len(nodes)} nodes for input '{input_name}'" - ) - - return model - - def add_causal_mask_input(self, model: ModelProto) -> ModelProto: - """ - reformulating this method (originally part of the AdditionalTransformsBase class) - so that the causal mask has shape [batch_size, input_ids_length, 1, sequence_length] - vs the original shape [batch_size, 1, input_ids_length, sequence_length] - """ - - input_ids = self._get_input_proto(model, "input_ids") - attention_mask = self._get_input_proto(model, "attention_mask") - - batch_size = input_ids.type.tensor_type.shape.dim[0].dim_param - input_ids_length = input_ids.type.tensor_type.shape.dim[1].dim_value - sequence_length = attention_mask.type.tensor_type.shape.dim[1].dim_param - - causal_mask_input = helper.make_tensor_value_info( - name=self.CAUSAL_MASK_NAME, - elem_type=TensorProto.INT64, - # this is de-facto the only change from the original method - shape=[batch_size, input_ids_length, 1, sequence_length], - ) - model.graph.input.append(causal_mask_input) - _LOGGER.info(f"Inserted {self.CAUSAL_MASK_NAME} input to the ONNX model") - return model - def swap_nodes_for_input( self, model: ModelProto, @@ -175,6 +98,31 @@ def add_constant_reshape_node(self, model: ModelProto) -> ModelProto: reshape_node.input[1] = "reshape_input" _LOGGER.info(f"Inserted constant reshape node to the ONNX model") return model + + def add_causal_mask_reshape_node(self, model: ModelProto) -> ModelProto: + """ + Adds positions as an input to the model. + + Positions is a tensor of shape and dtype + equal to input_ids. + + :param model: model to update + :return: updated model + """ + graph = ONNXGraph(model) + + transpose_node = onnx.helper.make_node( + op_type="Transpose", + inputs=["causal_mask"], + outputs=["causal_mask_transpose"], + name=f"causal_mask_transpose", + perm=(0,3,2,1), + ) + graph.add_node(transpose_node) + reshape_node = get_nodes_by_output_id(model, "causal_mask_adjusted")[0] + reshape_node.input[0] = "causal_mask_transpose" + _LOGGER.info(f"Inserted transpose to the causal mask in the ONNX model") + return model def transform(self, model: ModelProto) -> ModelProto: """ @@ -212,29 +160,31 @@ def transform(self, model: ModelProto) -> ModelProto: ) model = self.inject_causal_mask(model, causal_mask_nodes, "Where") model = self.adjust_causal_mask(model) - + model = self.add_causal_mask_reshape_node(model) return model -def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1): +def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes: List[str], hidden_size_kv_cache, batch_size = 1, key: bool = True, output_num:int=0): graph = ONNXGraph(model) inputs_to_add = [] outputs_to_add = [] - + num_attention_heads = 1 attention_layer_idx = 0 for node in model.graph.node: - if node.name in names_nodes_producing_kv_tensors: + if node.name in names_nodes: # inject kv cache input/output - cache_input_name_concat = f"past_key_values.{attention_layer_idx}" - cache_output_name_concat = f"present.{attention_layer_idx}" + cache_name = "key" if key else "value" + cache_input_name_concat = f"past_key_values.{attention_layer_idx}.{cache_name}" + cache_output_name_concat = f"present.{attention_layer_idx}.{cache_name}" cache_input_info = onnx.helper.make_tensor_value_info( cache_input_name_concat, TensorProto.FLOAT, [ batch_size, + num_attention_heads, "past_sequence_len", hidden_size_kv_cache, ] @@ -245,17 +195,30 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t TensorProto.FLOAT, [ batch_size, + num_attention_heads, "past_sequence_len + 1", hidden_size_kv_cache, ] ) + model, cache_input_dims_concat, cache_input_name_concat, cache_output_name_concat = reshape_kv_cache_inputs_outputs( + model=model, + cache_input_name=cache_input_name_concat, + cache_output_name=cache_output_name_concat, + cache_input_dims= [ + batch_size, + num_attention_heads, + "past_sequence_len", + hidden_size_kv_cache, + ], + batch_size=batch_size, + num_attention_heads=1, + ) cache_parent = node concat_axis = 1 # concat over length axis - concat_node = onnx.helper.make_node( op_type="Concat", - inputs=[cache_input_name_concat, cache_parent.output[1]], + inputs=[cache_input_name_concat, cache_parent.output[output_num]], outputs=[cache_output_name_concat], axis=concat_axis, name=f"concat.{cache_input_name_concat}", @@ -263,7 +226,7 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t for _node in model.graph.node: for input_idx, input_id in enumerate(_node.input): - if input_id == cache_parent.output[1] and _node.name != concat_node.name: + if input_id == cache_parent.output[output_num] and _node.name != concat_node.name: _node.input[input_idx] = cache_output_name_concat graph.add_node(concat_node) @@ -271,7 +234,7 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t outputs_to_add.extend([cache_output_info]) attention_layer_idx += 1 - _LOGGER.info(f"Injected kv cache input/output for attention layer {attention_layer_idx}") + print(f"Injected kv cache input/output for {attention_layer_idx}:{cache_name}") model.graph.input.extend(inputs_to_add) model.graph.output.extend(outputs_to_add) @@ -283,8 +246,15 @@ def main(deployment_folder_path, save_name_injected_model): config = AutoConfig.from_pretrained(os.path.join(deployment_folder_path, "config.json")) # KV Cache injection onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model, - names_nodes_producing_kv_tensors=[f"/transformer/h.{i}/attn/Split" for i in range(config.n_layer)], - hidden_size_kv_cache=2 * config.n_embd // config.n_head) + names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)], + hidden_size_kv_cache= config.n_embd // config.n_head, + key=True, + output_num=0) + onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model, + names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)], + hidden_size_kv_cache= config.n_embd // config.n_head, + key=False, + output_num=1) # Adjustment of causal masks and positions transformation = AdditionalTransformsBigCode() onnx_model = transformation.transform(model = onnx_model) diff --git a/starcode_kv_cache_injection/run_model.py b/starcode_kv_cache_injection/run_model.py new file mode 100644 index 00000000000..170fef0632a --- /dev/null +++ b/starcode_kv_cache_injection/run_model.py @@ -0,0 +1,10 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +checkpoint = "bigcode/tiny_starcoder_py" +device="cpu" +tokenizer = AutoTokenizer.from_pretrained(checkpoint) +model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) + +inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device) +outputs = model.generate(inputs, max_new_tokens=10) +print(tokenizer.decode(outputs[0])) \ No newline at end of file diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py index c8eec7a9426..89e734cabd8 100644 --- a/starcode_kv_cache_injection/validation.py +++ b/starcode_kv_cache_injection/validation.py @@ -73,7 +73,7 @@ def create_causal_mask( attention_mask = numpy.array(attention_mask)[None, ...] if input_ids_length == 1: - causal_mask = numpy.reshape(attention_mask, (batch_size, sequence_length, 1, -1)) + causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length)) return causal_mask.astype(dtype) causal_mask = numpy.tril( @@ -89,14 +89,13 @@ def create_causal_mask( # changes to the original function causal_mask[:, :, num_zeros:, :] = 0 - causal_mask = causal_mask.reshape(1, sequence_length, 1, -1) return causal_mask def apply_input_shapes(model, onnx_model_path, sequence_length, config): kv_cache_hidden_dim = config.n_embd // config.n_head - cache_changes_in = {n.name: [1, "dynamic_len_1", 2 * kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")} - cache_changes_out = {n.name: [1, "dynamic_len_2", 2 * kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")} + cache_changes_in = {n.name: [1, 1,"dynamic_len_1", kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")} + cache_changes_out = {n.name: [1, 1,"dynamic_len_2", kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")} graph = ONNXGraph(model) graph.delete_unused_initializers() @@ -107,9 +106,8 @@ def apply_input_shapes(model, onnx_model_path, sequence_length, config): {"input_ids": [1, "dynamic_len_3"], "positions": [1, "dynamic_len_4"], "attention_mask": [1, sequence_length], - "causal_mask": [1, "dynamic_len_5", 1, "dynamic_len_6"], + "causal_mask": [1, 1, "dynamic_len_5", "dynamic_len_6"], **cache_changes_in}, - {"logits": [1, "dynamic_len_6", config.vocab_size], **cache_changes_out}) onnx.save(model, onnx_model_path) @@ -123,8 +121,11 @@ def multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequen inputs = tokenizer(prompt, return_tensors="np", padding='max_length', max_length=sequence_length) input_ids = inputs.input_ids # (1, sequence_length) attention_mask = inputs.attention_mask # (1, sequence_length) - kv_cache = {f"past_key_values.{i}": np.zeros((1, 0, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in - range(config.n_layer)} # (1, 0, 2 * embedding [because we have k and v's concatenated]) + kv_cache_value = {f"past_key_values.{i}.value": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in + range(config.n_layer)} # (1, 0, embedding) + kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in + range(config.n_layer)} # (1, 0, embedding) + kv_cache = {**kv_cache_keys, **kv_cache_value} causal_mask = create_causal_mask(input_ids, attention_mask) # (1, sequence_length, 1, sequence_length) positions = attention_mask.cumsum(-1) - 1 # (1, sequence_length) @@ -154,7 +155,9 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque kv_cache_hidden_dim = config.n_embd // config.n_head inputs = tokenizer(prompt, return_tensors="np") attention_mask = np.zeros((1, sequence_length), dtype=np.int64) - kv_cache = {f"past_key_values.{i}": np.zeros((1,sequence_length-1, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + kv_cache_values = {f"past_key_values.{i}.value": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)} + kv_cache = {**kv_cache_keys, **kv_cache_values} session = ort.InferenceSession(onnx_model_path) for idx, token in enumerate(inputs.input_ids[0]): @@ -164,6 +167,11 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque positions = np.array([[idx]]) input_ids = np.array([[token]]) causal_mask = create_causal_mask(input_ids, attention_mask) + print(causal_mask.shape) + print(input_ids.shape) + print(attention_mask.shape) + print(positions) + print(kv_cache["past_key_values.0.key"].shape) outputs = session.run(None, { "input_ids": input_ids, "attention_mask": attention_mask, @@ -171,10 +179,10 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque "causal_mask": causal_mask, **kv_cache }) - logits, *kv_cache = outputs - for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)): - if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3): - print(f"Cache {_idx} matches for iteration {idx}") + #logits, *kv_cache = outputs + #for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)): + # if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3): + # print(f"Cache {_idx} matches for iteration {idx}") # will not run without throwing an error, there are some missing pieces that need to be addressed def get_baseline(prompt, hf_model_name, tokenizer): @@ -194,7 +202,7 @@ def main(prompt, hf_model_name, onnx_model_path, sequence_length): logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer) #multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) - #_LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") + # _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model") singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt) _LOGGER.info("Successfully ran single-token inference on the kv cache injected model")