forked from facebookincubator/AITemplate
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable Diffusion dynamic input shape, include/exclude constants, load…
… from diffusers/compvis, alternative pipeline (facebookincubator#696) Summary: * min/max height/width * include/exclude constants from module * load from diffusers model to compiled aitemplate module * load from compvis model to compiled aitemplate module * pipeline doesn't rely on StableDiffusionPipeline * set shape of output tensor according to height/width ``` ~/AITemplate/examples/05_stable_diffusion$ python scripts/compile_alt.py --min-width 64 --max-width 1536 --min-height 64 --max-height 1536 --clip-chunks 6 ``` ``` ~/AITemplate/examples/05_stable_diffusion$ python scripts/demo_alt.py INFO:aitemplate.backend.build_cache_base:Build cache disabled 2023-05-15 18:55:09,465 INFO <aitemplate.testing.detect_target> Set target to CUDA [18:55:09] model_container.cu:67: Device Runtime Version: 11060; Driver Version: 12010 [18:55:09] model_container.cu:81: Hardware accelerator device properties: Device: ASCII string identifying device: NVIDIA GeForce RTX 3060 [18:55:09] model_container.cu:85: Init AITemplate Runtime with 1 concurrency Loading PyTorch CLIP Setting constants Folding constants [18:55:19] model_container.cu:67: Device Runtime Version: 11060; Driver Version: 12010 [18:55:19] model_container.cu:81: Hardware accelerator device properties: Device: ASCII string identifying device: NVIDIA GeForce RTX 3060 [18:55:19] model_container.cu:85: Init AITemplate Runtime with 1 concurrency Loading PyTorch UNet Setting constants Folding constants [18:55:24] model_container.cu:67: Device Runtime Version: 11060; Driver Version: 12010 [18:55:24] model_container.cu:81: Hardware accelerator device properties: Device: ASCII string identifying device: NVIDIA GeForce RTX 3060 [18:55:24] model_container.cu:85: Init AITemplate Runtime with 1 concurrency Loading PyTorch VAE Mapping parameters... Setting constants Folding constants 100%|| 50/50 [00:03<00:00, 12.94it/s] ``` Pull Request resolved: facebookincubator#696 Reviewed By: terrychenism Differential Revision: D45964831 Pulled By: chenyang78 fbshipit-source-id: c126db27afb425b156e15373580a20cfbb06290a
- Loading branch information
1 parent
f738b9b
commit b9d77bd
Showing
7 changed files
with
1,611 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import logging | ||
|
||
import click | ||
import torch | ||
from aitemplate.testing import detect_target | ||
from aitemplate.utils.import_path import import_parent | ||
from diffusers import StableDiffusionPipeline | ||
|
||
if __name__ == "__main__": | ||
import_parent(filepath=__file__, level=1) | ||
|
||
from src.compile_lib.compile_clip_alt import compile_clip | ||
from src.compile_lib.compile_unet_alt import compile_unet | ||
from src.compile_lib.compile_vae_alt import compile_vae | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--local-dir", | ||
default="./tmp/diffusers-pipeline/runwayml/stable-diffusion-v1-5", | ||
help="the local diffusers pipeline directory", | ||
) | ||
@click.option( | ||
"--width", | ||
default=(64, 2048), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum width", | ||
) | ||
@click.option( | ||
"--height", | ||
default=(64, 2048), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum height", | ||
) | ||
@click.option( | ||
"--batch-size", | ||
default=(1, 4), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum batch size", | ||
) | ||
@click.option("--clip-chunks", default=6, help="Maximum number of clip chunks") | ||
@click.option( | ||
"--include-constants", | ||
default=None, | ||
help="include constants (model weights) with compiled model", | ||
) | ||
@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") | ||
@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") | ||
def compile_diffusers( | ||
local_dir, | ||
width, | ||
height, | ||
batch_size, | ||
clip_chunks, | ||
include_constants, | ||
use_fp16_acc=True, | ||
convert_conv_to_gemm=True, | ||
): | ||
logging.getLogger().setLevel(logging.INFO) | ||
torch.manual_seed(4896) | ||
|
||
if detect_target().name() == "rocm": | ||
convert_conv_to_gemm = False | ||
|
||
assert ( | ||
width[0] % 64 == 0 and width[1] % 64 == 0 | ||
), "Minimum Width and Maximum Width must be multiples of 64, otherwise, the compilation process will fail." | ||
assert ( | ||
height[0] % 64 == 0 and height[1] % 64 == 0 | ||
), "Minimum Height and Maximum Height must be multiples of 64, otherwise, the compilation process will fail." | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained( | ||
local_dir, | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
).to("cuda") | ||
|
||
# CLIP | ||
compile_clip( | ||
pipe.text_encoder, | ||
batch_size=batch_size, | ||
seqlen=77, | ||
use_fp16_acc=use_fp16_acc, | ||
convert_conv_to_gemm=convert_conv_to_gemm, | ||
depth=pipe.text_encoder.config.num_hidden_layers, | ||
num_heads=pipe.text_encoder.config.num_attention_heads, | ||
dim=pipe.text_encoder.config.hidden_size, | ||
act_layer=pipe.text_encoder.config.hidden_act, | ||
constants=True if include_constants else False, | ||
) | ||
# UNet | ||
compile_unet( | ||
pipe.unet, | ||
batch_size=batch_size, | ||
width=width, | ||
height=height, | ||
clip_chunks=clip_chunks, | ||
use_fp16_acc=use_fp16_acc, | ||
convert_conv_to_gemm=convert_conv_to_gemm, | ||
hidden_dim=pipe.unet.config.cross_attention_dim, | ||
attention_head_dim=pipe.unet.config.attention_head_dim, | ||
use_linear_projection=pipe.unet.config.get("use_linear_projection", False), | ||
constants=True if include_constants else False, | ||
) | ||
# VAE | ||
compile_vae( | ||
pipe.vae, | ||
batch_size=batch_size, | ||
width=width, | ||
height=height, | ||
use_fp16_acc=use_fp16_acc, | ||
convert_conv_to_gemm=convert_conv_to_gemm, | ||
constants=True if include_constants else False, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
compile_diffusers() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import click | ||
import torch | ||
|
||
from aitemplate.utils.import_path import import_parent | ||
|
||
if __name__ == "__main__": | ||
import_parent(filepath=__file__, level=1) | ||
|
||
from src.pipeline_stable_diffusion_ait_alt import StableDiffusionAITPipeline | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--hf-hub-or-path", | ||
default="runwayml/stable-diffusion-v1-5", | ||
help="Model weights to apply to compiled model (with --include-constants false)", | ||
) | ||
@click.option("--ckpt", default=None, help="e.g. v1-5-pruned-emaonly.ckpt") | ||
@click.option("--width", default=512, help="Width of generated image") | ||
@click.option("--height", default=512, help="Height of generated image") | ||
@click.option("--batch", default=1, help="Batch size of generated image") | ||
@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") | ||
@click.option("--negative_prompt", default="", help="prompt") | ||
@click.option("--steps", default=50, help="Number of inference steps") | ||
@click.option("--cfg", default=7.5, help="Guidance scale") | ||
def run( | ||
hf_hub_or_path, ckpt, width, height, batch, prompt, negative_prompt, steps, cfg | ||
): | ||
pipe = StableDiffusionAITPipeline( | ||
hf_hub_or_path=hf_hub_or_path, | ||
ckpt=ckpt, | ||
) | ||
|
||
prompt = [prompt] * batch | ||
negative_prompt = [negative_prompt] * batch | ||
with torch.autocast("cuda"): | ||
image = pipe( | ||
prompt=prompt, | ||
height=height, | ||
width=width, | ||
negative_prompt=negative_prompt, | ||
num_inference_steps=steps, | ||
guidance_scale=cfg, | ||
).images[0] | ||
image.save("example_ait.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
90 changes: 90 additions & 0 deletions
90
examples/05_stable_diffusion/src/compile_lib/compile_clip_alt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from aitemplate.compiler import compile_model | ||
from aitemplate.frontend import IntVar, Tensor | ||
from aitemplate.testing import detect_target | ||
|
||
from ..modeling.clip import CLIPTextTransformer as ait_CLIPTextTransformer | ||
from .util import mark_output | ||
|
||
|
||
def map_clip_params(pt_mod, batch_size=1, seqlen=77, depth=12): | ||
params_ait = {} | ||
pt_params = dict(pt_mod.named_parameters()) | ||
for key, arr in pt_params.items(): | ||
name = key.replace("text_model.", "") | ||
ait_name = name.replace(".", "_") | ||
if name.endswith("out_proj.weight"): | ||
ait_name = ait_name.replace("out_proj", "proj") | ||
elif name.endswith("out_proj.bias"): | ||
ait_name = ait_name.replace("out_proj", "proj") | ||
elif "q_proj" in name: | ||
ait_name = ait_name.replace("q_proj", "proj_q") | ||
elif "k_proj" in name: | ||
ait_name = ait_name.replace("k_proj", "proj_k") | ||
elif "v_proj" in name: | ||
ait_name = ait_name.replace("v_proj", "proj_v") | ||
params_ait[ait_name] = arr | ||
|
||
return params_ait | ||
|
||
|
||
def compile_clip( | ||
pt_mod, | ||
batch_size=(1, 8), | ||
seqlen=64, | ||
dim=768, | ||
num_heads=12, | ||
depth=12, | ||
use_fp16_acc=False, | ||
convert_conv_to_gemm=False, | ||
act_layer="gelu", | ||
constants=True, | ||
): | ||
mask_seq = 0 | ||
causal = True | ||
|
||
ait_mod = ait_CLIPTextTransformer( | ||
num_hidden_layers=depth, | ||
hidden_size=dim, | ||
num_attention_heads=num_heads, | ||
batch_size=batch_size, | ||
seq_len=seqlen, | ||
causal=causal, | ||
mask_seq=mask_seq, | ||
act_layer=act_layer, | ||
) | ||
ait_mod.name_parameter_tensor() | ||
|
||
pt_mod = pt_mod.eval() | ||
params_ait = map_clip_params(pt_mod, batch_size, seqlen, depth) | ||
batch_size = IntVar(values=list(batch_size), name="batch_size") | ||
|
||
input_ids_ait = Tensor( | ||
[batch_size, seqlen], name="input0", dtype="int64", is_input=True | ||
) | ||
position_ids_ait = Tensor( | ||
[batch_size, seqlen], name="input1", dtype="int64", is_input=True | ||
) | ||
Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) | ||
mark_output(Y) | ||
|
||
target = detect_target( | ||
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm | ||
) | ||
compile_model( | ||
Y, target, "./tmp", "CLIPTextModel", constants=params_ait if constants else None | ||
) |
Oops, something went wrong.