From 044844f7b8626b36cd9d88a1025103fb2dbfa5bc Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Tue, 19 Dec 2023 03:56:23 +0000 Subject: [PATCH 1/7] [WIP] tracing_mode="symbolic" fails for VAE encode --- python/shark_turbine/aot/passes/functorch.py | 2 +- .../custom_models/sd_inference/vae.py | 14 ++++++++++- .../custom_models/sd_inference/vae_runner.py | 25 ++++++++++++++----- python/turbine_models/tests/sd_test.py | 3 ++- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/python/shark_turbine/aot/passes/functorch.py b/python/shark_turbine/aot/passes/functorch.py index 52f3feb66..3e9452268 100644 --- a/python/shark_turbine/aot/passes/functorch.py +++ b/python/shark_turbine/aot/passes/functorch.py @@ -47,7 +47,7 @@ def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: new_gm = proxy_tensor.make_fx( functionalized_callable, decomposition_table={}, - tracing_mode="symbolic", + #tracing_mode="symbolic", _allow_non_fake_inputs=True, _allow_fake_constant=False, )(*args) diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py index 50a788f64..696fea943 100644 --- a/python/turbine_models/custom_models/sd_inference/vae.py +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -53,6 +53,7 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument("--variant", type=str, default="decode") class VaeModel(torch.nn.Module): @@ -69,6 +70,10 @@ def forward(self, inp): x = self.vae.decode(inp, return_dict=False)[0] return x + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.18215 * latents + def export_vae_model( vae_model, @@ -83,6 +88,7 @@ def export_vae_model( device=None, target_triple=None, max_alloc=None, + variant="decode", ): mapper = {} utils.save_external_weights( @@ -90,12 +96,17 @@ def export_vae_model( ) sample = (batch_size, 4, height // 8, width // 8) + if variant == "encode": + sample = (1, 3, 512, 512) class CompiledVae(CompiledModule): params = export_parameters(vae_model) def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): - return jittable(vae_model.forward)(inp) + if variant == "decode": + return jittable(vae_model.forward)(inp) + elif variant == "encode": + return jittable(vae_model.encode_inp)(inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledVae(context=Context(), import_to=import_to) @@ -127,6 +138,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.variant, ) safe_name = utils.create_safe_name(args.hf_model_name, "-vae") with open(f"{safe_name}.mlir", "w+") as f: diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py index 77b196ac0..de3dd8cf9 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -45,6 +45,7 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--variant", type=str, default="decode") def run_vae( @@ -57,7 +58,7 @@ def run_vae( return results -def run_torch_vae(hf_model_name, hf_auth_token, example_input): +def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input): from diffusers import AutoencoderKL class VaeModel(torch.nn.Module): @@ -73,22 +74,34 @@ def forward(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x + + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.18215 * latents vae_model = VaeModel( hf_model_name, hf_auth_token, ) - results = vae_model.forward(example_input) + if variant == "decode": + results = vae_model.forward(example_input) + elif variant == "encode": + results = vae_model.encode_inp(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output if __name__ == "__main__": args = parser.parse_args() - example_input = torch.rand( - args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 - ) + if args.variant == "decode": + example_input = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + ) + elif args.variant == "encode": + example_input = torch.rand( + args.batch_size, 3, args.height, args.width, dtype=torch.float32 + ) print("generating turbine output:") turbine_results = run_vae( args.device, @@ -109,7 +122,7 @@ def forward(self, inp): from turbine_models.custom_models.sd_inference import utils torch_output = run_torch_vae( - args.hf_model_name, args.hf_auth_token, example_input + args.hf_model_name, args.hf_auth_token, args.variant, example_input ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_results) diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index b8dca64f5..0213d9b4c 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -148,6 +148,7 @@ def testExportVaeModel(self): "safetensors", "stable_diffusion_v1_4_vae.safetensors", "cpu", + "decode", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -168,7 +169,7 @@ def testExportVaeModel(self): arguments["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], arguments["hf_auth_token"], example_input + arguments["hf_model_name"], arguments["hf_auth_token"], "decode", example_input ) err = utils.largest_error(torch_output, turbine) assert err < 9e-5 From 13b682c3a181da98128d691866ba52c7d6f1df2c Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Wed, 20 Dec 2023 08:19:15 +0000 Subject: [PATCH 2/7] [WIP] Update VAE encode accuracy after diffusers change --- python/shark_turbine/aot/passes/functorch.py | 2 +- python/turbine_models/custom_models/sd_inference/vae_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/shark_turbine/aot/passes/functorch.py b/python/shark_turbine/aot/passes/functorch.py index 3e9452268..52f3feb66 100644 --- a/python/shark_turbine/aot/passes/functorch.py +++ b/python/shark_turbine/aot/passes/functorch.py @@ -47,7 +47,7 @@ def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: new_gm = proxy_tensor.make_fx( functionalized_callable, decomposition_table={}, - #tracing_mode="symbolic", + tracing_mode="symbolic", _allow_non_fake_inputs=True, _allow_fake_constant=False, )(*args) diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py index de3dd8cf9..56f11aea1 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -127,7 +127,7 @@ def encode_inp(self, inp): print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_results) print("Largest Error: ", err) - assert err < 9e-5 + assert err < 2e-3 # TODO: Figure out why we occasionally segfault without unlinking output variables turbine_results = None From 06943908251e490f1220d15368712d1fd57e29d2 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Thu, 21 Dec 2023 09:53:27 +0000 Subject: [PATCH 3/7] [WIP] Add temp decomp for randn.generator to bypass VAE encode error --- python/shark_turbine/dynamo/passes.py | 39 ++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 91ea40211..315efd7a3 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -1,9 +1,15 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions, register_decomposition +from torch._prims_common.wrappers import out_wrapper +from torch._prims_common import ( + DeviceLikeType, + TensorLikeType, +) +import torch._refs as _refs from torch.func import functionalize from torch import Tensor -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional # default decompositions pulled from SHARK / torch._decomp DEFAULT_DECOMPOSITIONS = [ @@ -103,6 +109,37 @@ def scaled_dot_product_flash_attention( ) +# manually add decomposition to bypass the error that comes +# from VAE encode(inp).latent_dist.sample() failing to symbolically +# trace from torch fx. +# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 +# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170 +@register_decomposition(torch.ops.aten.randn.generator) +@out_wrapper() +def randn_generator( + *shape, + generator: Optional[torch.Generator] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # We should eventually support the generator overload. + # However, if someone passes in a None generator explicitly, + # we can jut fall back to randn.default + if generator is None: + return _refs.randn( + *shape, + dtype=dtype, + device=device, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + return NotImplemented + + def apply_decompositions( gm: torch.fx.GraphModule, example_inputs, From 7ae597c1883ae81c02e83dfa60b88e4f966ba1c3 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Thu, 21 Dec 2023 10:34:20 +0000 Subject: [PATCH 4/7] Update sd_test to include vae encode --- .../custom_models/sd_inference/vae.py | 6 +-- .../custom_models/sd_inference/vae_runner.py | 4 +- python/turbine_models/tests/sd_test.py | 46 ++++++++++++++++++- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py index 696fea943..03ef85556 100644 --- a/python/turbine_models/custom_models/sd_inference/vae.py +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -65,7 +65,7 @@ def __init__(self, hf_model_name, hf_auth_token): token=hf_auth_token, ) - def forward(self, inp): + def decode_inp(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x @@ -97,14 +97,14 @@ def export_vae_model( sample = (batch_size, 4, height // 8, width // 8) if variant == "encode": - sample = (1, 3, 512, 512) + sample = (batch_size, 3, height, width) class CompiledVae(CompiledModule): params = export_parameters(vae_model) def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): if variant == "decode": - return jittable(vae_model.forward)(inp) + return jittable(vae_model.decode_inp)(inp) elif variant == "encode": return jittable(vae_model.encode_inp)(inp) diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py index 56f11aea1..e4a24929e 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -70,7 +70,7 @@ def __init__(self, hf_model_name, hf_auth_token): token=hf_auth_token, ) - def forward(self, inp): + def decode_inp(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x @@ -85,7 +85,7 @@ def encode_inp(self, inp): ) if variant == "decode": - results = vae_model.forward(example_input) + results = vae_model.decode_inp(example_input) elif variant == "encode": results = vae_model.encode_inp(example_input) np_torch_output = results.detach().cpu().numpy() diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index 0213d9b4c..4d515e0fb 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -134,7 +134,7 @@ def testExportUnetModel(self): os.remove("stable_diffusion_v1_4_unet.safetensors") os.remove("stable_diffusion_v1_4_unet.vmfb") - def testExportVaeModel(self): + def testExportVaeModelDecode(self): with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, @@ -148,7 +148,7 @@ def testExportVaeModel(self): "safetensors", "stable_diffusion_v1_4_vae.safetensors", "cpu", - "decode", + variant="decode", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -176,6 +176,48 @@ def testExportVaeModel(self): os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") + def testExportVaeModelEncode(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + variant="encode", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + example_input = torch.rand( + arguments["batch_size"], + 3, + arguments["height"], + arguments["width"], + dtype=torch.float32, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], arguments["hf_auth_token"], "encode", example_input + ) + err = utils.largest_error(torch_output, turbine) + assert err < 2e-3 + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4_vae.vmfb") + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) From 669b145b0339eb93e3c091651182e07a120feced Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Wed, 27 Dec 2023 21:14:05 +0000 Subject: [PATCH 5/7] Fix black formatting and keep decomps in separate file --- python/shark_turbine/dynamo/passes.py | 92 +------------------ python/shark_turbine/dynamo/utils.py | 91 ++++++++++++++++++ .../custom_models/sd_inference/vae_runner.py | 2 +- python/turbine_models/tests/sd_test.py | 10 +- 4 files changed, 102 insertions(+), 93 deletions(-) create mode 100644 python/shark_turbine/dynamo/utils.py diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 315efd7a3..ace933529 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -1,15 +1,8 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions, register_decomposition -from torch._prims_common.wrappers import out_wrapper -from torch._prims_common import ( - DeviceLikeType, - TensorLikeType, -) -import torch._refs as _refs +from shark_turbine.dynamo import utils from torch.func import functionalize -from torch import Tensor -from typing import Dict, List, Tuple, Optional +from typing import List # default decompositions pulled from SHARK / torch._decomp DEFAULT_DECOMPOSITIONS = [ @@ -59,87 +52,6 @@ ] -@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) -def scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: float = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: - dtype = query.dtype - batchSize, num_head, qSize, headSize = ( - query.shape[0], - query.shape[1], - query.shape[2], - query.shape[3], - ) - - logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float) - cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - max_q, max_k = 0, 0 - philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - debug_attn_mask = torch.empty( - [], - dtype=query.dtype, - device="cpu", - requires_grad=query.requires_grad, - ) - output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( - query, key, value, None, dropout_p, is_causal, None, scale=scale - ) - output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) - return ( - output.transpose(1, 2), - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) - - -# manually add decomposition to bypass the error that comes -# from VAE encode(inp).latent_dist.sample() failing to symbolically -# trace from torch fx. -# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 -# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170 -@register_decomposition(torch.ops.aten.randn.generator) -@out_wrapper() -def randn_generator( - *shape, - generator: Optional[torch.Generator] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[DeviceLikeType] = None, - layout: Optional[torch.layout] = None, - requires_grad: bool = False, - pin_memory: bool = False, -) -> TensorLikeType: - # We should eventually support the generator overload. - # However, if someone passes in a None generator explicitly, - # we can jut fall back to randn.default - if generator is None: - return _refs.randn( - *shape, - dtype=dtype, - device=device, - layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, - ) - return NotImplemented - - def apply_decompositions( gm: torch.fx.GraphModule, example_inputs, diff --git a/python/shark_turbine/dynamo/utils.py b/python/shark_turbine/dynamo/utils.py new file mode 100644 index 000000000..ef9789531 --- /dev/null +++ b/python/shark_turbine/dynamo/utils.py @@ -0,0 +1,91 @@ +import torch +from torch._prims_common.wrappers import out_wrapper +from torch._prims_common import ( + DeviceLikeType, + TensorLikeType, +) +import torch._refs as _refs +from torch._decomp import get_decompositions, register_decomposition +from torch import Tensor +from typing import Dict, List, Tuple, Optional + + +@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) +def scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: + dtype = query.dtype + batchSize, num_head, qSize, headSize = ( + query.shape[0], + query.shape[1], + query.shape[2], + query.shape[3], + ) + + logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float) + cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + max_q, max_k = 0, 0 + philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + debug_attn_mask = torch.empty( + [], + dtype=query.dtype, + device="cpu", + requires_grad=query.requires_grad, + ) + output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( + query, key, value, None, dropout_p, is_causal, None, scale=scale + ) + output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) + return ( + output.transpose(1, 2), + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + debug_attn_mask, + ) + + +# manually add decomposition to bypass the error that comes +# from VAE encode(inp).latent_dist.sample() failing to symbolically +# trace from torch fx. +# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 +# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170 +@register_decomposition(torch.ops.aten.randn.generator) +@out_wrapper() +def randn_generator( + *shape, + generator: Optional[torch.Generator] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # We should eventually support the generator overload. + # However, if someone passes in a None generator explicitly, + # we can jut fall back to randn.default + if generator is None: + return _refs.randn( + *shape, + dtype=dtype, + device=device, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + return NotImplemented diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py index e4a24929e..77acaedcb 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -74,7 +74,7 @@ def decode_inp(self, inp): with torch.no_grad(): x = self.vae.decode(inp, return_dict=False)[0] return x - + def encode_inp(self, inp): latents = self.vae.encode(inp).latent_dist.sample() return 0.18215 * latents diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index 4d515e0fb..125f97d82 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -169,7 +169,10 @@ def testExportVaeModelDecode(self): arguments["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], arguments["hf_auth_token"], "decode", example_input + arguments["hf_model_name"], + arguments["hf_auth_token"], + "decode", + example_input, ) err = utils.largest_error(torch_output, turbine) assert err < 9e-5 @@ -211,7 +214,10 @@ def testExportVaeModelEncode(self): arguments["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], arguments["hf_auth_token"], "encode", example_input + arguments["hf_model_name"], + arguments["hf_auth_token"], + "encode", + example_input, ) err = utils.largest_error(torch_output, turbine) assert err < 2e-3 From 1ba10b7a049f13c29288ca75d1a32085f73036c2 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Wed, 27 Dec 2023 21:18:58 +0000 Subject: [PATCH 6/7] Add get_decompositions import for passes.py --- python/shark_turbine/dynamo/passes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index ace933529..88c08f6ad 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -1,5 +1,6 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx +from torch._decomp import get_decompositions from shark_turbine.dynamo import utils from torch.func import functionalize from typing import List From f97b1ea560178371a1e14c0fd9ae496aedc4af62 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Wed, 27 Dec 2023 23:01:38 +0000 Subject: [PATCH 7/7] Add expected torch stable version --- python/shark_turbine/dynamo/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/shark_turbine/dynamo/utils.py b/python/shark_turbine/dynamo/utils.py index ef9789531..6429c2444 100644 --- a/python/shark_turbine/dynamo/utils.py +++ b/python/shark_turbine/dynamo/utils.py @@ -63,8 +63,9 @@ def scaled_dot_product_flash_attention( # manually add decomposition to bypass the error that comes # from VAE encode(inp).latent_dist.sample() failing to symbolically # trace from torch fx. +# Expected Torch stable version: > 2.1.0 # diffusers side issue: https://github.com/huggingface/diffusers/issues/6239 -# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170 +# temporary Torch fix: https://github.com/pytorch/pytorch/issues/107170 @register_decomposition(torch.ops.aten.randn.generator) @out_wrapper() def randn_generator(