-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathupscale.py
134 lines (111 loc) · 4.63 KB
/
upscale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import logging
import os
from typing import List, Optional, Tuple
import PIL
import torch
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
ImageFile.LOAD_TRUNCATED_IMAGES = True
logger = logging.getLogger(__name__)
class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}
torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device.type != "cpu" and has_fp16_variant:
logger.info("UpscalePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)
if sfast_enabled:
logger.info(
"UpscalePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.ldm = compile_model(self.ldm)
# Warm-up the pipeline.
# TODO: Not yet supported for UpscalePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"UpscalePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"UpscalePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)
if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(outputs.images)
return outputs.images, has_nsfw_concept
def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"