From 1fc466958b3930b26db6afb7f8dfc2227301d0d8 Mon Sep 17 00:00:00 2001 From: vyomeshnayi Date: Wed, 28 Jun 2023 17:50:02 -0400 Subject: [PATCH 1/2] Encode Modify EncoderBlock and Encoder classes in the code --- dall_e/new.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 dall_e/new.py diff --git a/dall_e/new.py b/dall_e/new.py new file mode 100644 index 0000000..4a32155 --- /dev/null +++ b/dall_e/new.py @@ -0,0 +1,51 @@ +import attr +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict +from functools import partial +from dall_e.utils import Conv2d + + +@attr.s(eq=False, repr=False) +class EncoderBlock(nn.Module): + n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) + n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0) + n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) + + device: torch.device = attr.ib(default=None) + requires_grad: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + super().__init__() + self.n_hid = self.n_out // 4 + self.post_gain = 1 / (self.n_layers ** 2) + + make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) + self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(self.n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, self.n_out, 1)), + ]) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@attr.s(eq=False, repr=False) +class Encoder(nn.Module): + group_count: int = 4 + n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) + n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) + input_channels: int = attr.ib(default=3, validator=lambda From 37cd7d3e37ceb3c71b6dc75a6af9ccc789eccf65 Mon Sep 17 00:00:00 2001 From: vyomeshnayi Date: Wed, 28 Jun 2023 18:44:54 -0400 Subject: [PATCH 2/2] new new code --- dall_e/encoder.py | 40 ++++++++++++++++++------------------- dall_e/new.py | 51 ----------------------------------------------- 2 files changed, 19 insertions(+), 72 deletions(-) delete mode 100644 dall_e/new.py diff --git a/dall_e/encoder.py b/dall_e/encoder.py index 712f2f5..89f9d6c 100644 --- a/dall_e/encoder.py +++ b/dall_e/encoder.py @@ -60,27 +60,25 @@ def __attrs_post_init__(self) -> None: requires_grad=self.requires_grad) self.blocks = nn.Sequential(OrderedDict([ - ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), - ('group_1', nn.Sequential(OrderedDict([ - *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_2', nn.Sequential(OrderedDict([ - *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_3', nn.Sequential(OrderedDict([ - *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_4', nn.Sequential(OrderedDict([ - *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], - ]))), - ('output', nn.Sequential(OrderedDict([ - ('relu', nn.ReLU()), - ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), - ]))), - ])) + ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), + *[ + ( + f'group_{i + 1}', + nn.Sequential( + OrderedDict([ + *[(f'block_{j + 1}', make_blk((2 ** i) * self.n_hid if j == 0 else (2 ** (i + 1)) * self.n_hid, (2 ** (i + 1)) * self.n_hid)) for j in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]) + ) + ) + for i in range(self.group_count) + ], + ('output', nn.Sequential(OrderedDict([ + ('relu', nn.ReLU()), + ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), + ]))), +])) + def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 4: diff --git a/dall_e/new.py b/dall_e/new.py deleted file mode 100644 index 4a32155..0000000 --- a/dall_e/new.py +++ /dev/null @@ -1,51 +0,0 @@ -import attr -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from collections import OrderedDict -from functools import partial -from dall_e.utils import Conv2d - - -@attr.s(eq=False, repr=False) -class EncoderBlock(nn.Module): - n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) - n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0) - n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) - - device: torch.device = attr.ib(default=None) - requires_grad: bool = attr.ib(default=False) - - def __attrs_post_init__(self) -> None: - super().__init__() - self.n_hid = self.n_out // 4 - self.post_gain = 1 / (self.n_layers ** 2) - - make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) - self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() - self.res_path = nn.Sequential( - OrderedDict([ - ('relu_1', nn.ReLU()), - ('conv_1', make_conv(self.n_in, self.n_hid, 3)), - ('relu_2', nn.ReLU()), - ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), - ('relu_3', nn.ReLU()), - ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), - ('relu_4', nn.ReLU()), - ('conv_4', make_conv(self.n_hid, self.n_out, 1)), - ]) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.id_path(x) + self.post_gain * self.res_path(x) - - -@attr.s(eq=False, repr=False) -class Encoder(nn.Module): - group_count: int = 4 - n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) - n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) - input_channels: int = attr.ib(default=3, validator=lambda