Skip to content

Commit

Permalink
feat: reverted to what used to work
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 7, 2024
1 parent 85606d9 commit 4798f53
Show file tree
Hide file tree
Showing 7 changed files with 1,828 additions and 1,046 deletions.
1,216 changes: 318 additions & 898 deletions evaluate.ipynb

Large diffs are not rendered by default.

61 changes: 1 addition & 60 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,65 +162,6 @@ def __call__(self, x, context=None):
proj = proj.reshape(orig_x_shape)
return proj

class BasicTransformerBlock(nn.Module):
# Has self and cross attention
query_dim: int
heads: int = 4
dim_head: int = 64
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
use_flash_attention:bool = False
use_cross_only:bool = False

def setup(self):
if self.use_flash_attention:
attenBlock = EfficientAttention
else:
attenBlock = NormalAttention

self.attention1 = attenBlock(
query_dim=self.query_dim,
heads=self.heads,
dim_head=self.dim_head,
name=f'Attention1',
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
)
self.attention2 = attenBlock(
query_dim=self.query_dim,
heads=self.heads,
dim_head=self.dim_head,
name=f'Attention2',
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
)

self.ff = FlaxFeedForward(dim=self.query_dim)
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)

@nn.compact
def __call__(self, hidden_states, context=None):
# self attention
if not self.use_cross_only:
print("Using self attention")
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))

# cross attention
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)

# feed forward
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))

return hidden_states

class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
Expand Down Expand Up @@ -330,7 +271,7 @@ def setup(self):
@nn.compact
def __call__(self, hidden_states, context=None):
if self.only_pure_attention:
return self.attention2(self.norm2(hidden_states), context)
return self.attention2(hidden_states, context)

# self attention
if not self.use_cross_only:
Expand Down
8 changes: 4 additions & 4 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ class ResidualBlock(nn.Module):
@nn.compact
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
residual = x
# out = nn.GroupNorm(self.norm_groups)(x)
out = nn.RMSNorm()(x)
out = nn.GroupNorm(self.norm_groups)(x)
# out = nn.RMSNorm()(x)
out = self.activation(out)

out = ConvLayer(
Expand All @@ -295,8 +295,8 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
# out = out * (1 + scale) + shift
out = out + temb

# out = nn.GroupNorm(self.norm_groups)(out)
out = nn.RMSNorm()(out)
out = nn.GroupNorm(self.norm_groups)(out)
# out = nn.RMSNorm()(out)
out = self.activation(out)

out = ConvLayer(
Expand Down
6 changes: 3 additions & 3 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, x, temb, textcontext):
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
dim_head=dim_in // attention_config['heads'],
use_flash_attention=attention_config.get("flash_attention", True),
use_flash_attention=attention_config.get("flash_attention", False),
use_projection=attention_config.get("use_projection", False),
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
Expand Down Expand Up @@ -103,7 +103,7 @@ def __call__(self, x, temb, textcontext):
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
dim_head=middle_dim_out // middle_attention['heads'],
use_flash_attention=middle_attention.get("flash_attention", True),
use_flash_attention=middle_attention.get("flash_attention", False),
use_linear_attention=False,
use_projection=middle_attention.get("use_projection", False),
use_self_and_cross=False,
Expand Down Expand Up @@ -146,7 +146,7 @@ def __call__(self, x, temb, textcontext):
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
dim_head=dim_out // attention_config['heads'],
use_flash_attention=attention_config.get("flash_attention", True),
use_flash_attention=attention_config.get("flash_attention", False),
use_projection=attention_config.get("use_projection", False),
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
Expand Down
Loading

0 comments on commit 4798f53

Please sign in to comment.