From 0d699321e8449c863465923a890dd62d106ad29f Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:02:08 -0500 Subject: [PATCH] [llama] transpose global state (#104) Transposed the global state to allow continuous memory access while updating. Previously only the 0th idx of the 2nd dim was getting updated on initialization. Also update needs to be stored back to the global variable. Couple other minor bugfixes. --- examples/aot_mlp/mlp_export_simple.py | 1 + examples/llama2_inference/stateless_llama.py | 106 +++++++++---------- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index fed4795d4..f6345cd5e 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -34,6 +34,7 @@ def forward(self, x: torch.Tensor): model = MLP() example_x = torch.empty(97, 8, dtype=torch.float32) exported = aot.export(model, example_x) +aot.CompiledModule.run_import(exported.compiled_module) exported.print_readable() compiled_binary = exported.compile(save_to=None) diff --git a/examples/llama2_inference/stateless_llama.py b/examples/llama2_inference/stateless_llama.py index f4ad586bd..ea7fc5c6e 100644 --- a/examples/llama2_inference/stateless_llama.py +++ b/examples/llama2_inference/stateless_llama.py @@ -49,7 +49,6 @@ Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] """ - class InferenceModel(torch.nn.Module): def __init__( self, @@ -212,27 +211,15 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): all_pkv_tensors = [] for i in range(heads * 2): sliced = IREE.tensor_slice( - global_pkv, i, 0, (0, heads), (0, seq_step), (0, hidden_dim) + global_pkv, i, 0, (0, seq_step), (0, heads), (0, hidden_dim) ) # sequence context dim all_pkv_tensors.append( - IREE.tensor_reshape(sliced, 1, heads, seq_step, hidden_dim) + IREE.tensor_reshape(sliced, 1, seq_step, heads, hidden_dim) ) return all_pkv_tensors -def update_state(state, state_updates, seq_step, heads, hidden_dim): - all_updates = [] - for i in range(heads * 2): - update = IREE.tensor_reshape( - state_updates[i], 1, 1, heads, 1, hidden_dim - ) - all_updates.append( - IREE.tensor_update(state, update, i, 0, 0, seq_step, 0) - ) - return all_updates - - def export_transformer_model( state_schema_path, hf_model_name, hf_auth_token, compile_to ): @@ -262,7 +249,7 @@ def export_transformer_model( HIDDEN_DIM = 128 BATCH_SIZE = 1 global_pkv = torch.zeros( - size=(HEADS * 2, BATCH_SIZE, HEADS, MAX_STEP_SEQ, HIDDEN_DIM), + size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM), dtype=torch.float32, ) seq_step = AbstractIndex @@ -281,16 +268,14 @@ def run_initialize( token, *state = self.initialize(x, constraints=init_const) updates = [] self.global_seq_step = IREE.tensor_dim( - state[0], 2 - ) # 2nd dimension of arbitrarily 0th kv tensor + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor for i in range(HEADS * 2): slice_of_state = IREE.tensor_reshape( - state[i], 1, 1, HEADS, self.global_seq_step, HIDDEN_DIM + state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM ) - updates.append( - IREE.tensor_update( - self.global_state, slice_of_state, i, 0, 0, 0, 0 - ) + self.global_state = IREE.tensor_update( + self.global_state, slice_of_state, i, 0, 0, 0, 0 ) return token @@ -298,22 +283,22 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): state_arg = slice_up_to_step( self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM ) - forw_const = [state_arg[0].dynamic_dim(2) < MAX_STEP_SEQ] + [ - x.dynamic_dim(2) == (state_arg[0].dynamic_dim(2)) + forw_const = [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) for x in state_arg[1:] ] token, *state_update = self.forward( x, *state_arg, constraints=forw_const ) - self.global_seq_step = self.global_seq_step + 1 - res = update_state( - self.global_state, - state_update, - self.global_seq_step, - HEADS, - HIDDEN_DIM, - ) + for i in range(HEADS * 2): + update = IREE.tensor_reshape( + state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM + ) + self.global_state = IREE.tensor_update( + self.global_state, update, i, 0, self.global_seq_step, 0, 0 + ) + self.global_seq_step = self.global_seq_step + 1 return token def get_global_state(self): @@ -328,15 +313,19 @@ def initialize(input_ids): state1_flat, _ = pytree.tree_flatten(result.past_key_values) token1 = torch.argmax(result.logits[:, -1, :], dim=1) token1 = token1[None, :] + state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat] return token1, *state1_flat @jittable def forward(token0: torch.Tensor, *state0_flat): # Unpad the states. + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] state0 = pytree.tree_unflatten(state0_flat, state_schema) result = mod.forward(token0, past_key_values=state0) state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [x[:, :, -1:, :] for x in state1_flat] + state1_flat = [ + torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat + ] token1 = torch.argmax(result.logits[:, -1, :], dim=1) token1 = token1[None, :] return token1, *state1_flat @@ -399,39 +388,48 @@ def run_vmfb_comparison(args): example_input_id = initial_input.input_ids device_inputs = [ireert.asdevicearray(config.device, example_input_id)] - step0 = ModuleCompiled["get_seq_step"]() - print("step0 :"+str(step0)) results = ModuleCompiled["run_initialize"](*device_inputs) - pkv = ModuleCompiled["get_global_state"]().to_host() - step = ModuleCompiled["get_seq_step"]() - print(f"step: {step}") - sliced = pkv[0,:,:,:step,:] - def format_out(results): return torch.tensor(results.to_host()[0][0]) -# print(tokenizer.decode(format_out(results))) - for i in range(10): - results = ModuleCompiled["run_forward"](results) - step = ModuleCompiled["get_seq_step"]() - print(f"step: {step}") - print(tokenizer.decode(format_out(results))) model = InferenceModel(args) - def get_token_from_logits(logits): return torch.argmax(logits[:, -1, :], dim=1) base_model_results = model.base_model.forward(example_input_id) base_model_token = get_token_from_logits(base_model_results.logits) -# print(tokenizer.decode(base_model_token)) + bm_pkv = base_model_results.past_key_values + turbine_results = [] + torch_results = [] + turbine_results.append(format_out(results)) + torch_results.append(int(base_model_token)) + while base_model_token != 2: + results = ModuleCompiled["run_forward"](results) + step = ModuleCompiled["get_seq_step"]() + pkv = ModuleCompiled["get_global_state"]().to_host() + # print(f"turbine: {tokenizer.decode(format_out(results))}") + base_model_results = model.base_model.forward( + torch.unsqueeze(base_model_token, 0), past_key_values=bm_pkv + ) + base_model_token = int( + get_token_from_logits(base_model_results.logits)[0] + ) + bm_pkv = base_model_results.past_key_values + # print(f"pytorch: {tokenizer.decode(base_model_token)}") + turbine_results.append(format_out(results)) + torch_results.append(base_model_token) + + print("\n\n") + print("what is the best hardware company?") + print("\n\n") + + print("turbine output: ") + print(tokenizer.decode(turbine_results)) + print("torch output: ") + print(tokenizer.decode(torch_results)) - matcher = base_model_results.past_key_values[0][0] - print(sliced) - print(matcher) - print(sliced.shape) - print(matcher.shape) if __name__ == "__main__": args = parser.parse_args()