From 8359038a0307f0abd9476a951c28de5de491512e Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 19 Sep 2024 11:03:53 -0500 Subject: [PATCH] Change Llama2 from the Turbine implementation to the Sharktank one --- apps/shark_studio/api/llm.py | 79 ++++++++++++++++++++++---------- apps/shark_studio/api/sd.py | 18 ++++---- apps/shark_studio/web/ui/chat.py | 2 +- requirements.txt | 5 ++ 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index f6d33adcb6..aeaff8b917 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,4 +1,9 @@ -from turbine_models.custom_models import stateless_llama +from shark_turbine.aot import * +from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +import sharktank +import huggingface_hub + +# from turbine_models.custom_models import stateless_llama from turbine_models.model_runner import vmfbRunner from turbine_models.gen_external_params.gen_external_params import gen_external_params import time @@ -19,7 +24,7 @@ llm_model_map = { "meta-llama/Llama-2-7b-chat-hf": { - "initializer": stateless_llama.export_transformer_model, + # "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, @@ -27,7 +32,7 @@ "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { - "initializer": stateless_llama.export_transformer_model, + # "initializer": stateless_llama.export_transformer_model, "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, @@ -35,7 +40,7 @@ "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, "TinyPixel/small-llama2": { - "initializer": stateless_llama.export_transformer_model, + # "initializer": stateless_llama.export_transformer_model, "hf_model_name": "TinyPixel/small-llama2", "compile_flags": ["--iree-opt-const-expr-hoisting=True"], "stop_token": 2, @@ -130,13 +135,18 @@ def __init__( print( f"External weight file {self.external_weight_file} does not exist. Generating..." ) - gen_external_params( - hf_model_name=self.hf_model_name, - quantization=self.quantization, - weight_path=self.external_weight_file, - hf_auth_token=hf_auth_token, - precision=self.precision, + # gen_external_params( + # hf_model_name=self.hf_model_name, + # quantization=self.quantization, + # weight_path=self.external_weight_file, + # hf_auth_token=hf_auth_token, + # precision=self.precision, + # ) + cache_dir = os.path.join(".", str(self.hf_model_name).replace("/", "_")) + huggingface_hub.snapshot_download( + repo_id=self.hf_model_name, cache_dir=cache_dir ) + # TODO: Convert to gguf, delete cache else: print( f"External weight file {self.external_weight_file} found for {self.vmfb_name}" @@ -161,20 +171,39 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ - "initializer" - ]( - self.hf_model_name, - hf_auth_token, - compile_to="torch", - external_weights=external_weights, - precision=self.precision, - quantization=self.quantization, - streaming_llm=self.streaming_llm, - decomp_attn=True, + # self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ + # "initializer" + # ]( + # self.hf_model_name, + # hf_auth_token, + # compile_to="torch", + # external_weights=external_weights, + # precision=self.precision, + # quantization=self.quantization, + # streaming_llm=self.streaming_llm, + # decomp_attn=True, + # ) + + dataset = sharktank.types.Dataset.load( + self.external_weight_file, file_type="gguf" + ) + hp = sharktank.layers.configs.LlamaHParams.from_gguf_props( + dataset.properties ) - with open(self.tempfile_name, "w+") as f: - f.write(self.torch_ir) + llama_config = sharktank.models.llama.llama.LlamaModelConfig(hp) + llama_config.use_hf = False + llama_config.static_tables = ( + False # Rely on the compiler for hoisting tables. + ) + llama_config.kv_cache_type = "direct" # if args.bs == [1] else "paged" + model = PagedLlamaModelV1(dataset.root_theta, llama_config) + + fxb = FxProgramsBuilder(model) + self.torch_ir = export(fxb) + self.torch_ir.save_mlir(self.tempfile_name) + + # with open(self.tempfile_name, "w+") as f: + # f.write(self.torch_ir) del self.torch_ir gc.collect() self.compile() @@ -413,7 +442,7 @@ def llm_chat_api(InputData: dict): hf_auth_token=cmd_opts.hf_auth_token, device=device, quantization=cmd_opts.quantization, - external_weights="safetensors", + external_weights="gguf", use_system_prompt=True, streaming_llm=False, ) @@ -467,7 +496,7 @@ def llm_chat_api(InputData: dict): "Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, device="cpu-task", - external_weights="safetensors", + external_weights="gguf", ) print("model loaded") diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 1b535b66b2..fcc1ccf317 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -169,14 +169,14 @@ def __init__( batch_size=batch_size, num_inference_steps=steps, device=target_backend, - iree_target_triple=triple, + target=triple, # iree_target_triple=triple, ireec_flags=EMPTY_FLAGS, attn_spec=attn_spec, decomp_attn=decomp_attn, pipeline_dir=self.pipeline_dir, external_weights_dir=self.weights_path, external_weights=external_weights, - custom_vae=custom_vae, + # custom_vae=custom_vae, ) print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") gc.collect() @@ -237,13 +237,15 @@ def prepare_pipe( ) weights[key] = save_irpa(vae_weights_path, "vae.") - vmfbs, weights = self.sd_pipe.check_prepared( - mlirs, vmfbs, weights, interactive=False - ) + # vmfbs, weights = self.sd_pipe.check_prepared( + # mlirs, vmfbs, weights, interactive=False + # ) + self.sd_pipe.prepare_all() print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") - self.sd_pipe.load_pipeline( - vmfbs, weights, self.rt_device, self.compiled_pipeline - ) + # self.sd_pipe.load_pipeline( + # vmfbs, weights, self.rt_device, self.compiled_pipeline + # ) + self.sd_pipe.load_map() print( "\n[LOG] Pipeline successfully prepared for runtime. Generating images..." ) diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index cad9f4cb00..0b3398fc37 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -62,7 +62,7 @@ def chat_fn( model, device=device, precision=precision, - external_weights="safetensors", + external_weights="gguf", use_system_prompt=prompt_prefix, streaming_llm=streaming_llm, hf_auth_token=cmd_opts.hf_auth_token, diff --git a/requirements.txt b/requirements.txt index 404c1db9b1..8223bd0eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,7 +39,12 @@ py-cpuinfo pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions mpmath==1.3.0 optimum +fastapi<0.113.0 # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile pyinstaller + +sharktank +gguf +huggingface_hub