Skip to content

Commit 575bcd0

Browse files
committed
Cleanup comments and redundant code.
1 parent 12b91f4 commit 575bcd0

File tree

8 files changed

+18
-358
lines changed

8 files changed

+18
-358
lines changed

core/shark_turbine/aot/builtins/jittable.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,6 @@ def flat_wrapped_f(*args):
214214
if "functorch_functionalize" in self._passes:
215215
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
216216

217-
for node in transformed_f.graph.nodes: # type: ignore
218-
if node.op == "call_function":
219-
if node.target == torch._ops.ops.aten.lift_fresh_copy.default:
220-
print(f"replaced lift_fresh_copy")
221-
node.target = torch._ops.ops.aten.clone.default
222-
transformed_f.recompile() # type: ignore
223-
224217
# Ask dynamo to give us an aten graph.
225218
# TODO: Cache this for repeated calls.
226219
logger.debug("Performing dynamo.export(constraints=%r)", constraints)
@@ -233,7 +226,7 @@ def flat_wrapped_f(*args):
233226
)
234227
logger.debug("Invoking dynamo trace")
235228
gm, guards = exported_f(*flat_pytorch_args)
236-
logger.debug("Dyanmo trace complete")
229+
logger.debug("Dynamo trace complete")
237230

238231
# TODO: Add debug logging for the exported graph module.
239232
# gm.print_readable()

models/turbine_models/custom_models/sd_inference/sdxl_split_schedulers.py

Lines changed: 0 additions & 280 deletions
This file was deleted.

models/turbine_models/custom_models/sd_inference/unet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,17 @@ def export_unet_model(
101101
target_triple=None,
102102
max_alloc=None,
103103
upload_ir=False,
104+
decomp_attn=True,
104105
):
105106
mapper = {}
106107
decomp_list = DEFAULT_DECOMPOSITIONS
107-
decomp_list.extend(
108-
[
109-
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
110-
torch.ops.aten._scaled_dot_product_flash_attention.default,
111-
]
112-
)
108+
if decomp_attn:
109+
decomp_list.extend(
110+
[
111+
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
112+
torch.ops.aten._scaled_dot_product_flash_attention.default,
113+
]
114+
)
113115
dtype = torch.float16 if precision == "fp16" else torch.float32
114116
unet_model = unet_model.to(dtype)
115117
utils.save_external_weights(

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
EulerDiscreteScheduler,
99
)
1010

11-
winograd_params = "keys=unet.down_blocks.2.resnets.0.conv2.weight keys=unet.down_blocks.2.resnets.1.conv1.weight keys=unet.down_blocks.2.resnets.1.conv2.weight keys=unet.mid_block.resnets.0.conv1.weight keys=unet.mid_block.resnets.0.conv2.weight keys=unet.mid_block.resnets.1.conv1.weight keys=unet.mid_block.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.0.conv2.weight keys=unet.up_blocks.0.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.2.conv2.weight keys=unet.up_blocks.0.resnets.0.conv1.weight keys=unet.up_blocks.0.resnets.1.conv1.weight keys=unet.up_blocks.0.resnets.2.conv1.weight keys=unet.up_blocks.0.upsamplers.0.conv.weight"
1211
# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument.
1312
gfx94X_flags = {
1413
"all": [
1514
"--iree-global-opt-propagate-transposes=true",
16-
"--iree-opt-const-eval=false",
1715
"--iree-opt-outer-dim-concat=true",
1816
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
1917
"--iree-vm-target-truncate-unsupported-floats",
@@ -95,6 +93,7 @@ def compile_to_vmfb(
9593
"--iree-hal-target-backends=rocm",
9694
"--iree-rocm-target-chip=" + target_triple,
9795
"--verify=false",
96+
"--iree-opt-const-eval=false",
9897
]
9998
)
10099
elif device == "cuda":

models/turbine_models/custom_models/sd_inference/vae.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,17 @@ def export_vae_model(
116116
max_alloc=None,
117117
variant="decode",
118118
upload_ir=False,
119+
decomp_attn=True,
119120
):
120121
mapper = {}
121122
decomp_list = DEFAULT_DECOMPOSITIONS
122-
decomp_list.extend(
123-
[
124-
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
125-
torch.ops.aten._scaled_dot_product_flash_attention.default,
126-
]
127-
)
123+
if decomp_attn:
124+
decomp_list.extend(
125+
[
126+
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
127+
torch.ops.aten._scaled_dot_product_flash_attention.default,
128+
]
129+
)
128130
dtype = torch.float16 if precision == "fp16" else torch.float32
129131
vae_model = vae_model.to(dtype)
130132
utils.save_external_weights(

0 commit comments

Comments
 (0)