Skip to content

Commit

Permalink
fix: fixed var scaler passed to func calls
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Sep 5, 2024
1 parent cd0e112 commit fb16bcc
Show file tree
Hide file tree
Showing 4 changed files with 441 additions and 315 deletions.
714 changes: 420 additions & 294 deletions evaluate.ipynb

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def __call__(self, x):
class TimeProjection(nn.Module):
features:int
activation:Callable=jax.nn.gelu
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x):
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
x = self.activation(x)
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
x = self.activation(x)
return x

Expand All @@ -123,7 +123,7 @@ class SeparableConv(nn.Module):
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
use_bias:bool=False
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)
padding:str="SAME"
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
Expand All @@ -133,15 +133,15 @@ def __call__(self, x):
in_features = x.shape[-1]
depthwise = nn.Conv(
features=in_features, kernel_size=self.kernel_size,
strides=self.strides, kernel_init=self.kernel_init(),
strides=self.strides, kernel_init=self.kernel_init,
feature_group_count=in_features, use_bias=self.use_bias,
padding=self.padding,
dtype=self.dtype,
precision=self.precision
)(x)
pointwise = nn.Conv(
features=self.features, kernel_size=(1, 1),
strides=(1, 1), kernel_init=self.kernel_init(),
strides=(1, 1), kernel_init=self.kernel_init,
use_bias=self.use_bias,
dtype=self.dtype,
precision=self.precision
Expand All @@ -153,7 +153,7 @@ class ConvLayer(nn.Module):
features:int
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

Expand All @@ -164,7 +164,7 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -183,7 +183,7 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -192,7 +192,7 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -206,7 +206,7 @@ class Upsample(nn.Module):
activation:Callable=jax.nn.swish
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x, residual=None):
Expand All @@ -221,7 +221,7 @@ def __call__(self, x, residual=None):
strides=(1, 1),
dtype=self.dtype,
precision=self.precision,
kernel_init=self.kernel_init()
kernel_init=self.kernel_init
)(out)
if residual is not None:
out = jnp.concatenate([out, residual], axis=-1)
Expand All @@ -233,7 +233,7 @@ class Downsample(nn.Module):
activation:Callable=jax.nn.swish
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x, residual=None):
Expand All @@ -244,7 +244,7 @@ def __call__(self, x, residual=None):
strides=(2, 2),
dtype=self.dtype,
precision=self.precision,
kernel_init=self.kernel_init()
kernel_init=self.kernel_init
)(x)
if residual is not None:
if residual.shape[1] > out.shape[1]:
Expand All @@ -269,7 +269,7 @@ class ResidualBlock(nn.Module):
direction:str=None
res:int=2
norm_groups:int=8
kernel_init:Callable=partial(kernel_init, 1.0)
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
named_norms:bool=False
Expand All @@ -296,7 +296,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
name="conv1",
dtype=self.dtype,
precision=self.precision
Expand All @@ -321,7 +321,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
name="conv2",
dtype=self.dtype,
precision=self.precision
Expand All @@ -333,7 +333,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=(1, 1),
strides=1,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
name="residual_conv",
dtype=self.dtype,
precision=self.precision
Expand Down
4 changes: 2 additions & 2 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __call__(self, x, temb, textcontext=None):
# Middle block
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init())(x)

Expand All @@ -124,7 +124,7 @@ def __call__(self, x, temb, textcontext=None):
dtype=self.dtype, precision=self.precision)(skip)
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init())(skip)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.24',
version='0.1.25',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit fb16bcc

Please sign in to comment.