Skip to content

Commit

Permalink
Update colab.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Beyondo committed Dec 17, 2022
1 parent 323c926 commit 17af0c2
Showing 1 changed file with 100 additions and 12 deletions.
112 changes: 100 additions & 12 deletions colab.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import patcher, torch, random, time
import patcher, torch, random, time, importlib, os
importlib.reload(patcher)
from IPython import display
from IPython.display import HTML
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
from diffusers.schedulers import PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DDPMScheduler
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import ClipGuided
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
model_name = ""
ready = False
tokenizer = None
pipeline = None
text2img = None
img2img = None
inpaint = None
Expand All @@ -17,29 +24,84 @@ def get_current_image_seed():
return settings['InitialSeed'] + image_id
def get_current_image_uid():
return "text2img-%d" % get_current_image_seed()
# v1.4 = laion/CLIP-ViT-B-32-laion2B-s34B-b79K
# v1.5 = sentence-transformers/clip-ViT-L-14
def create_guided_pipeline(pipeline):
clip_model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
clip_model = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float16).to("cuda:0")
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_name, torch_dtype=torch.float16)
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
skip_prk_steps=True)
guided_pipeline = ClipGuided.CLIPGuidedStableDiffusion(
unet=pipeline.unet,
vae=pipeline.vae,
tokenizer=pipeline.tokenizer,
text_encoder=pipeline.text_encoder,
scheduler=scheduler,
clip_model=clip_model,
feature_extractor=feature_extractor,
)
return guided_pipeline
def modify_clip_limit(limit):
global pipeline
# Text Encoder
old_weights = pipeline.text_encoder.text_model.embeddings.position_embedding.weight.data.to("cuda:0")
input_embeddings = pipeline.text_encoder.text_model.embeddings.token_embedding
pipeline.text_encoder.config.max_position_embeddings = limit
# Bug: The following line is supposed to be a hack to make the model reload everything using the new config but it also makes the model generate random images:
#pipeline.text_encoder.text_model.__init__(config=pipeline.text_encoder.config)
# Which might be because the model wasn't trained to receive N number of tokens to begin with,
# however, that might not be the case since if I tried with the default value, that's "77" and uncommenting that line, it still generates random images.
# So there's still the possibility that there might be a way to make it work, but I don't know how.
# In any case, it's not as trivial as I thought.
pipeline.text_encoder.text_model.to("cuda:0")
pipeline.text_encoder.text_model.embeddings.token_embedding = input_embeddings
pipeline.text_encoder.text_model.embeddings.position_embedding = torch.nn.Embedding(limit, 768).to("cuda:0") # Zero padding
pipeline.text_encoder.text_model.embeddings.position_embedding.weight.data[:old_weights.shape[0]] = old_weights
# Tokenizer
pipeline.tokenizer.model_max_length = limit
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))

def init(ModelName):
global model_name, ready, text2img, img2img, inpaint
global model_name, ready, pipeline, tokenizer, text2img, img2img, inpaint
ready = False
model_name = ModelName
settings['ModelName'] = ModelName
patcher.patch()
if not torch.cuda.is_available():
print("No GPU found. If you are on Colab, go to Runtime -> Change runtime type, and choose \"GPU\" then click Save.")
else:
print("Running on -> ", end="")
print(torch.cuda.get_device_name("cuda:0") + ".")
print("Initializing model -> " + model_name + ":")
try:
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
from transformers import AutoTokenizer
rev = "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
text2img = StableDiffusionPipeline.from_pretrained(model_name, revision=rev, torch_dtype=torch.float16).to("cuda:0")
img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
inpaint = StableDiffusionInpaintPipeline(**text2img.components)
install_vendor()
print("Initializing model " + model_name + ":")
torch.set_default_dtype(torch.float16)
rev = "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "" if model_name == "prompthero/openjourney" else "fp16"
# Hook VOIDPipeline to StableDiffusionPipeline
#import VOIDPipeline, importlib
#importlib.reload(VOIDPipeline)
#VOIDPipeline.Hook()
if rev != "":
pipeline = StableDiffusionPipeline.from_pretrained(model_name, revision=rev, torch_dtype=torch.float16).to("cuda:0")
else:
pipeline = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda:0")
#modify_clip_limit(77)
patcher.patch(pipeline)
text2img = pipeline
img2img = StableDiffusionImg2ImgPipeline(**pipeline.components)
inpaint = StableDiffusionInpaintPipeline(**pipeline.components)
print("Done.")
ready = True
from IPython.display import clear_output; clear_output()
#from IPython.display import clear_output; clear_output()
display.display(HTML("Model <strong><span style='color: green'>%s</span></strong> has been selected." % model_name))
except Exception as e:
# if contains "502 Server Error"
if "502" in str(e):
print("Received 502 Server Error: Huggingface is currently down." % model_name)
print("Failed to initialize model %s with error %s" % (model_name, e))

def prepare(mode):
Expand All @@ -53,4 +115,30 @@ def prepare(mode):
else:
settings['InitialSeed'] = settings['Seed']
current_mode = mode
torch.cuda.empty_cache()
torch.cuda.empty_cache()

def install_vendor():
print("Installing vendors -> ", end="")
import os, IPython
if(os.path.exists("vendor")):
print("Vendor already installed.")
return
try:
os.mkdir("vendor")
# GFPGAN
os.remove("vendor/GFPGAN") if os.path.exists("vendor/GFPGAN") else None
# git clone using IPython magic
IPython.get_ipython().system("git clone https://github.com/TencentARC/GFPGAN.git vendor/GFPGAN &> /dev/null")
IPython.get_ipython().system("pip install basicsr &> /dev/null")
IPython.get_ipython().system("pip install facexlib &> /dev/null")
IPython.get_ipython().system("pip install -q -r vendor/GFPGAN/requirements.txt &> /dev/null")
IPython.get_ipython().system("python vendor/GFPGAN/setup.py develop &> /dev/null")
# used for enhancing the background (non-face) regions
IPython.get_ipython().system("pip install realesrgan &> /dev/null")
# used for enhancing the background (non-face) regions
IPython.get_ipython().system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -p experiments/pretrained_models &> /dev/null")
IPython.get_ipython().system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.8/GFPGANv1.3.pth -P experiments/pretrained_models &> /dev/null")
# ESRGAN
print("Done.")
except Exception as e:
print("Error: %s" % e)

0 comments on commit 17af0c2

Please sign in to comment.