Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vae encode example with test #294

Merged
merged 7 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 3 additions & 53 deletions python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions, register_decomposition
from torch._decomp import get_decompositions
from shark_turbine.dynamo import utils
from torch.func import functionalize
from torch import Tensor
from typing import Dict, List, Tuple
from typing import List

# default decompositions pulled from SHARK / torch._decomp
DEFAULT_DECOMPOSITIONS = [
Expand Down Expand Up @@ -53,56 +53,6 @@
]


@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)


def apply_decompositions(
gm: torch.fx.GraphModule,
example_inputs,
Expand Down
91 changes: 91 additions & 0 deletions python/shark_turbine/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from torch._prims_common.wrappers import out_wrapper
from torch._prims_common import (
DeviceLikeType,
TensorLikeType,
)
import torch._refs as _refs
from torch._decomp import get_decompositions, register_decomposition
from torch import Tensor
from typing import Dict, List, Tuple, Optional


@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this being here, but I can't think of a better place for now.

def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)


# manually add decomposition to bypass the error that comes
# from VAE encode(inp).latent_dist.sample() failing to symbolically
# trace from torch fx.
aviator19941 marked this conversation as resolved.
Show resolved Hide resolved
# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239
# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170
@register_decomposition(torch.ops.aten.randn.generator)
@out_wrapper()
def randn_generator(
*shape,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[DeviceLikeType] = None,
layout: Optional[torch.layout] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> TensorLikeType:
# We should eventually support the generator overload.
# However, if someone passes in a None generator explicitly,
# we can jut fall back to randn.default
if generator is None:
return _refs.randn(
*shape,
dtype=dtype,
device=device,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
return NotImplemented
16 changes: 14 additions & 2 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument("--variant", type=str, default="decode")


class VaeModel(torch.nn.Module):
Expand All @@ -64,11 +65,15 @@ def __init__(self, hf_model_name, hf_auth_token):
token=hf_auth_token,
)

def forward(self, inp):
def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents


def export_vae_model(
vae_model,
Expand All @@ -83,19 +88,25 @@ def export_vae_model(
device=None,
target_triple=None,
max_alloc=None,
variant="decode",
):
mapper = {}
utils.save_external_weights(
mapper, vae_model, external_weights, external_weight_path
)

sample = (batch_size, 4, height // 8, width // 8)
if variant == "encode":
sample = (batch_size, 3, height, width)

class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
return jittable(vae_model.forward)(inp)
if variant == "decode":
return jittable(vae_model.decode_inp)(inp)
elif variant == "encode":
return jittable(vae_model.encode_inp)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledVae(context=Context(), import_to=import_to)
Expand Down Expand Up @@ -127,6 +138,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.variant,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
with open(f"{safe_name}.mlir", "w+") as f:
Expand Down
29 changes: 21 additions & 8 deletions python/turbine_models/custom_models/sd_inference/vae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--variant", type=str, default="decode")


def run_vae(
Expand All @@ -57,7 +58,7 @@ def run_vae(
return results


def run_torch_vae(hf_model_name, hf_auth_token, example_input):
def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input):
from diffusers import AutoencoderKL

class VaeModel(torch.nn.Module):
Expand All @@ -69,26 +70,38 @@ def __init__(self, hf_model_name, hf_auth_token):
token=hf_auth_token,
)

def forward(self, inp):
def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents

vae_model = VaeModel(
hf_model_name,
hf_auth_token,
)

results = vae_model.forward(example_input)
if variant == "decode":
results = vae_model.decode_inp(example_input)
elif variant == "encode":
results = vae_model.encode_inp(example_input)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output


if __name__ == "__main__":
args = parser.parse_args()
example_input = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
if args.variant == "decode":
example_input = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
elif args.variant == "encode":
example_input = torch.rand(
args.batch_size, 3, args.height, args.width, dtype=torch.float32
)
print("generating turbine output:")
turbine_results = run_vae(
args.device,
Expand All @@ -109,12 +122,12 @@ def forward(self, inp):
from turbine_models.custom_models.sd_inference import utils

torch_output = run_torch_vae(
args.hf_model_name, args.hf_auth_token, example_input
args.hf_model_name, args.hf_auth_token, args.variant, example_input
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_results)
print("Largest Error: ", err)
assert err < 9e-5
assert err < 2e-3

# TODO: Figure out why we occasionally segfault without unlinking output variables
turbine_results = None
53 changes: 51 additions & 2 deletions python/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def testExportUnetModel(self):
os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4_unet.vmfb")

def testExportVaeModel(self):
def testExportVaeModelDecode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
Expand All @@ -148,6 +148,7 @@ def testExportVaeModel(self):
"safetensors",
"stable_diffusion_v1_4_vae.safetensors",
"cpu",
variant="decode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors"
Expand All @@ -168,13 +169,61 @@ def testExportVaeModel(self):
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"], arguments["hf_auth_token"], example_input
arguments["hf_model_name"],
arguments["hf_auth_token"],
"decode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 9e-5
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")

def testExportVaeModelEncode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
"stable_diffusion_v1_4_vae.safetensors",
"cpu",
variant="encode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors"
arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb"
example_input = torch.rand(
arguments["batch_size"],
3,
arguments["height"],
arguments["width"],
dtype=torch.float32,
)
turbine = vae_runner.run_vae(
arguments["device"],
example_input,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"],
arguments["hf_auth_token"],
"encode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 2e-3
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down
Loading