Skip to content

Commit

Permalink
Small fixes for unifying pipelines.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 23, 2024
1 parent 6c63f3b commit 9658431
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import os
import json
import numpy as np
import copy
from tqdm.auto import tqdm

from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference import clip, unet, vae
from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
Expand All @@ -34,16 +34,11 @@
process_custom_pipe_weights,
)

sd_model_map = {
"clip": {
"initializer": clip.export_clip_model,
},
"unet": {
"initializer": unet.export_unet_model,
},
"vae_decode": {
"initializer": vae.export_vae_model,
},

EMPTY_SD_MAP = {
"clip": None,
"unet": None,
"vae_decode": None,
}

EMPTY_FLAGS = {
Expand Down Expand Up @@ -75,7 +70,6 @@ def __init__(
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
hf_auth_token=None,
):
self.compiled_pipeline = False
self.base_model_id = base_model_id
Expand All @@ -102,7 +96,7 @@ def __init__(
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline(
self.sd_pipe = SharkSDPipeline(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
Expand All @@ -125,28 +119,10 @@ def __init__(

def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = is_img2img
mlirs = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
vmfbs = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
weights = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
self.is_img2img = False
mlirs = copy.deepcopy(EMPTY_SD_MAP)
vmfbs = copy.deepcopy(EMPTY_SD_MAP)
weights = copy.deepcopy(EMPTY_SD_MAP)
vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline)
Expand Down Expand Up @@ -235,6 +211,7 @@ def shark_sd_fn(
control_mode = None
hints = []
num_loras = 0
import_ir=True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
Expand Down Expand Up @@ -268,7 +245,7 @@ def shark_sd_fn(
"device": device,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": cmd_opts.import_mlir,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
Expand Down

0 comments on commit 9658431

Please sign in to comment.