Skip to content

Commit

Permalink
Fix Llama2 on CPU (#2133)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd authored Apr 29, 2024
1 parent e003d0a commit 81d6e05
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
20 changes: 14 additions & 6 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM

llm_model_map = {
"llama2_7b": {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
Expand Down Expand Up @@ -258,7 +258,8 @@ def format_out(results):

history.append(format_out(token))
while (
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
format_out(token)
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
Expand All @@ -272,7 +273,10 @@ def format_out(results):

self.prev_token_len = token_len + len(history)

if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
break

for i in range(len(history)):
Expand Down Expand Up @@ -306,7 +310,7 @@ def chat_hf(self, prompt):
self.first_input = False

history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
Expand All @@ -317,7 +321,7 @@ def chat_hf(self, prompt):

self.prev_token_len = token_len + len(history)

if token == llm_model_map["llama2_7b"]["stop_token"]:
if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
Expand Down Expand Up @@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict):
else:
print(f"prompt : {InputData['prompt']}")

model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b"
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
Expand Down
2 changes: 2 additions & 0 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj

B_SYS, E_SYS = "<s>", "</s>"
Expand Down Expand Up @@ -64,6 +65,7 @@ def chat_fn(
external_weights="safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ safetensors==0.3.1
py-cpuinfo
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum

# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
Expand Down

0 comments on commit 81d6e05

Please sign in to comment.