Skip to content

Commit 669b145

Browse files
committed
Fix black formatting and keep decomps in separate file
1 parent 7ae597c commit 669b145

File tree

4 files changed

+102
-93
lines changed

4 files changed

+102
-93
lines changed

python/shark_turbine/dynamo/passes.py

Lines changed: 2 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
import torch
22
from torch.fx.experimental.proxy_tensor import make_fx
3-
from torch._decomp import get_decompositions, register_decomposition
4-
from torch._prims_common.wrappers import out_wrapper
5-
from torch._prims_common import (
6-
DeviceLikeType,
7-
TensorLikeType,
8-
)
9-
import torch._refs as _refs
3+
from shark_turbine.dynamo import utils
104
from torch.func import functionalize
11-
from torch import Tensor
12-
from typing import Dict, List, Tuple, Optional
5+
from typing import List
136

147
# default decompositions pulled from SHARK / torch._decomp
158
DEFAULT_DECOMPOSITIONS = [
@@ -59,87 +52,6 @@
5952
]
6053

6154

62-
@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
63-
def scaled_dot_product_flash_attention(
64-
query,
65-
key,
66-
value,
67-
dropout_p: float = 0.0,
68-
is_causal: bool = False,
69-
return_debug_mask: bool = False,
70-
*,
71-
scale: float = None,
72-
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
73-
dtype = query.dtype
74-
batchSize, num_head, qSize, headSize = (
75-
query.shape[0],
76-
query.shape[1],
77-
query.shape[2],
78-
query.shape[3],
79-
)
80-
81-
logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
82-
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
83-
[], dtype=torch.long
84-
)
85-
max_q, max_k = 0, 0
86-
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
87-
[], dtype=torch.long
88-
)
89-
debug_attn_mask = torch.empty(
90-
[],
91-
dtype=query.dtype,
92-
device="cpu",
93-
requires_grad=query.requires_grad,
94-
)
95-
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
96-
query, key, value, None, dropout_p, is_causal, None, scale=scale
97-
)
98-
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
99-
return (
100-
output.transpose(1, 2),
101-
logsumexp,
102-
cum_seq_q,
103-
cum_seq_k,
104-
max_q,
105-
max_k,
106-
philox_seed,
107-
philox_offset,
108-
debug_attn_mask,
109-
)
110-
111-
112-
# manually add decomposition to bypass the error that comes
113-
# from VAE encode(inp).latent_dist.sample() failing to symbolically
114-
# trace from torch fx.
115-
# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239
116-
# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170
117-
@register_decomposition(torch.ops.aten.randn.generator)
118-
@out_wrapper()
119-
def randn_generator(
120-
*shape,
121-
generator: Optional[torch.Generator] = None,
122-
dtype: Optional[torch.dtype] = None,
123-
device: Optional[DeviceLikeType] = None,
124-
layout: Optional[torch.layout] = None,
125-
requires_grad: bool = False,
126-
pin_memory: bool = False,
127-
) -> TensorLikeType:
128-
# We should eventually support the generator overload.
129-
# However, if someone passes in a None generator explicitly,
130-
# we can jut fall back to randn.default
131-
if generator is None:
132-
return _refs.randn(
133-
*shape,
134-
dtype=dtype,
135-
device=device,
136-
layout=layout,
137-
requires_grad=requires_grad,
138-
pin_memory=pin_memory,
139-
)
140-
return NotImplemented
141-
142-
14355
def apply_decompositions(
14456
gm: torch.fx.GraphModule,
14557
example_inputs,

python/shark_turbine/dynamo/utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
from torch._prims_common.wrappers import out_wrapper
3+
from torch._prims_common import (
4+
DeviceLikeType,
5+
TensorLikeType,
6+
)
7+
import torch._refs as _refs
8+
from torch._decomp import get_decompositions, register_decomposition
9+
from torch import Tensor
10+
from typing import Dict, List, Tuple, Optional
11+
12+
13+
@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
14+
def scaled_dot_product_flash_attention(
15+
query,
16+
key,
17+
value,
18+
dropout_p: float = 0.0,
19+
is_causal: bool = False,
20+
return_debug_mask: bool = False,
21+
*,
22+
scale: float = None,
23+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
24+
dtype = query.dtype
25+
batchSize, num_head, qSize, headSize = (
26+
query.shape[0],
27+
query.shape[1],
28+
query.shape[2],
29+
query.shape[3],
30+
)
31+
32+
logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
33+
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
34+
[], dtype=torch.long
35+
)
36+
max_q, max_k = 0, 0
37+
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
38+
[], dtype=torch.long
39+
)
40+
debug_attn_mask = torch.empty(
41+
[],
42+
dtype=query.dtype,
43+
device="cpu",
44+
requires_grad=query.requires_grad,
45+
)
46+
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
47+
query, key, value, None, dropout_p, is_causal, None, scale=scale
48+
)
49+
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
50+
return (
51+
output.transpose(1, 2),
52+
logsumexp,
53+
cum_seq_q,
54+
cum_seq_k,
55+
max_q,
56+
max_k,
57+
philox_seed,
58+
philox_offset,
59+
debug_attn_mask,
60+
)
61+
62+
63+
# manually add decomposition to bypass the error that comes
64+
# from VAE encode(inp).latent_dist.sample() failing to symbolically
65+
# trace from torch fx.
66+
# diffusers side issue: https://github.com/huggingface/diffusers/issues/6239
67+
# temporary torch fix: https://github.com/pytorch/pytorch/issues/107170
68+
@register_decomposition(torch.ops.aten.randn.generator)
69+
@out_wrapper()
70+
def randn_generator(
71+
*shape,
72+
generator: Optional[torch.Generator] = None,
73+
dtype: Optional[torch.dtype] = None,
74+
device: Optional[DeviceLikeType] = None,
75+
layout: Optional[torch.layout] = None,
76+
requires_grad: bool = False,
77+
pin_memory: bool = False,
78+
) -> TensorLikeType:
79+
# We should eventually support the generator overload.
80+
# However, if someone passes in a None generator explicitly,
81+
# we can jut fall back to randn.default
82+
if generator is None:
83+
return _refs.randn(
84+
*shape,
85+
dtype=dtype,
86+
device=device,
87+
layout=layout,
88+
requires_grad=requires_grad,
89+
pin_memory=pin_memory,
90+
)
91+
return NotImplemented

python/turbine_models/custom_models/sd_inference/vae_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def decode_inp(self, inp):
7474
with torch.no_grad():
7575
x = self.vae.decode(inp, return_dict=False)[0]
7676
return x
77-
77+
7878
def encode_inp(self, inp):
7979
latents = self.vae.encode(inp).latent_dist.sample()
8080
return 0.18215 * latents

python/turbine_models/tests/sd_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ def testExportVaeModelDecode(self):
169169
arguments["external_weight_path"],
170170
)
171171
torch_output = vae_runner.run_torch_vae(
172-
arguments["hf_model_name"], arguments["hf_auth_token"], "decode", example_input
172+
arguments["hf_model_name"],
173+
arguments["hf_auth_token"],
174+
"decode",
175+
example_input,
173176
)
174177
err = utils.largest_error(torch_output, turbine)
175178
assert err < 9e-5
@@ -211,7 +214,10 @@ def testExportVaeModelEncode(self):
211214
arguments["external_weight_path"],
212215
)
213216
torch_output = vae_runner.run_torch_vae(
214-
arguments["hf_model_name"], arguments["hf_auth_token"], "encode", example_input
217+
arguments["hf_model_name"],
218+
arguments["hf_auth_token"],
219+
"encode",
220+
example_input,
215221
)
216222
err = utils.largest_error(torch_output, turbine)
217223
assert err < 2e-3

0 commit comments

Comments
 (0)