Implementation of Key-Locked Rank One Editing. Project page
The selling point of this paper is extremely low extra parameters per added concept, down to 100kb.
It seems they successfully applied the Rank-1 editing technique from a memory editing paper for LLM, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).
For researchers out there, if this paper checks out, the tools in this repository should work for any other text-to-<insert modality>
network using cross attention conditioning. Just a thought
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
Yoad Tewel for the multiple code reviews and clarifying emails
-
Brad Vidler for precomputing the covariance matrix for the CLIP used in Stable Diffusion 1.5!
-
All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models
$ pip install perfusion-pytorch
import torch
from torch import nn
from perfusion_pytorch import Rank1EditModule
to_keys = nn.Linear(768, 320, bias = False)
to_values = nn.Linear(768, 320, bias = False)
wrapped_to_keys = Rank1EditModule(
to_keys,
is_key_proj = True
)
wrapped_to_values = Rank1EditModule(
to_values
)
text_enc = torch.randn(4, 77, 768) # regular input
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
concept_indices = torch.randint(0, 77, (4,)) # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch.ones(4, 77).bool()
keys = wrapped_to_keys(
text_enc,
concept_indices = concept_indices,
text_enc_with_superclass = text_enc_with_superclass,
)
values = wrapped_to_values(
text_enc,
concept_indices = concept_indices,
text_enc_with_superclass = text_enc_with_superclass,
)
# after much training ...
wrapped_to_keys.eval()
wrapped_to_values.eval()
keys = wrapped_to_keys(text_enc)
values = wrapped_to_values(text_enc)
The repository also contains an EmbeddingWrapper
that makes it easy to train on a new concept (and for eventual inference with multiple concepts)
import torch
from torch import nn
from perfusion_pytorch import EmbeddingWrapper
embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion
# wrap it, and will automatically create a new concept for learning, based on the superclass embed string
wrapped_embed = EmbeddingWrapper(
embed,
superclass_string = 'dog'
)
# now just pass in your prompts with the superclass id
embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
'a portrait of dog',
'dog running through a green field',
'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)
# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask
If you can identify the CLIP
instance within the stable diffusion instance, you can also pass it directly to the OpenClipEmbedWrapper
to gain everything you need on forward for the cross attention layers
ex.
from perfusion_pytorch import OpenClipEmbedWrapper
texts = [
'a portrait of dog',
'dog running through a green field',
'a man walking his dog'
]
wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
stable_diffusion.path.to.clip,
superclass_string = 'dog'
)
text_enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)
# (3, 77, 512), (3, 77, 512), (3, 77), (3,)
-
wire up with SD 1.5, starting with xiao's dreambooth-sd
-
show example in readme for inference with multiple concepts
-
automatically infer where keys and values projection are if not specified for the
make_key_value_proj_rank1_edit_modules_
function -
embedding wrapper should take care of substituting with super class token id and return embedding with super class
-
review multiple concepts - thanks to Yoad
-
offer a function that wires up the cross attention
-
handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
- accept multiple concept indices
-
offer a way to combine separately learned concepts from multiple
Rank1EditModule
into one for inference- offer function for merging
Rank1EditModule
s
- offer function for merging
-
add the zero-shot masking of concept proposed in paper
-
take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
-
instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
@article{Tewel2023KeyLockedRO,
title = {Key-Locked Rank One Editing for Text-to-Image Personalization},
author = {Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon},
journal = {ACM SIGGRAPH 2023 Conference Proceedings},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:258436985}
}
@inproceedings{Meng2022LocatingAE,
title = {Locating and Editing Factual Associations in GPT},
author = {Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
booktitle = {Neural Information Processing Systems},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:255825985}
}