Skip to content

Commit

Permalink
feat: completed adding DDIM, DDPM and Euler samplers + training
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Jul 8, 2024
1 parent 69490a1 commit ea41d03
Show file tree
Hide file tree
Showing 4 changed files with 15,773 additions and 402 deletions.
1,452 changes: 1,096 additions & 356 deletions Diffusion flax linen.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ![](images/logo.jpeg "FlaxDiff")

## A Versatile and Easy-to-Understand Diffusion Library
## A Versatile and simple Diffusion Library

In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.

Expand Down
14,714 changes: 14,674 additions & 40 deletions example notebooks/simple diffusion flax.ipynb

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ def __call__(self, x):

class TimeEmbedding(nn.Module):
features:int
max_timesteps:int=10000
nax_positions:int=10000

def setup(self):
# self.embeddings = nn.Embed(
# num_embeddings=max_timesteps, features=out_features
# )
half_dim = self.features // 2
emb = jnp.log(self.max_timesteps) / (half_dim - 1)
emb = jnp.log(self.nax_positions) / (half_dim - 1)
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
self.embeddings = emb

Expand Down

0 comments on commit ea41d03

Please sign in to comment.