From 4ba7cd7ae2f06819eb7dac93e46d346704d22ed7 Mon Sep 17 00:00:00 2001 From: jung235 Date: Sun, 23 Jun 2024 23:36:33 +0900 Subject: [PATCH] fix --- pydiffuser/models/core/base.py | 1 - pydiffuser/utils/jitted.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pydiffuser/models/core/base.py b/pydiffuser/models/core/base.py index cfba9b4..9967c12 100644 --- a/pydiffuser/models/core/base.py +++ b/pydiffuser/models/core/base.py @@ -86,7 +86,6 @@ def pre_generate(self, *generate_args) -> None: ) jax.config.update("jax_platform_name", "cpu") # TODO - jax.config.update("jax_traceback_filtering", "off") jax.config.update("jax_enable_x64", self.precision_x64) if self.precision_x64: logger.debug( diff --git a/pydiffuser/utils/jitted.py b/pydiffuser/utils/jitted.py index 2908d0c..fc69cdb 100644 --- a/pydiffuser/utils/jitted.py +++ b/pydiffuser/utils/jitted.py @@ -46,16 +46,15 @@ def get_noise( try: noise = generator(size=size) # type: ignore[call-arg] except Exception as exc: - if isinstance(exc, TypeError): - module = inspect.getmodule(generator) - if module is None: - noise = generator(size) - elif "jax" in module.__name__: - raise NotImplementedError( - "Random number generator via JAX is unsupported" - ) from exc + if "rand" in generator.__name__: + noise = generator(size) + elif "jax" in inspect.getmodule(generator).__name__: # type: ignore[union-attr] + raise NotImplementedError( + "Random number generator via JAX is unsupported" + ) from exc else: raise RuntimeError(f"{exc}") from exc + noise = jnp.array(noise) if shape is not None: noise = noise.reshape(shape)