Skip to content

Commit

Permalink
try more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Mar 6, 2024
1 parent 0e30bb8 commit 9e59c06
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
matrix:
version: [3.11]
os: [nodai-amdgpu-w7900-x86-64]
os: [nodai-ubuntu-builder-large]

runs-on: ${{matrix.os}}
steps:
Expand Down
46 changes: 13 additions & 33 deletions models/turbine_models/custom_models/llm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,14 @@ def run_llm(
streaming_llm=False,
chat_mode=False,
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
tokenizer=None,
):
tokenizer = AutoTokenizer.from_pretrained(
hf_model_name,
use_fast=False,
token=hf_auth_token,
)
if tokenizer == None:
tokenizer = AutoTokenizer.from_pretrained(
hf_model_name,
use_fast=False,
token=hf_auth_token,
)
llm = SharkLLM(
device=device,
vmfb_path=vmfb_path,
Expand Down Expand Up @@ -206,55 +208,33 @@ def run_torch_llm(
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
model=None,
tokenizer=None,
):
from turbine_models.model_builder import HFTransformerBuilder
from transformers import AutoModelForCausalLM

if model == None:
model_builder = HFTransformerBuilder(
example_input=None,
hf_id=hf_model_name,
auto_model=AutoModelForCausalLM,
hf_auth_token=hf_auth_token,
auto_tokenizer=AutoTokenizer,
)
else:
model_builder = HFTransformerBuilder(
example_input=None,
hf_id=hf_model_name,
auto_model=AutoModelForCausalLM,
hf_auth_token=hf_auth_token,
auto_tokenizer=AutoTokenizer,
model=model,
tokenizer=tokenizer,
)

):
if streaming_llm is True:
enable_llama_pos_shift_attention(model_builder.model)
enable_llama_pos_shift_attention(model)

def get_token_from_logits(logits):
return torch.argmax(logits[:, -1, :], dim=1)

prompt = append_user_prompt(chat_sys_prompt, prompt)
initial_input = model_builder.tokenizer(prompt, return_tensors="pt")
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids

model_results = model_builder.model.forward(example_input_id)
model_results = model.forward(example_input_id)
model_token = get_token_from_logits(model_results.logits)

pkv = model_results.past_key_values

torch_results = []
torch_results.append(int(model_token))
while model_token != 2:
model_results = model_builder.model.forward(
model_results = model.forward(
torch.unsqueeze(model_token, 0), past_key_values=pkv
)
model_token = get_token_from_logits(model_results.logits)
pkv = model_results.past_key_values
torch_results.append(int(model_token[0]))

return model_builder.tokenizer.decode(torch_results)
return tokenizer.decode(torch_results)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions models/turbine_models/tests/stateless_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_vmfb_comparison(self):
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
None,
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
tokenizer=tokenizer
)
check_output_string(torch_str, turbine_str)

Expand Down Expand Up @@ -160,6 +161,7 @@ def test_streaming_vmfb_comparison(self):
None,
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
streaming_llm=True,
tokenizer=tokenizer
)
check_output_string(torch_str, turbine_str)

Expand Down

0 comments on commit 9e59c06

Please sign in to comment.