From 541572a1d005770798234b4e9c02e25f0a6f6424 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 22 Oct 2024 02:45:46 +0200 Subject: [PATCH] Replace `shark_turbine` with `iree.turbine` (#870) --- models/turbine_models/custom_models/resnet_18.py | 4 ++-- .../turbine_models/custom_models/sd3_inference/sd3_full.py | 2 +- .../turbine_models/custom_models/sd3_inference/sd3_mmdit.py | 6 +++--- .../custom_models/sd3_inference/sd3_mmdit_runner.py | 2 +- .../custom_models/sd3_inference/sd3_schedulers.py | 6 +++--- .../custom_models/sd3_inference/sd3_text_encoders.py | 4 ++-- .../turbine_models/custom_models/sd3_inference/sd3_vae.py | 4 ++-- .../custom_models/sd3_inference/text_encoder_impls.py | 2 +- models/turbine_models/custom_models/sd_inference/clip.py | 4 ++-- .../turbine_models/custom_models/sd_inference/schedulers.py | 6 +++--- models/turbine_models/custom_models/sd_inference/unet.py | 6 +++--- models/turbine_models/custom_models/sd_inference/vae.py | 6 +++--- models/turbine_models/custom_models/sdxl_inference/clip.py | 2 +- .../custom_models/sdxl_inference/sdxl_prompt_encoder.py | 4 ++-- .../custom_models/sdxl_inference/sdxl_scheduled_unet.py | 4 ++-- .../sdxl_inference/sdxl_scheduled_unet_runner.py | 2 +- models/turbine_models/custom_models/sdxl_inference/unet.py | 4 ++-- models/turbine_models/custom_models/sdxl_inference/vae.py | 4 ++-- models/turbine_models/custom_models/stateless_llama.py | 4 ++-- models/turbine_models/model_builder.py | 2 +- models/turbine_models/tests/pipeline_test.py | 4 ++-- models/turbine_models/tests/stateless_llama_test.py | 2 +- 22 files changed, 42 insertions(+), 42 deletions(-) diff --git a/models/turbine_models/custom_models/resnet_18.py b/models/turbine_models/custom_models/resnet_18.py index 3b560e058..bb2085368 100644 --- a/models/turbine_models/custom_models/resnet_18.py +++ b/models/turbine_models/custom_models/resnet_18.py @@ -4,11 +4,11 @@ from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from iree.compiler.ir import Context import iree.runtime as rt from turbine_models.custom_models.sd_inference import utils -import shark_turbine.ops.iree as ops +import iree.turbine.ops.iree as ops import argparse parser = argparse.ArgumentParser() diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_full.py b/models/turbine_models/custom_models/sd3_inference/sd3_full.py index f88cda03f..cafab40ff 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_full.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_full.py @@ -10,7 +10,7 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * +from iree.turbine.aot import * from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index b71d3129e..44840367e 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -13,11 +13,11 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py index 06100eab3..14b4c2bfd 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -5,7 +5,7 @@ import torch import numpy as np from tqdm.auto import tqdm -from shark_turbine.ops.iree import trace_tensor +from iree.turbine.ops.iree import trace_tensor torch.random.manual_seed(0) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 2c1d04cf1..815c8459c 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -10,9 +10,9 @@ import torch from typing import Any, Callable, Dict, List, Optional, Union -from shark_turbine.aot import * -import shark_turbine.ops.iree as ops -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +import iree.turbine.ops.iree as ops +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from iree.compiler.ir import Context import iree.runtime as ireert import numpy as np diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index d3e4ecb54..11a28ba5a 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -12,8 +12,8 @@ import iree.compiler as ireec from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py index ff24864a6..a978bfd1a 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -11,8 +11,8 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) from turbine_models.custom_models.sd_inference import utils diff --git a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py index 747b60d9b..c825009f2 100644 --- a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py +++ b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py @@ -3,7 +3,7 @@ import torch, math from torch import nn from transformers import CLIPTokenizer, T5TokenizerFast -from shark_turbine import ops +from iree.turbine import ops ################################################################################################# ### Core/Utility diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 11705a916..271a1032e 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -8,8 +8,8 @@ import re from iree.compiler.ir import Context -from shark_turbine.aot import * -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 0a6e36cc1..2f4a0f7fc 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -8,9 +8,9 @@ from typing import List import torch -from shark_turbine.aot import * -import shark_turbine.ops.iree as ops -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +import iree.turbine.ops.iree as ops +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from iree.compiler.ir import Context import iree.runtime as ireert import numpy as np diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index dac967b8a..46854d35e 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -11,11 +11,11 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 7ccd12c48..a29472393 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -9,11 +9,11 @@ from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo diff --git a/models/turbine_models/custom_models/sdxl_inference/clip.py b/models/turbine_models/custom_models/sdxl_inference/clip.py index 2740745ed..d7c19abba 100644 --- a/models/turbine_models/custom_models/sdxl_inference/clip.py +++ b/models/turbine_models/custom_models/sdxl_inference/clip.py @@ -11,7 +11,7 @@ import iree.compiler as ireec from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * +from iree.turbine.aot import * from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 40ce6c2e5..0630070eb 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -11,8 +11,8 @@ import iree.compiler as ireec from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index fd9adaa8f..5a79b1f3a 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -19,8 +19,8 @@ from iree import runtime as ireert from iree.compiler.ir import Context -from shark_turbine.aot import * -import shark_turbine.ops as ops +from iree.turbine.aot import * +import iree.turbine.ops as ops from turbine_models.custom_models.sd_inference import utils from turbine_models.custom_models.sd_inference.schedulers import get_scheduler diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py index 5e90596d9..f93d0d1a0 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py @@ -5,7 +5,7 @@ import torch import numpy as np from tqdm.auto import tqdm -from shark_turbine.ops.iree import trace_tensor +from iree.turbine.ops.iree import trace_tensor torch.random.manual_seed(0) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index bd36db763..9cb028a14 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -12,8 +12,8 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.aot import * +from iree.turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 753cbb9e7..8c741d225 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -11,8 +11,8 @@ from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np -from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) from turbine_models.custom_models.sd_inference import utils diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index 4ad00e911..933e3593c 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch from torch.utils import _pytree as pytree -from shark_turbine.aot import * +from iree.turbine.aot import * from iree.compiler.ir import Context from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import ( enable_llama_pos_shift_attention, @@ -458,7 +458,7 @@ def evict_kvcache_space(self): # TODO: Integrate with external parameters to actually be able to run # TODO: Make more generalizable to be able to quantize with all compile_to options if quantization == "int4" and not compile_to == "linalg": - from shark_turbine.transforms.quantization import mm_group_quant + from iree.turbine.transforms.quantization import mm_group_quant mm_group_quant.MMGroupQuantRewriterPass( CompiledModule.get_mlir_module(inst).operation diff --git a/models/turbine_models/model_builder.py b/models/turbine_models/model_builder.py index 4d06d8623..81fcb21b4 100644 --- a/models/turbine_models/model_builder.py +++ b/models/turbine_models/model_builder.py @@ -1,6 +1,6 @@ from transformers import AutoModel, AutoTokenizer, AutoConfig import torch -import shark_turbine.aot as aot +import iree.turbine.aot as aot from turbine_models.turbine_tank import turbine_tank import os import re diff --git a/models/turbine_models/tests/pipeline_test.py b/models/turbine_models/tests/pipeline_test.py index 658402652..e7507c033 100644 --- a/models/turbine_models/tests/pipeline_test.py +++ b/models/turbine_models/tests/pipeline_test.py @@ -11,13 +11,13 @@ import os import numpy as np from iree.compiler.ir import Context -from shark_turbine.aot import * +from iree.turbine.aot import * from turbine_models.custom_models.sd_inference import utils from turbine_models.custom_models.pipeline_base import ( PipelineComponent, TurbinePipelineBase, ) -from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.turbine.transforms.general.add_metadata import AddMetadataPass model_metadata_forward = { "model_name": "TestModel2xLinear", diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index 4b1ffef73..a6d04108b 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -16,7 +16,7 @@ import tempfile os.environ["TORCH_LOGS"] = "dynamic" -from shark_turbine.aot import * +from iree.turbine.aot import * from turbine_models.custom_models import llm_runner from turbine_models.gen_external_params.gen_external_params import (