-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathPerformancePipeline.py
21 lines (20 loc) · 1 KB
/
PerformancePipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# PerformancePipeline Or PP for convenience (Hahah- ok, I'm sorry)
from diffusers import StableDiffusionPipeline
import torch
def from_pretrained(model_name, safety_checker=None):
torch.set_default_dtype(torch.float16)
rev = "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "" if model_name == "SG161222/Realistic_Vision_V2.0" else "fp16"
pipe = None
try:
if rev != "":
pipe = StableDiffusionPipeline.from_pretrained(model_name, revision=rev, torch_dtype=torch.float16, safety_checker=safety_checker)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=safety_checker)
pipe.to("cuda")
except:
try:
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=safety_checker)
pipe.to("cuda")
except Exception as e:
print("Failed to load model %s: %s" % (model_name, e))
return pipe