From 41f49c3c9eb42b7a4a0b6da9d81074d016bf57a5 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Wed, 9 Aug 2023 10:13:40 -0700 Subject: [PATCH] Update deployment code to use explicit downloader (#54) --- diffusion/inference/inference_model.py | 19 +++++++++++++------ diffusion/inference/mosaic_inference.yaml | 5 +++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 99b19e0b..c4d95b83 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -5,7 +5,7 @@ import base64 import io -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch from composer.utils.file_helpers import get_file @@ -17,6 +17,15 @@ LOCAL_CHECKPOINT_PATH = '/tmp/model.pt' +def download_checkpoint(chkpt_path: str): + """Downloads the Stable Diffusion checkpoint to the local filesystem. + + Args: + chkpt_path (str): The path to the local folder, URL or object score that contains the checkpoint. + """ + get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH) + + class StableDiffusionInference(): """Inference endpoint class for Stable Diffusion. @@ -26,13 +35,11 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, chkpt_path: Optional[str] = None): - pretrained_flag = chkpt_path is None + def __init__(self, pretrained: bool = False): self.device = torch.cuda.current_device() - model = stable_diffusion_2(pretrained=pretrained_flag, encode_latents_in_fp16=True, fsdp=False) - if not pretrained_flag: - get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH) + model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False) + if not pretrained: state_dict = torch.load(LOCAL_CHECKPOINT_PATH) for key in list(state_dict['state']['model'].keys()): if 'val_metrics.' in key: diff --git a/diffusion/inference/mosaic_inference.yaml b/diffusion/inference/mosaic_inference.yaml index 39d70323..38f17c0d 100644 --- a/diffusion/inference/mosaic_inference.yaml +++ b/diffusion/inference/mosaic_inference.yaml @@ -10,7 +10,12 @@ integrations: git_branch: main pip_install: .[all] model: + downloader: diffusion.inference.inference_model.download_checkpoint + download_parameters: + chkpt_path: # Path to download the checkpoint to evaluate model_handler: diffusion.inference.inference_model.StableDiffusionInference + model_parameters: + pretrained: false command: | export PYTHONPATH=$PYTHONPATH:/code/diffusion rm /usr/lib/python3/dist-packages/packaging-23.1.dist-info/REQUESTED