Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Nov 27, 2023
1 parent 05aab43 commit 43a7067
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):

self.global_seq_step = self.global_seq_step + 1
return token

def run_all(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
token, *state = self.initialize(x, constraints=init_const)
Expand All @@ -204,7 +204,9 @@ def run_all(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
]
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
)
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
token, *state_update = self.forward(
x, *state_arg, constraints=forw_const
)
for i in range(HEADS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
Expand Down Expand Up @@ -290,6 +292,7 @@ def forward(token0: torch.Tensor, *state0_flat):
print("saved to ", safe_name + ".vmfb")
exit()


def run_benchmark(args):
config = ireert.Config("local-task")

Expand Down Expand Up @@ -317,11 +320,15 @@ def run_benchmark(args):
example_input_id = initial_input.input_ids
input = np.asarray(example_input_id, dtype=None, order="C")
input = np.reshape(input, (1,) + (input.shape))
results = ireert.benchmark_module(mod, "run_all", input, parameters=f"model={weights}")

results = ireert.benchmark_module(
mod, "run_all", input, parameters=f"model={weights}"
)

for benchmark_result in results:
print(f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}")

print(
f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}"
)


def run_vmfb_comparison(args):
config = ireert.Config("local-task")
Expand Down

0 comments on commit 43a7067

Please sign in to comment.