diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index aecf07c5..6576e390 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -35,10 +35,14 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'): + def __init__(self, + model_name: str = 'stabilityai/stable-diffusion-2-base', + pretrained: bool = False, + prediction_type: str = 'epsilon'): self.device = torch.cuda.current_device() model = stable_diffusion_2( + model_name=model_name, pretrained=pretrained, prediction_type=prediction_type, encode_latents_in_fp16=True, @@ -68,12 +72,14 @@ def predict(self, model_requests: List[Dict[str, Any]]): # Prompts and negative prompts if available if isinstance(inputs, str): prompts.append(inputs) - elif isinstance(input, Dict): - if 'prompt' not in req: + elif isinstance(inputs, Dict): + if 'prompt' not in inputs: raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') prompts.append(inputs['prompt']) - if 'negative_prompt' in req: + if 'negative_prompt' in inputs: negative_prompts.append(inputs['negative_prompt']) + else: + raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}') generate_kwargs = req['parameters']