Skip to content

Commit 5813043

Browse files
committed
Complete SD pipeline.
1 parent b0151a7 commit 5813043

File tree

8 files changed

+434
-212
lines changed

8 files changed

+434
-212
lines changed

apps/shark_studio/api/sd.py

Lines changed: 316 additions & 94 deletions
Large diffs are not rendered by default.

apps/shark_studio/modules/img_processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
7575
"parameters",
7676
f"{extra_info['prompt'][0]}"
7777
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
78-
f"\nSteps: {extra_info['steps'][0]},"
79-
f"Sampler: {extra_info['scheduler'][0]}, "
80-
f"CFG scale: {extra_info['guidance_scale'][0]}, "
78+
f"\nSteps: {extra_info['steps']},"
79+
f"Sampler: {extra_info['scheduler']}, "
80+
f"CFG scale: {extra_info['guidance_scale']}, "
8181
f"Seed: {img_seed},"
8282
f"Size: {png_size_text}, "
8383
f"Model: {img_model}, "

apps/shark_studio/modules/pipeline.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from msvcrt import kbhit
2-
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
2+
from shark.iree_utils.compile_utils import (
3+
get_iree_compiled_module,
4+
load_vmfb_using_mmap,
5+
clean_device_info,
6+
get_iree_target_triple,
7+
)
38
from apps.shark_studio.web.utils.file_utils import (
49
get_checkpoints_path,
510
get_resource_path,
@@ -32,8 +37,8 @@ def __init__(
3237
self.model_map = model_map
3338
self.static_kwargs = static_kwargs
3439
self.base_model_id = base_model_id
35-
self.device_name = device
36-
self.device = device.split("=>")[-1].strip(" ")
40+
self.triple = get_iree_target_triple(device)
41+
self.device, self.device_id = clean_device_info(device)
3742
self.import_mlir = import_mlir
3843
self.iree_module_dict = {}
3944
self.tempfiles = {}
@@ -46,22 +51,24 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
4651
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
4752
# and your model map is populated with any IR - unique model IDs and their static params,
4853
# call this method to get the artifacts associated with your map.
49-
self.pipe_id = pipe_id
54+
self.pipe_id = self.safe_name(pipe_id)
5055
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
5156
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
52-
print("\n[LOG] Checking for pre-compiled artifacts.")
5357
if submodel == "None":
58+
print("\n[LOG] Gathering any pre-compiled artifacts....")
5459
for key in self.model_map:
5560
self.get_compiled_map(pipe_id, submodel=key)
5661
else:
5762
self.get_precompiled(pipe_id, submodel)
5863
ireec_flags = []
5964
if submodel in self.iree_module_dict:
6065
if "vmfb" in self.iree_module_dict[submodel]:
61-
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
66+
print(f"\n[LOG] Executable for {submodel} already loaded...")
6267
return
68+
elif "vmfb_path" in self.model_map[submodel]:
69+
return
6370
elif submodel not in self.tempfiles:
64-
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
71+
print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
6572
if submodel in self.static_kwargs:
6673
init_kwargs = self.static_kwargs[submodel]
6774
for key in self.static_kwargs["pipe"]:
@@ -90,16 +97,6 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
9097
return
9198

9299

93-
def hijack_weights(self, weights_path, submodel="None"):
94-
if submodel == "None":
95-
for i in self.model_map:
96-
self.hijack_weights(weights_path, i)
97-
else:
98-
if submodel in self.iree_module_dict:
99-
self.model_map[submodel]["external_weights_file"] = weights_path
100-
return
101-
102-
103100
def get_precompiled(self, pipe_id, submodel="None"):
104101
if submodel == "None":
105102
for model in self.model_map:
@@ -112,33 +109,10 @@ def get_precompiled(self, pipe_id, submodel="None"):
112109
break
113110
for file in vmfbs:
114111
if submodel in file:
115-
print(f"Found existing .vmfb at {file}")
116-
self.iree_module_dict[submodel] = {}
117-
(
118-
self.iree_module_dict[submodel]["vmfb"],
119-
self.iree_module_dict[submodel]["config"],
120-
self.iree_module_dict[submodel]["temp_file_to_unlink"],
121-
) = load_vmfb_using_mmap(
122-
os.path.join(vmfbs_path, file),
123-
self.device,
124-
device_idx=0,
125-
rt_flags=[],
126-
external_weight_file=self.model_map[submodel]['external_weight_file'],
127-
)
112+
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
128113
return
129114

130115

131-
def safe_dict(self, kwargs: dict):
132-
flat_args = {}
133-
for i in kwargs:
134-
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
135-
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
136-
else:
137-
flat_args[i] = kwargs[i]
138-
139-
return flat_args
140-
141-
142116
def import_torch_ir(self, submodel, kwargs):
143117
torch_ir = self.model_map[submodel]["initializer"](
144118
**self.safe_dict(kwargs), compile_to="torch"
@@ -160,18 +134,53 @@ def import_torch_ir(self, submodel, kwargs):
160134
def load_submodels(self, submodels: list):
161135
for submodel in submodels:
162136
if submodel in self.iree_module_dict:
137+
print(f"\n[LOG] {submodel} is ready for inference.")
138+
if "vmfb_path" in self.model_map[submodel]:
163139
print(
164-
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
140+
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
141+
)
142+
self.iree_module_dict[submodel] = {}
143+
(
144+
self.iree_module_dict[submodel]["vmfb"],
145+
self.iree_module_dict[submodel]["config"],
146+
self.iree_module_dict[submodel]["temp_file_to_unlink"],
147+
) = load_vmfb_using_mmap(
148+
self.model_map[submodel]["vmfb_path"],
149+
self.device,
150+
device_idx=0,
151+
rt_flags=[],
152+
external_weight_file=self.model_map[submodel]['external_weight_file'],
165153
)
166154
else:
167155
self.get_compiled_map(self.pipe_id, submodel)
168156
return
169157

170158

159+
def unload_submodels(self, submodels: list):
160+
for submodel in submodels:
161+
if submodel in self.iree_module_dict:
162+
del self.iree_module_dict[submodel]
163+
gc.collect()
164+
return
165+
166+
171167
def run(self, submodel, inputs):
172-
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
168+
if not isinstance(inputs, list):
169+
inputs = [inputs]
170+
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs]
173171
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
174172

175173

176-
def safe_name(name):
177-
return name.replace("/", "_").replace("-", "_")
174+
def safe_name(self, name):
175+
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
176+
177+
178+
def safe_dict(self, kwargs: dict):
179+
flat_args = {}
180+
for i in kwargs:
181+
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
182+
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
183+
else:
184+
flat_args[i] = kwargs[i]
185+
186+
return flat_args

apps/shark_studio/modules/prompt_encoding.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from iree import runtime as ireert
44
import re
55
import torch
6+
import numpy as np
67

78
re_attention = re.compile(
89
r"""
@@ -161,7 +162,7 @@ def pad_tokens_and_weights(
161162
r"""
162163
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
163164
"""
164-
max_embeddings_multiples = 8
165+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
165166
weights_length = (
166167
max_length
167168
if no_boseos_middle
@@ -194,13 +195,16 @@ def pad_tokens_and_weights(
194195

195196
return tokens, weights
196197

197-
198198
def get_unweighted_text_embeddings(
199199
pipe,
200-
text_input: torch.Tensor,
200+
text_input,
201201
chunk_length: int,
202202
no_boseos_middle: Optional[bool] = True,
203203
):
204+
"""
205+
When the length of tokens is a multiple of the capacity of the text encoder,
206+
it should be split into chunks and sent to the text encoder individually.
207+
"""
204208
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
205209
if max_embeddings_multiples > 1:
206210
text_embeddings = []
@@ -214,7 +218,7 @@ def get_unweighted_text_embeddings(
214218
text_input_chunk[:, 0] = text_input[0, 0]
215219
text_input_chunk[:, -1] = text_input[0, -1]
216220

217-
text_embedding = pipe.run("clip", text_input_chunk)[0]
221+
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
218222

219223
if no_boseos_middle:
220224
if i == 0:
@@ -231,50 +235,14 @@ def get_unweighted_text_embeddings(
231235
# SHARK: Convert the result to tensor
232236
# text_embeddings = torch.concat(text_embeddings, axis=1)
233237
text_embeddings_np = np.concatenate(np.array(text_embeddings))
234-
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
238+
text_embeddings = torch.from_numpy(text_embeddings_np)
235239
else:
236240
text_embeddings = pipe.run("clip", text_input)[0]
237-
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
238-
return torch.from_numpy(text_embeddings.to_host())
239-
"""
240-
When the length of tokens is a multiple of the capacity of the text encoder,
241-
it should be split into chunks and sent to the text encoder individually.
242-
"""
243-
max_embeddings_multiples = 8
244-
text_embeddings = []
245-
for i in range(max_embeddings_multiples):
246-
# extract the i-th chunk
247-
text_input_chunk = text_input[
248-
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
249-
].clone()
250-
251-
# cover the head and the tail by the starting and the ending tokens
252-
text_input_chunk[:, 0] = text_input[0, 0]
253-
text_input_chunk[:, -1] = text_input[0, -1]
254-
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
255-
256-
print(text_input_chunk)
257-
breakpoint()
258-
text_embedding = pipe.run("clip", text_input_chunk)
259-
if no_boseos_middle:
260-
if i == 0:
261-
# discard the ending token
262-
text_embedding = text_embedding[:, :-1]
263-
elif i == max_embeddings_multiples - 1:
264-
# discard the starting token
265-
text_embedding = text_embedding[:, 1:]
266-
else:
267-
# discard both starting and ending tokens
268-
text_embedding = text_embedding[:, 1:-1]
269-
270-
text_embeddings.append(text_embedding)
271-
# SHARK: Convert the result to tensor
272-
# text_embeddings = torch.concat(text_embeddings, axis=1)
273-
text_embeddings_np = np.concatenate(np.array(text_embeddings))
274-
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
241+
text_embeddings = torch.from_numpy(text_embeddings.to_host())
275242
return text_embeddings
276243

277244

245+
278246
# This function deals with NoneType values occuring in tokens after padding
279247
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
280248
def filter_nonetype_tokens(tokens: List[List]):
@@ -286,7 +254,7 @@ def get_weighted_text_embeddings(
286254
prompt: List[str],
287255
uncond_prompt: List[str] = None,
288256
max_embeddings_multiples: Optional[int] = 8,
289-
no_boseos_middle: Optional[bool] = False,
257+
no_boseos_middle: Optional[bool] = True,
290258
skip_parsing: Optional[bool] = False,
291259
skip_weighting: Optional[bool] = False,
292260
):
@@ -325,12 +293,12 @@ def get_weighted_text_embeddings(
325293
max_length = max(
326294
max_length, max([len(token) for token in uncond_tokens])
327295
)
328-
329296
max_embeddings_multiples = min(
330297
max_embeddings_multiples,
331298
(max_length - 1) // (pipe.model_max_length - 2) + 1,
332299
)
333300
max_embeddings_multiples = max(1, max_embeddings_multiples)
301+
334302
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
335303

336304
# pad the length of tokens and weights

apps/shark_studio/modules/schedulers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def get_schedulers(model_id):
2020
#TODO: switch over to turbine and run all on GPU
21-
print(f"[LOG] Initializing schedulers from model id: {model_id}")
21+
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
2222
schedulers = dict()
2323
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
2424
model_id,

apps/shark_studio/web/configs/default_sd_config.json

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
{
22
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
33
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
4-
"sd_init_image": [ "None" ],
4+
"sd_init_image": [ null ],
55
"height": 512,
66
"width": 512,
7-
"steps": [ 50 ],
8-
"strength": [ 0.8 ],
9-
"guidance_scale": [ 7.5 ],
10-
"seed": [ -1 ],
7+
"steps": 50,
8+
"strength": 0.8,
9+
"guidance_scale": 7.5,
10+
"seed": -1,
1111
"batch_count": 1,
1212
"batch_size": 1,
13-
"scheduler": [ "EulerDiscrete" ],
13+
"scheduler": "EulerDiscrete",
1414
"base_model_id": "runwayml/stable-diffusion-v1-5",
15-
"custom_weights": "",
16-
"custom_vae": "",
15+
"custom_weights": null,
16+
"custom_vae": null,
17+
"use_base_vae": false,
1718
"precision": "fp16",
1819
"device": "vulkan",
19-
"ondemand": "False",
20-
"repeatable_seeds": "False",
20+
"ondemand": false,
21+
"repeatable_seeds": false,
2122
"resample_type": "Nearest Neighbor",
2223
"controlnets": {},
2324
"embeddings": {}

0 commit comments

Comments
 (0)