Skip to content

Commit

Permalink
fix: fixed a basic mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Sep 8, 2024
1 parent 4ddb53c commit 6b9b4a4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __call__(self, x, temb, textcontext=None):

# Patch embedding
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
dtype=self.dtype, precision=self.precision)(x)
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
num_patches = x.shape[1]

context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
Expand All @@ -98,7 +98,7 @@ def __call__(self, x, temb, textcontext=None):
# print(f'Shape of x after time embedding: {x.shape}')

# Add positional encoding
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features, kernel_init=self.kernel_init)(x)
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)

# print(f'Shape of x after positional encoding: {x.shape}')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.26',
version='0.1.27',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 6b9b4a4

Please sign in to comment.