diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index c4d95b83..aecf07c5 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -35,10 +35,16 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, pretrained: bool = False): + def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'): self.device = torch.cuda.current_device() - model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False) + model = stable_diffusion_2( + pretrained=pretrained, + prediction_type=prediction_type, + encode_latents_in_fp16=True, + fsdp=False, + ) + if not pretrained: state_dict = torch.load(LOCAL_CHECKPOINT_PATH) for key in list(state_dict['state']['model'].keys()):