From a4bebf44ac0753c19abf438921b8da1cc3e26b57 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 12 Oct 2023 04:06:25 +0000 Subject: [PATCH] Some fixes to stateless Adds support functions for debug-ability --- examples/llama2_inference/stateless_llama.py | 35 ++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/examples/llama2_inference/stateless_llama.py b/examples/llama2_inference/stateless_llama.py index 83f252167..f4ad586bd 100644 --- a/examples/llama2_inference/stateless_llama.py +++ b/examples/llama2_inference/stateless_llama.py @@ -281,8 +281,8 @@ def run_initialize( token, *state = self.initialize(x, constraints=init_const) updates = [] self.global_seq_step = IREE.tensor_dim( - state[0], 3 - ) # 3rd dimension of arbitrarily 0th kv tensor + state[0], 2 + ) # 2nd dimension of arbitrarily 0th kv tensor for i in range(HEADS * 2): slice_of_state = IREE.tensor_reshape( state[i], 1, 1, HEADS, self.global_seq_step, HIDDEN_DIM @@ -305,6 +305,7 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): token, *state_update = self.forward( x, *state_arg, constraints=forw_const ) + self.global_seq_step = self.global_seq_step + 1 res = update_state( self.global_state, state_update, @@ -312,9 +313,15 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): HEADS, HIDDEN_DIM, ) - self.global_seq_step = self.global_seq_step + 1 + return token + def get_global_state(self): + return self.global_state + + def get_seq_step(self): + return self.global_seq_step + @jittable def initialize(input_ids): result = mod.forward(input_ids) @@ -391,24 +398,40 @@ def run_vmfb_comparison(args): initial_input = tokenizer(prompt, return_tensors="pt") example_input_id = initial_input.input_ids device_inputs = [ireert.asdevicearray(config.device, example_input_id)] + + step0 = ModuleCompiled["get_seq_step"]() + print("step0 :"+str(step0)) results = ModuleCompiled["run_initialize"](*device_inputs) + pkv = ModuleCompiled["get_global_state"]().to_host() + step = ModuleCompiled["get_seq_step"]() + print(f"step: {step}") + sliced = pkv[0,:,:,:step,:] + def format_out(results): return torch.tensor(results.to_host()[0][0]) - print(tokenizer.decode(format_out(results))) - for i in range(100): +# print(tokenizer.decode(format_out(results))) + for i in range(10): results = ModuleCompiled["run_forward"](results) + step = ModuleCompiled["get_seq_step"]() + print(f"step: {step}") print(tokenizer.decode(format_out(results))) model = InferenceModel(args) + def get_token_from_logits(logits): return torch.argmax(logits[:, -1, :], dim=1) base_model_results = model.base_model.forward(example_input_id) base_model_token = get_token_from_logits(base_model_results.logits) - print(tokenizer.decode(base_model_token)) +# print(tokenizer.decode(base_model_token)) + matcher = base_model_results.past_key_values[0][0] + print(sliced) + print(matcher) + print(sliced.shape) + print(matcher.shape) if __name__ == "__main__": args = parser.parse_args()