Skip to content

Commit

Permalink
feat: added some more descriptions and code
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Jul 6, 2024
1 parent 4f31818 commit 129fba4
Show file tree
Hide file tree
Showing 4 changed files with 3,284 additions and 54 deletions.
4 changes: 2 additions & 2 deletions differential equation tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
}
],
"source": [
"# Simple diffusion equation for constant speed\n",
"# Simple differential equation for constant speed\n",
"# dX = c dt + X0\n",
"def dX(dt: float, X0: float, c: float = 1, **params) -> float:\n",
" return c * dt + X0\n",
Expand Down Expand Up @@ -104,7 +104,7 @@
}
],
"source": [
"# Simple diffusion equation for increasing speed with constant rate (constant acceleration)\n",
"# Simple differential equation for increasing speed with constant rate (constant acceleration)\n",
"# dX = X0 + t * dt\n",
"def dX(dt: float, X0: float, t: float, c: float = 1, **params) -> float:\n",
" return X0 + t * c * dt \n",
Expand Down
3,314 changes: 3,274 additions & 40 deletions example notebooks/ddpm flax.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def __call__(self, x):
)(depthwise)
return pointwise


class ConvLayer(nn.Module):
conv_type:str
features:int
Expand Down
19 changes: 8 additions & 11 deletions flaxdiff/schedulers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@ def __init__(self, timesteps,
self.dtype = dtype
self.clip_min = clip_min
self.clip_max = clip_max
if type(timesteps) == int and timesteps > 1:
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
else:
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
self.timestep_generator = timestep_generator

def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
raise NotImplementedError
state, rng = state.get_random_key()
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
return timesteps, state

def get_weights(self, steps):
raise NotImplementedError
Expand Down Expand Up @@ -65,16 +72,6 @@ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80.0, sigma_data=1, *ar
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
if type(timesteps) == int and timesteps > 1:
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
else:
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
self.timestep_generator = timestep_generator

def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
state, rng = state.get_random_key()
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
return timesteps, state

def get_weights(self, steps, shape=(-1, 1, 1, 1)):
sigma = self.get_sigmas(steps)
Expand Down

0 comments on commit 129fba4

Please sign in to comment.