diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index 31d63f416574..ae46a56dc2fb 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -2,7 +2,7 @@ from __future__ import annotations import re import jax -from flax import nnx as nn +import flax.nnx as nnx import jax.tree_util as tree import jax.numpy as jnp import functools @@ -370,7 +370,7 @@ def _addindent(s_, numSpaces): return s -class Module(nn.Module, ModelHelpers): +class Module(nnx.Module, ModelHelpers): _build_mode = None _with_partial_v = None _store_vars = True @@ -654,7 +654,7 @@ def __getattr__(self, name): def _compute_module_dict(self): self._module_dict = dict() for key, value in self.__dict__.items(): - if isinstance(value, (Module, nn.Module)): + if isinstance(value, (Module, nnx.Module)): if ( "stateful" in value.__module__ or hasattr(value, "_frontend_module") @@ -667,7 +667,7 @@ def _compute_module_dict(self): def __setattr__(self, name, value): if name in ["v", "buffers"]: name = "_" + name - if isinstance(value, (Module, nn.Module)): + if isinstance(value, (Module, nnx.Module)): _dict = getattr(self, "__dict__", None) if _dict: _dict[name] = value @@ -677,7 +677,7 @@ def __setattr__(self, name, value): obj_to_search = ( None - if not isinstance(value, (nn.Module, Module)) + if not isinstance(value, (nnx.Module, Module)) else ( self._modules if hasattr(self, "_modules") and self._modules @@ -729,7 +729,7 @@ def __setattr__(self, name, value): obj_to_search = getattr(self, name) except AttributeError: obj_to_search = None - if isinstance(obj_to_search, (nn.Module)): + if isinstance(obj_to_search, (nnx.Module)): # retrieve all hierarchical submodules assign_dict, kc = get_assignment_dict() @@ -779,7 +779,7 @@ def _find_variables( return getattr(obj, fn)( *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs ) - elif isinstance(obj, nn.Module) and obj is not self: + elif isinstance(obj, nnx.Module) and obj is not self: return obj.v if trainable else obj.buffers elif isinstance(obj, (list, tuple)): @@ -942,7 +942,7 @@ def __delattr__(self, name): if hasattr(self, name): if isinstance( getattr(self, name), - (Module, nn.Module), + (Module, nnx.Module), ): super().__delattr__(name) return