Skip to content

Commit

Permalink
is_expendable argument reduces memory usage for command line script
Browse files Browse the repository at this point in the history
  • Loading branch information
kuprel committed Jun 30, 2022
1 parent 3837710 commit 1e18ba0
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 31 deletions.
15 changes: 8 additions & 7 deletions image_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
parser.add_argument('--token_count', type=int, default=256) # for debugging


def ascii_from_image(image: Image.Image, size: int) -> str:
Expand All @@ -42,20 +42,21 @@ def generate_image(
text: str,
seed: int,
image_path: str,
sample_token_count: int
token_count: int
):
is_expendable = True
if is_torch:
image_generator = MinDalleTorch(is_mega, sample_token_count)
image_tokens = image_generator.generate_image_tokens(text, seed)
image_generator = MinDalleTorch(is_mega, is_expendable, token_count)

if sample_token_count < image_generator.config['image_length']:
if token_count < image_generator.config['image_length']:
image_tokens = image_generator.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)

else:
image_generator = MinDalleFlax(is_mega)
image_generator = MinDalleFlax(is_mega, is_expendable=True)
image = image_generator.generate_image(text, seed)

save_image(image, image_path)
Expand All @@ -71,5 +72,5 @@ def generate_image(
text=args.text,
seed=args.seed,
image_path=args.image_path,
sample_token_count=args.sample_token_count
token_count=args.token_count
)
12 changes: 8 additions & 4 deletions min_dalle/min_dalle.py → min_dalle/min_dalle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer

class MinDalle:
class MinDalleBase:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
Expand All @@ -25,11 +25,15 @@ def __init__(self, is_mega: bool):
merges = f.read().split("\n")[1:-1]

self.model_params = load_dalle_bart_flax_params(model_path)

self.tokenizer = TextTokenizer(vocab, merges)


def init_detokenizer(self):
print("initializing VQGanDetokenizer")
params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer = VQGanDetokenizer()
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer.load_state_dict(vqgan_params)
self.detokenizer.load_state_dict(params)
del params


def tokenize_text(self, text: str) -> numpy.ndarray:
Expand Down
33 changes: 26 additions & 7 deletions min_dalle/min_dalle_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@
from PIL import Image
import torch

from .min_dalle import MinDalle
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax


class MinDalleFlax(MinDalle):
def __init__(self, is_mega: bool):
class MinDalleFlax(MinDalleBase):
def __init__(self, is_mega: bool, is_expendable: bool = False):
super().__init__(is_mega)
self.is_expendable = is_expendable
print("initializing MinDalleFlax")
if not is_expendable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()

print("loading encoder")
self.encoder = DalleBartEncoderFlax(

def init_encoder(self):
print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
Expand All @@ -23,7 +30,9 @@ def __init__(self, is_mega: bool):
layer_count = self.config['encoder_layers']
).bind({'params': self.model_params.pop('encoder')})

print("loading decoder")

def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
Expand All @@ -39,20 +48,30 @@ def __init__(self, is_mega: bool):
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)

if self.is_expendable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
if self.is_expendable: del self.encoder

if self.is_expendable:
self.init_decoder()
params = self.model_params.pop('decoder')
else:
params = self.model_params['decoder']
print("sampling image tokens")
image_tokens = self.decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
self.model_params['decoder']
params
)
if self.is_expendable: del self.decoder

image_tokens = torch.tensor(numpy.array(image_tokens))

if self.is_expendable: self.init_detokenizer()
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
if self.is_expendable: del self.detokenizer
image = Image.fromarray(image.to('cpu').detach().numpy())
return image
52 changes: 39 additions & 13 deletions min_dalle/min_dalle_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,30 @@
torch.set_num_threads(os.cpu_count())

from .load_params import convert_dalle_bart_torch_from_flax_params
from .min_dalle import MinDalle
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch


class MinDalleTorch(MinDalle):
def __init__(self, is_mega: bool, sample_token_count: int = 256):
class MinDalleTorch(MinDalleBase):
def __init__(
self,
is_mega: bool,
is_expendable: bool = False,
token_count: int = 256
):
super().__init__(is_mega)
self.is_expendable = is_expendable
self.token_count = token_count
print("initializing MinDalleTorch")
if not is_expendable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()

print("loading encoder")

def init_encoder(self):
print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
Expand All @@ -28,18 +41,22 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('encoder'),
layer_count=self.config['encoder_layers'],
is_encoder=True
)
self.encoder.load_state_dict(encoder_params, strict=False)
self.encoder.load_state_dict(params, strict=False)
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
del params


print("loading decoder")
def init_decoder(self):
print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = sample_token_count,
sample_token_count = self.token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
Expand All @@ -48,36 +65,45 @@ def __init__(self, is_mega: bool, sample_token_count: int = 256):
start_token = self.config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('decoder'),
layer_count=self.config['decoder_layers'],
is_encoder=False
)
self.decoder.load_state_dict(decoder_params, strict=False)
self.decoder.load_state_dict(params, strict=False)
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
del params


def init_detokenizer(self):
super().init_detokenizer()
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
self.detokenizer = self.detokenizer.cuda()


def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()

if self.is_expendable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
if self.is_expendable: del self.encoder

if self.is_expendable: self.init_decoder()
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = self.decoder.forward(text_tokens, encoder_state)
if self.is_expendable: del self.decoder
return image_tokens


def generate_image(self, text: str, seed: int) -> Image.Image:
image_tokens = self.generate_image_tokens(text, seed)
if self.is_expendable: self.init_detokenizer()
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
if self.is_expendable: del self.detokenizer
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

0 comments on commit 1e18ba0

Please sign in to comment.