Skip to content

Commit

Permalink
Update setup_venv.sh with simpler install path. (#2144)
Browse files Browse the repository at this point in the history
* Update requirements.txt for iree-turbine (#2130)

* Fix Llama2 on CPU (#2133)

* Filesystem cleanup and custom model fixes (#2127)

* Initial filesystem cleanup

* More filesystem cleanup

* Fix some formatting issues

* Address comments

* Remove IREE pin (fixes exe issue) (#2126)

* Diagnose a build issue

* Remove IREE pin

* Revert the build on pull request change

* Update find links for IREE packages (#2136)

* (Studio2) Refactors SD pipeline to rely on turbine-models pipeline, fixes to LLM, gitignore (#2129)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Fixups to llm API/UI and ignore user config files.

* Small fixes for unifying pipelines.

* Update requirements.txt for iree-turbine (#2130)

* Fix Llama2 on CPU (#2133)

* Filesystem cleanup and custom model fixes (#2127)

* Fix some formatting issues

* Remove IREE pin (fixes exe issue) (#2126)

* Update find links for IREE packages (#2136)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Abstract out SD pipelines from Studio Webui (WIP)

* Switch from pin to minimum torch version and fix index url

* Fix device parsing.

* Fix linux setup

* Fix custom weights.

---------

Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>

* Remove leftover merge conflict line from setup script. (#2141)

* Add a few requirements for ensured parity with turbine-models requirements. (#2142)

* Add scipy to requirements.

Adds diffusers req and a note for torchsde.

* Update linux setup script.

* Move brevitas install

---------

Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
  • Loading branch information
4 people authored May 28, 2024
1 parent 78160b8 commit 78b6e4f
Show file tree
Hide file tree
Showing 24 changed files with 377 additions and 470 deletions.
1 change: 1 addition & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
python process_skipfiles.py
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,4 @@ jobs:
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ cython_debug/
# vscode related
.vscode

# Shark related artefacts
# Shark related artifacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs

# ORT related artefacts
cache_models/
Expand All @@ -188,6 +189,11 @@ variants.json
# models folder
apps/stable_diffusion/web/models/

# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb

# Stencil annotators.
stencil_annotator/

Expand Down
4 changes: 2 additions & 2 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def initialize():
clear_tmp_imgs()

from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
create_model_folders,
)

# Create custom models folders if they don't exist
create_checkpoint_folders()
create_model_folders()

import gradio as gr

Expand Down
20 changes: 13 additions & 7 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM

llm_model_map = {
"llama2_7b": {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
Expand Down Expand Up @@ -155,7 +155,9 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -258,7 +260,7 @@ def format_out(results):

history.append(format_out(token))
while (
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
Expand All @@ -272,7 +274,7 @@ def format_out(results):

self.prev_token_len = token_len + len(history)

if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break

for i in range(len(history)):
Expand Down Expand Up @@ -306,7 +308,7 @@ def chat_hf(self, prompt):
self.first_input = False

history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
Expand All @@ -317,7 +319,7 @@ def chat_hf(self, prompt):

self.prev_token_len = token_len + len(history)

if token == llm_model_map["llama2_7b"]["stop_token"]:
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
Expand Down Expand Up @@ -347,7 +349,11 @@ def llm_chat_api(InputData: dict):
else:
print(f"prompt : {InputData['prompt']}")

model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b"
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
Expand Down
Loading

0 comments on commit 78b6e4f

Please sign in to comment.