@@ -518,46 +518,63 @@ class UNIPCBH2(Sampler):
518
518
def sample (self , model_wrap , sigmas , extra_args , callback , noise , latent_image = None , denoise_mask = None , disable_pbar = False ):
519
519
return uni_pc .sample_unipc (model_wrap , noise , latent_image , sigmas , max_denoise = self .max_denoise (model_wrap , sigmas ), extra_args = extra_args , noise_mask = denoise_mask , callback = callback , variant = 'bh2' , disable = disable_pbar )
520
520
521
- KSAMPLER_NAMES = ["euler" , "euler_ancestral" , "heun" , "dpm_2" , "dpm_2_ancestral" ,
521
+ KSAMPLER_NAMES = ["euler" , "euler_ancestral" , "heun" , "heunpp2" , " dpm_2" , "dpm_2_ancestral" ,
522
522
"lms" , "dpm_fast" , "dpm_adaptive" , "dpmpp_2s_ancestral" , "dpmpp_sde" , "dpmpp_sde_gpu" ,
523
523
"dpmpp_2m" , "dpmpp_2m_sde" , "dpmpp_2m_sde_gpu" , "dpmpp_3m_sde" , "dpmpp_3m_sde_gpu" , "ddpm" , "lcm" ]
524
524
525
- def ksampler (sampler_name , extra_options = {}, inpaint_options = {}):
526
- class KSAMPLER (Sampler ):
527
- def sample (self , model_wrap , sigmas , extra_args , callback , noise , latent_image = None , denoise_mask = None , disable_pbar = False ):
528
- extra_args ["denoise_mask" ] = denoise_mask
529
- model_k = KSamplerX0Inpaint (model_wrap )
530
- model_k .latent_image = latent_image
531
- if inpaint_options .get ("random" , False ): #TODO: Should this be the default?
532
- generator = torch .manual_seed (extra_args .get ("seed" , 41 ) + 1 )
533
- model_k .noise = torch .randn (noise .shape , generator = generator , device = "cpu" ).to (noise .dtype ).to (noise .device )
534
- else :
535
- model_k .noise = noise
525
+ class KSAMPLER (Sampler ):
526
+ def __init__ (self , sampler_function , extra_options = {}, inpaint_options = {}):
527
+ self .sampler_function = sampler_function
528
+ self .extra_options = extra_options
529
+ self .inpaint_options = inpaint_options
536
530
537
- if self .max_denoise (model_wrap , sigmas ):
538
- noise = noise * torch .sqrt (1.0 + sigmas [0 ] ** 2.0 )
539
- else :
540
- noise = noise * sigmas [0 ]
531
+ def sample (self , model_wrap , sigmas , extra_args , callback , noise , latent_image = None , denoise_mask = None , disable_pbar = False ):
532
+ extra_args ["denoise_mask" ] = denoise_mask
533
+ model_k = KSamplerX0Inpaint (model_wrap )
534
+ model_k .latent_image = latent_image
535
+ if self .inpaint_options .get ("random" , False ): #TODO: Should this be the default?
536
+ generator = torch .manual_seed (extra_args .get ("seed" , 41 ) + 1 )
537
+ model_k .noise = torch .randn (noise .shape , generator = generator , device = "cpu" ).to (noise .dtype ).to (noise .device )
538
+ else :
539
+ model_k .noise = noise
541
540
542
- k_callback = None
543
- total_steps = len (sigmas ) - 1
544
- if callback is not None :
545
- k_callback = lambda x : callback (x ["i" ], x ["denoised" ], x ["x" ], total_steps )
541
+ if self .max_denoise (model_wrap , sigmas ):
542
+ noise = noise * torch .sqrt (1.0 + sigmas [0 ] ** 2.0 )
543
+ else :
544
+ noise = noise * sigmas [0 ]
545
+
546
+ k_callback = None
547
+ total_steps = len (sigmas ) - 1
548
+ if callback is not None :
549
+ k_callback = lambda x : callback (x ["i" ], x ["denoised" ], x ["x" ], total_steps )
550
+
551
+ if latent_image is not None :
552
+ noise += latent_image
546
553
554
+ samples = self .sampler_function (model_k , noise , sigmas , extra_args = extra_args , callback = k_callback , disable = disable_pbar , ** self .extra_options )
555
+ return samples
556
+
557
+
558
+ def ksampler (sampler_name , extra_options = {}, inpaint_options = {}):
559
+ if sampler_name == "dpm_fast" :
560
+ def dpm_fast_function (model , noise , sigmas , extra_args , callback , disable ):
547
561
sigma_min = sigmas [- 1 ]
548
562
if sigma_min == 0 :
549
563
sigma_min = sigmas [- 2 ]
564
+ total_steps = len (sigmas ) - 1
565
+ return k_diffusion_sampling .sample_dpm_fast (model , noise , sigma_min , sigmas [0 ], total_steps , extra_args = extra_args , callback = callback , disable = disable )
566
+ sampler_function = dpm_fast_function
567
+ elif sampler_name == "dpm_adaptive" :
568
+ def dpm_adaptive_function (model , noise , sigmas , extra_args , callback , disable ):
569
+ sigma_min = sigmas [- 1 ]
570
+ if sigma_min == 0 :
571
+ sigma_min = sigmas [- 2 ]
572
+ return k_diffusion_sampling .sample_dpm_adaptive (model , noise , sigma_min , sigmas [0 ], extra_args = extra_args , callback = callback , disable = disable )
573
+ sampler_function = dpm_adaptive_function
574
+ else :
575
+ sampler_function = getattr (k_diffusion_sampling , "sample_{}" .format (sampler_name ))
550
576
551
- if latent_image is not None :
552
- noise += latent_image
553
- if sampler_name == "dpm_fast" :
554
- samples = k_diffusion_sampling .sample_dpm_fast (model_k , noise , sigma_min , sigmas [0 ], total_steps , extra_args = extra_args , callback = k_callback , disable = disable_pbar )
555
- elif sampler_name == "dpm_adaptive" :
556
- samples = k_diffusion_sampling .sample_dpm_adaptive (model_k , noise , sigma_min , sigmas [0 ], extra_args = extra_args , callback = k_callback , disable = disable_pbar )
557
- else :
558
- samples = getattr (k_diffusion_sampling , "sample_{}" .format (sampler_name ))(model_k , noise , sigmas , extra_args = extra_args , callback = k_callback , disable = disable_pbar , ** extra_options )
559
- return samples
560
- return KSAMPLER
577
+ return KSAMPLER (sampler_function , extra_options , inpaint_options )
561
578
562
579
def wrap_model (model ):
563
580
model_denoise = CFGNoisePredictor (model )
@@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
618
635
print ("error invalid scheduler" , self .scheduler )
619
636
return sigmas
620
637
621
- def sampler_class (name ):
638
+ def sampler_object (name ):
622
639
if name == "uni_pc" :
623
- sampler = UNIPC
640
+ sampler = UNIPC ()
624
641
elif name == "uni_pc_bh2" :
625
- sampler = UNIPCBH2
642
+ sampler = UNIPCBH2 ()
626
643
elif name == "ddim" :
627
644
sampler = ksampler ("euler" , inpaint_options = {"random" : True })
628
645
else :
@@ -687,6 +704,6 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
687
704
else :
688
705
return torch .zeros_like (noise )
689
706
690
- sampler = sampler_class (self .sampler )
707
+ sampler = sampler_object (self .sampler )
691
708
692
- return sample (self .model , noise , positive , negative , cfg , self .device , sampler () , sigmas , self .model_options , latent_image = latent_image , denoise_mask = denoise_mask , callback = callback , disable_pbar = disable_pbar , seed = seed )
709
+ return sample (self .model , noise , positive , negative , cfg , self .device , sampler , sigmas , self .model_options , latent_image = latent_image , denoise_mask = denoise_mask , callback = callback , disable_pbar = disable_pbar , seed = seed )
0 commit comments