Skip to content

Use turbine for gen_sharktank imports. #1904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def pytest_addoption(parser):
"--tank_url",
type=str,
default="gs://shark_tank/nightly",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/nightly",
)
parser.addoption(
"--tank_prefix",
Expand Down
1 change: 1 addition & 0 deletions requirements-importer.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre

shark-turbine
numpy>1.22.4
pytorch-triton
torchvision
Expand Down
19 changes: 11 additions & 8 deletions setup_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,21 @@ else
fi

$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
T_VER=$($PYTHON -m pip show torch | grep Version)
T_VER_MIN=${T_VER:14:12}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/

if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
T_VER_MIN=${T_VER:14:12}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/

$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install torch==2.1.0 torchvision
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
echo "Installed torch/torchvision for turbine importer requirements."
else
echo "Could not install torch + cu118." >&2
echo "Could not install torch version to satisfy shark-turbine requirement." >&2
fi
fi

Expand Down
8 changes: 2 additions & 6 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def save_data(
inputs_name = "inputs.npz"
outputs_name = "golden_out.npz"
func_file_name = "function_name"
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
print(f"saving {model_name_mlir} to {dir}")
try:
inputs = [x.cpu().detach() for x in inputs]
Expand Down Expand Up @@ -207,9 +205,7 @@ def import_debug(
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
)
sys.exit(1)
model_name_mlir = (
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
)
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
artifact_path = os.path.join(dir, model_name_mlir)
imported_mlir = self.import_mlir(
is_dynamic,
Expand Down
20 changes: 10 additions & 10 deletions tank/all_models.csv
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-base-uncased,torch,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,torch,torch,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,False,False,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/1487","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
96 changes: 44 additions & 52 deletions tank/generate_sharktank.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import hashlib
import numpy as np
from pathlib import Path
import shark_turbine.aot as aot


def create_hash(file_name):
Expand All @@ -34,22 +35,21 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
get_hf_causallm_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
from shark.shark_importer import import_with_fx, save_mlir

with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
fields = next(torch_reader)
for row in torch_reader:
if len(row) < 6:
continue
torch_model_name = row[0]
tracing_required = row[1]
dict_inputs = row[1]
model_type = row[2]
is_dynamic = row[3]
mlir_type = row[4]
is_decompose = row[5]

tracing_required = False if tracing_required == "False" else True
tracing_required = True
is_dynamic = False
print("generating artifacts for: " + torch_model_name)
model = None
Expand All @@ -64,6 +64,8 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
elif model_type == "hf_seqcls":
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_causallm":
model, input, _ = get_hf_causallm_model(
torch_model_name, import_args
Expand All @@ -88,15 +90,16 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
file_path = os.path.join(
torch_model_dir, f"{torch_model_name}_torch.mlir"
)
if dict_inputs == "True":
from shark.shark_importer import import_with_fx

if is_decompose:
# Add decomposition to some torch ops
# TODO add op whitelist/blacklist
import_with_fx(
model,
(input,),
inputs=input,
is_f16=False,
f16_input_mask=None,
debug=True,
training=False,
return_str=False,
Expand All @@ -107,27 +110,33 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=True,
)
else:
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
exported_model = aot.export(model, input)
exported_model.save_mlir(file_path=file_path)
golden_out = model.forward(input).detach().numpy()
function_name = "main"
hash_gen_attempts = 2
for i in range(hash_gen_attempts):
try:
mlir_hash = create_hash(file_path)
except FileNotFoundError as err:
if i < hash_gen_attempts:
continue
else:
raise err

np.save(
os.path.join(torch_model_dir, "hash"), np.array(mlir_hash)
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
np.savez(os.path.join(torch_model_dir, "inputs"), input)
np.savez(
os.path.join(torch_model_dir, "golden_out"), *golden_out
)
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,
)
np.save(
os.path.join(torch_model_dir, "function_name"),
np.array(function_name),
)

print(f"Finished saving artifacts for {torch_model_name}!")


def check_requirements(frontend):
Expand Down Expand Up @@ -156,21 +165,21 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
custom_model_csv = os.path.join(
os.path.dirname(__file__), "custom_model_list.csv"
)
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
if "_".join(row[0].split("/")) == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
break
with open(custom_model_csv, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
save_torch_model(custom_model_csv, tank_dir, import_args)
else:
raise NoImportException

Expand All @@ -184,23 +193,6 @@ def is_valid_file(arg):


if __name__ == "__main__":
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)

# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": 1,
}
Expand Down
Loading