Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jung235 committed Jun 23, 2024
1 parent 1b4000b commit 4ba7cd7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
1 change: 0 additions & 1 deletion pydiffuser/models/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def pre_generate(self, *generate_args) -> None:
)

jax.config.update("jax_platform_name", "cpu") # TODO
jax.config.update("jax_traceback_filtering", "off")
jax.config.update("jax_enable_x64", self.precision_x64)
if self.precision_x64:
logger.debug(
Expand Down
15 changes: 7 additions & 8 deletions pydiffuser/utils/jitted.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,15 @@ def get_noise(
try:
noise = generator(size=size) # type: ignore[call-arg]
except Exception as exc:
if isinstance(exc, TypeError):
module = inspect.getmodule(generator)
if module is None:
noise = generator(size)
elif "jax" in module.__name__:
raise NotImplementedError(
"Random number generator via JAX is unsupported"
) from exc
if "rand" in generator.__name__:
noise = generator(size)
elif "jax" in inspect.getmodule(generator).__name__: # type: ignore[union-attr]
raise NotImplementedError(
"Random number generator via JAX is unsupported"
) from exc
else:
raise RuntimeError(f"{exc}") from exc

noise = jnp.array(noise)
if shape is not None:
noise = noise.reshape(shape)
Expand Down

0 comments on commit 4ba7cd7

Please sign in to comment.