Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Llama2 from the Turbine implementation to the Sharktank one #2170

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from turbine_models.custom_models import stateless_llama
from shark_turbine.aot import *
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank
import huggingface_hub

# from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
Expand All @@ -19,23 +24,23 @@

llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
Expand Down Expand Up @@ -130,13 +135,18 @@ def __init__(
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
# gen_external_params(
# hf_model_name=self.hf_model_name,
# quantization=self.quantization,
# weight_path=self.external_weight_file,
# hf_auth_token=hf_auth_token,
# precision=self.precision,
# )
cache_dir = os.path.join(".", str(self.hf_model_name).replace("/", "_"))
huggingface_hub.snapshot_download(
repo_id=self.hf_model_name, cache_dir=cache_dir
)
# TODO: Convert to gguf, delete cache
else:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way that sharktank recommends for generating the .gguf file is to use a CLI tool from llama.cpp. Is that still the best way to extract that, or do we have a way to do it using sharktank?

print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
Expand All @@ -161,20 +171,39 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
# self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
# "initializer"
# ](
# self.hf_model_name,
# hf_auth_token,
# compile_to="torch",
# external_weights=external_weights,
# precision=self.precision,
# quantization=self.quantization,
# streaming_llm=self.streaming_llm,
# decomp_attn=True,
# )

dataset = sharktank.types.Dataset.load(
self.external_weight_file, file_type="gguf"
)
hp = sharktank.layers.configs.LlamaHParams.from_gguf_props(
dataset.properties
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
llama_config = sharktank.models.llama.llama.LlamaModelConfig(hp)
llama_config.use_hf = False
llama_config.static_tables = (
False # Rely on the compiler for hoisting tables.
)
llama_config.kv_cache_type = "direct" # if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

fxb = FxProgramsBuilder(model)
self.torch_ir = export(fxb)
self.torch_ir.save_mlir(self.tempfile_name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why, but this is producing an empty module. Any idea what I'm missing?


# with open(self.tempfile_name, "w+") as f:
# f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
Expand Down Expand Up @@ -413,7 +442,7 @@ def llm_chat_api(InputData: dict):
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
external_weights="gguf",
use_system_prompt=True,
streaming_llm=False,
)
Expand Down Expand Up @@ -467,7 +496,7 @@ def llm_chat_api(InputData: dict):
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
external_weights="gguf",
)

print("model loaded")
Expand Down
18 changes: 10 additions & 8 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def __init__(
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
target=triple, # iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
# custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
Expand Down Expand Up @@ -237,13 +237,15 @@ def prepare_pipe(
)
weights[key] = save_irpa(vae_weights_path, "vae.")

vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
# vmfbs, weights = self.sd_pipe.check_prepared(
# mlirs, vmfbs, weights, interactive=False
# )
self.sd_pipe.prepare_all()
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
# self.sd_pipe.load_pipeline(
# vmfbs, weights, self.rt_device, self.compiled_pipeline
# )
self.sd_pipe.load_map()
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def chat_fn(
model,
device=device,
precision=precision,
external_weights="safetensors",
external_weights="gguf",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ py-cpuinfo
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum
fastapi<0.113.0

# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

sharktank
gguf
huggingface_hub
Loading