Skip to content

Commit

Permalink
Fix SharkEulerDiscrete (#2022)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Dec 6, 2023
1 parent c74b55f commit 2d6f488
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
self.batch_size = 1

def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
Expand Down Expand Up @@ -171,8 +171,9 @@ def _import(self):
_import(self)

def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
Expand Down Expand Up @@ -213,21 +214,25 @@ def step(
else noise_pred
)

if gamma > 0:
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)

eps = noise * s_noise

eps = noise * s_noise
if gamma > 0:
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5

if self.config.prediction_type == "v_prediction":
sigma_hat = sigma

dt = self.sigmas[self.step_index + 1] - sigma_hat

self._step_index += 1

return self.step_model(
"forward",
(
Expand Down

0 comments on commit 2d6f488

Please sign in to comment.