Skip to content

Commit d1fda83

Browse files
committed
WIP: fixing scalars in jittables / precision
1 parent 4498486 commit d1fda83

File tree

6 files changed

+153
-157
lines changed

6 files changed

+153
-157
lines changed

python/turbine_models/custom_models/sd_inference/unet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@
6262

6363

6464
class UnetModel(torch.nn.Module):
65-
def __init__(self, hf_model_name, hf_auth_token):
65+
def __init__(self, hf_model_name):
6666
super().__init__()
6767
self.unet = UNet2DConditionModel.from_pretrained(
6868
hf_model_name,
6969
subfolder="unet",
70-
token=hf_auth_token,
7170
)
7271

73-
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
72+
def forward(self, sample, timestep, encoder_hidden_states):
73+
guidance_scale = 7.5
7474
samples = torch.cat([sample] * 2)
7575
unet_out = self.unet.forward(
7676
samples, timestep, encoder_hidden_states, return_dict=False
@@ -127,10 +127,10 @@ def main(
127127
encoder_hidden_states=AbstractTensor(
128128
*encoder_hidden_states_sizes, dtype=dtype
129129
),
130-
guidance_scale=AbstractTensor(1, dtype=dtype),
130+
#guidance_scale=AbstractTensor(1, dtype=dtype),
131131
):
132132
return jittable(unet_model.forward)(
133-
sample, timestep, encoder_hidden_states, guidance_scale
133+
sample, timestep, encoder_hidden_states, # guidance_scale
134134
)
135135

136136
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"

python/turbine_models/custom_models/sd_inference/unet_runner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def run_unet(
5252
sample,
5353
timestep,
5454
encoder_hidden_states,
55-
guidance_scale,
55+
# guidance_scale,
5656
vmfb_path,
5757
hf_model_name,
5858
hf_auth_token,
@@ -64,7 +64,7 @@ def run_unet(
6464
ireert.asdevicearray(runner.config.device, sample),
6565
ireert.asdevicearray(runner.config.device, timestep),
6666
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
67-
ireert.asdevicearray(runner.config.device, guidance_scale),
67+
# ireert.asdevicearray(runner.config.device, guidance_scale),
6868
]
6969
results = runner.ctx.modules.compiled_unet["main"](*inputs)
7070
return results
@@ -90,13 +90,13 @@ def __init__(self, hf_model_name, hf_auth_token):
9090
)
9191
self.guidance_scale = 7.5
9292

93-
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
93+
def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
9494
samples = torch.cat([sample] * 2)
9595
unet_out = self.unet.forward(
9696
samples, timestep, encoder_hidden_states, return_dict=False
9797
)[0]
9898
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
99-
noise_pred = noise_pred_uncond + guidance_scale * (
99+
noise_pred = noise_pred_uncond + self.guidance_scale * (
100100
noise_pred_text - noise_pred_uncond
101101
)
102102
return noise_pred
@@ -106,7 +106,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
106106
hf_auth_token,
107107
)
108108
results = unet_model.forward(
109-
sample, timestep, encoder_hidden_states, guidance_scale
109+
sample, timestep, encoder_hidden_states, #guidance_scale
110110
)
111111
np_torch_output = results.detach().cpu().numpy()
112112
return np_torch_output
@@ -118,7 +118,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
118118
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
119119
)
120120
timestep = torch.zeros(1, dtype=torch.float32)
121-
guidance_scale = torch.Tensor([7.5], dtype=torch.float32)
121+
# guidance_scale = torch.Tensor([7.5], dtype=torch.float32)
122122
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
123123
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
124124
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
@@ -129,7 +129,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
129129
sample,
130130
timestep,
131131
encoder_hidden_states,
132-
guidance_scale,
132+
# guidance_scale,
133133
args.vmfb_path,
134134
args.hf_model_name,
135135
args.hf_auth_token,
@@ -152,7 +152,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
152152
sample,
153153
timestep,
154154
encoder_hidden_states,
155-
guidance_scale,
155+
# guidance_scale,
156156
)
157157
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
158158
err = utils.largest_error(torch_output, turbine_output)

python/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def save_external_weights(
1616
for name in mod_params:
1717
mapper["params." + name] = name
1818
if external_weight_file:
19+
print("Saving params to", external_weight_file)
1920
safetensors.torch.save_file(mod_params, external_weight_file)
2021
print("Saved params to", external_weight_file)
2122

@@ -35,7 +36,6 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
3536
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
3637
"--iree-stream-resource-index-bits=64",
3738
"--iree-vm-target-index-bits=64",
38-
"--iree-codegen-check-ir-before-llvm-conversion=false",
3939
"--iree-opt-const-expr-hoisting=False",
4040
]
4141
if device == "cpu":

python/turbine_models/custom_models/sd_inference/vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
super().__init__()
6464
self.vae = None
6565
self.base_vae = False
66-
if custom_vae == "":
66+
if custom_vae in ["", None]:
6767
self.vae = AutoencoderKL.from_pretrained(
6868
hf_model_name,
6969
subfolder="vae",

python/turbine_models/custom_models/sd_inference/vae_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def run_vae(device, example_input, vmfb_path, hf_model_name, external_weight_pat
5151
return results
5252

5353

54-
def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input):
54+
def run_torch_vae(hf_model_name, variant, example_input):
5555
from diffusers import AutoencoderKL
5656

5757
class VaeModel(torch.nn.Module):
@@ -89,7 +89,7 @@ def __init__(
8989
self.vae.load_state_dict(custom_vae)
9090
self.base_vae = base_vae
9191

92-
def decode_inp(self, inp):
92+
def decode_inp(self, input):
9393
with torch.no_grad():
9494
if not self.base_vae:
9595
input = 1 / 0.18215 * input

0 commit comments

Comments
 (0)