Skip to content

Commit

Permalink
fix: replace KSampler2 with KSampler and optimize CFG calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
Aatrick committed Mar 1, 2025
1 parent 08172d9 commit b940e26
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 251 deletions.
34 changes: 27 additions & 7 deletions modules/sample/CFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def cfg_function(
#### Returns:
- `torch.Tensor`: The CFG result.
"""
# Check for custom sampler CFG function first
if "sampler_cfg_function" in model_options:
# Precompute differences to avoid redundant operations
cond_diff = x - cond_pred
uncond_diff = x - uncond_pred

args = {
"cond": x - cond_pred,
"uncond": x - uncond_pred,
"cond": cond_diff,
"uncond": uncond_diff,
"cond_scale": cond_scale,
"timestep": timestep,
"input": x,
Expand All @@ -45,9 +50,18 @@ def cfg_function(
}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale

for fn in model_options.get("sampler_post_cfg_function", []):
# Standard CFG calculation - optimized to avoid intermediate tensor allocation
# When cond_scale = 1.0, we can just return cond_pred without computation
if math.isclose(cond_scale, 1.0):
cfg_result = cond_pred
else:
# Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale
# Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)

# Apply post-CFG functions if any
post_cfg_functions = model_options.get("sampler_post_cfg_function", [])
if post_cfg_functions:
args = {
"denoised": cfg_result,
"cond": cond,
Expand All @@ -59,7 +73,12 @@ def cfg_function(
"model_options": model_options,
"input": x,
}
cfg_result = fn(args)

# Apply each post-CFG function in sequence
for fn in post_cfg_functions:
cfg_result = fn(args)
# Update the denoised result for the next function
args["denoised"] = cfg_result

return cfg_result

Expand Down Expand Up @@ -128,6 +147,7 @@ def sampling_function(

class CFGGuider:
"""#### Class for guiding the sampling process with CFG."""

def __init__(self, model_patcher, flux=False):
"""#### Initialize the CFGGuider.
Expand Down Expand Up @@ -315,4 +335,4 @@ def sample(
del self.inner_model
del self.conds
del self.loaded_models
return output
return output
53 changes: 43 additions & 10 deletions modules/sample/ksampler_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def pre_run_control(model: torch.nn.Module, conds: list) -> None:

def percent_to_timestep_function(a):
return s.percent_to_sigma(a)

if "control" in x:
x["control"].pre_run(model, percent_to_timestep_function)

Expand Down Expand Up @@ -158,6 +159,7 @@ def normal_scheduler(
sigs += [0.0]
return torch.FloatTensor(sigs)


def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
"""#### Create a simple scheduler.
Expand All @@ -176,21 +178,52 @@ def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.Float
sigs += [0.0]
return torch.FloatTensor(sigs)


# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
total_timesteps = (len(model_sampling.sigmas) - 1)
ts = 1 - np.linspace(0, 1, steps, endpoint=False)
ts = np.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
"""Creates a beta scheduler for noise levels based on the beta distribution.
This optimized implementation efficiently computes sigmas using the beta
distribution and caches calculations where possible.
Args:
model_sampling: Model sampling module
steps: Number of steps
alpha: Alpha parameter for beta distribution
beta: Beta parameter for beta distribution
Returns:
torch.FloatTensor: Tensor of sigma values for each step
"""
# Calculate total timesteps once
total_timesteps = len(model_sampling.sigmas) - 1

# Create a cache dictionary for reused values
model_sigmas = model_sampling.sigmas

# Generate evenly spaced values in [0,1) interval
ts_normalized = np.linspace(0, 1, steps, endpoint=False)

# Apply beta inverse CDF to get sampled time points - vectorized operation
ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta)

# Scale to timestep indices and round to integers
ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32)

# Use numpy's unique function with return_index to efficiently find unique values
# while preserving order
unique_ts, indices = np.unique(ts_indices, return_index=True)
ordered_unique_ts = unique_ts[np.argsort(indices)]

# Map indices to sigma values efficiently
sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts]

# Add final sigma value of 0.0
sigs.append(0.0)

sigs = []
last_t = -1
for t in ts:
if t != last_t:
sigs += [float(model_sampling.sigmas[int(t)])]
last_t = t
sigs += [0.0]
return torch.FloatTensor(sigs)


def calculate_sigmas(
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
) -> torch.Tensor:
Expand Down
Loading

0 comments on commit b940e26

Please sign in to comment.