From 6bf7d8c59fc0c6c0936a187760324d36cc3e5dd6 Mon Sep 17 00:00:00 2001 From: dan Date: Mon, 9 Oct 2023 19:05:42 +0000 Subject: [PATCH] adds a minimal jupyter llama example --- examples/llama2_inference/stateless_llama.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/llama2_inference/stateless_llama.py b/examples/llama2_inference/stateless_llama.py index 83f252167..22e4c0a99 100644 --- a/examples/llama2_inference/stateless_llama.py +++ b/examples/llama2_inference/stateless_llama.py @@ -305,6 +305,7 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): token, *state_update = self.forward( x, *state_arg, constraints=forw_const ) + self.global_seq_step = self.global_seq_step + 1 res = update_state( self.global_state, state_update, @@ -312,9 +313,15 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)): HEADS, HIDDEN_DIM, ) - self.global_seq_step = self.global_seq_step + 1 + return token + def get_global_state(self): + return self.global_state + + def get_seq_step(self): + return self.seq_step + @jittable def initialize(input_ids): result = mod.forward(input_ids)