Skip to content

Implementation of Key-Locked Rank One Editing, from Nvidia AI

License

Notifications You must be signed in to change notification settings

lucidrains/perfusion-pytorch

Repository files navigation

Perfusion - Pytorch

Implementation of Key-Locked Rank One Editing. Project page

The selling point of this paper is extremely low extra parameters per added concept, down to 100kb.

It seems they successfully applied the Rank-1 editing technique from a memory editing paper for LLM, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).

For researchers out there, if this paper checks out, the tools in this repository should work for any other text-to-<insert modality> network using cross attention conditioning. Just a thought

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • Yoad Tewel for the multiple code reviews and clarifying emails

  • Brad Vidler for precomputing the covariance matrix for the CLIP used in Stable Diffusion 1.5!

  • All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models

Install

$ pip install perfusion-pytorch

Usage

import torch
from torch import nn

from perfusion_pytorch import Rank1EditModule

to_keys = nn.Linear(768, 320, bias = False)
to_values = nn.Linear(768, 320, bias = False)

wrapped_to_keys = Rank1EditModule(
    to_keys,
    is_key_proj = True
)

wrapped_to_values = Rank1EditModule(
    to_values
)

text_enc = torch.randn(4, 77, 768)                  # regular input
text_enc_with_superclass = torch.randn(4, 77, 768)  # init_input in algorithm 1, for key-locking
concept_indices = torch.randint(0, 77, (4,))        # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch.ones(4, 77).bool()

keys = wrapped_to_keys(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

values = wrapped_to_values(
    text_enc,
    concept_indices = concept_indices,
    text_enc_with_superclass = text_enc_with_superclass,
)

# after much training ...

wrapped_to_keys.eval()
wrapped_to_values.eval()

keys = wrapped_to_keys(text_enc)

values = wrapped_to_values(text_enc)

The repository also contains an EmbeddingWrapper that makes it easy to train on a new concept (and for eventual inference with multiple concepts)

import torch
from torch import nn

from perfusion_pytorch import EmbeddingWrapper

embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion

# wrap it, and will automatically create a new concept for learning, based on the superclass embed string

wrapped_embed = EmbeddingWrapper(
    embed,
    superclass_string = 'dog'
)

# now just pass in your prompts with the superclass id

embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)

# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask

If you can identify the CLIP instance within the stable diffusion instance, you can also pass it directly to the OpenClipEmbedWrapper to gain everything you need on forward for the cross attention layers

ex.

from perfusion_pytorch import OpenClipEmbedWrapper

texts = [
    'a portrait of dog',
    'dog running through a green field',
    'a man walking his dog'
]

wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
    stable_diffusion.path.to.clip,
    superclass_string = 'dog'
)

text_enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)

# (3, 77, 512), (3, 77, 512), (3, 77), (3,)

Todo

  • wire up with SD 1.5, starting with xiao's dreambooth-sd

  • show example in readme for inference with multiple concepts

  • automatically infer where keys and values projection are if not specified for the make_key_value_proj_rank1_edit_modules_ function

  • embedding wrapper should take care of substituting with super class token id and return embedding with super class

  • review multiple concepts - thanks to Yoad

  • offer a function that wires up the cross attention

  • handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs

    • accept multiple concept indices
  • offer a way to combine separately learned concepts from multiple Rank1EditModule into one for inference

    • offer function for merging Rank1EditModules
  • add the zero-shot masking of concept proposed in paper

  • take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update

  • instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)

Citations

@article{Tewel2023KeyLockedRO,
    title   = {Key-Locked Rank One Editing for Text-to-Image Personalization},
    author  = {Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon},
    journal = {ACM SIGGRAPH 2023 Conference Proceedings},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:258436985}
}
@inproceedings{Meng2022LocatingAE,
    title   = {Locating and Editing Factual Associations in GPT},
    author  = {Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
    booktitle = {Neural Information Processing Systems},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:255825985}
}