Skip to content

Commit

Permalink
refactor (backends)(jax)(module.py): renaming nn alias with nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
YushaArif99 committed Sep 18, 2024
1 parent 17657d2 commit 55095e7
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations
import re
import jax
from flax import nnx as nn
import flax.nnx as nnx
import jax.tree_util as tree
import jax.numpy as jnp
import functools
Expand Down Expand Up @@ -370,7 +370,7 @@ def _addindent(s_, numSpaces):
return s


class Module(nn.Module, ModelHelpers):
class Module(nnx.Module, ModelHelpers):
_build_mode = None
_with_partial_v = None
_store_vars = True
Expand Down Expand Up @@ -654,7 +654,7 @@ def __getattr__(self, name):
def _compute_module_dict(self):
self._module_dict = dict()
for key, value in self.__dict__.items():
if isinstance(value, (Module, nn.Module)):
if isinstance(value, (Module, nnx.Module)):
if (
"stateful" in value.__module__
or hasattr(value, "_frontend_module")
Expand All @@ -667,7 +667,7 @@ def _compute_module_dict(self):
def __setattr__(self, name, value):
if name in ["v", "buffers"]:
name = "_" + name
if isinstance(value, (Module, nn.Module)):
if isinstance(value, (Module, nnx.Module)):
_dict = getattr(self, "__dict__", None)
if _dict:
_dict[name] = value
Expand All @@ -677,7 +677,7 @@ def __setattr__(self, name, value):

obj_to_search = (
None
if not isinstance(value, (nn.Module, Module))
if not isinstance(value, (nnx.Module, Module))
else (
self._modules
if hasattr(self, "_modules") and self._modules
Expand Down Expand Up @@ -729,7 +729,7 @@ def __setattr__(self, name, value):
obj_to_search = getattr(self, name)
except AttributeError:
obj_to_search = None
if isinstance(obj_to_search, (nn.Module)):
if isinstance(obj_to_search, (nnx.Module)):
# retrieve all hierarchical submodules
assign_dict, kc = get_assignment_dict()

Expand Down Expand Up @@ -779,7 +779,7 @@ def _find_variables(
return getattr(obj, fn)(
*obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs
)
elif isinstance(obj, nn.Module) and obj is not self:
elif isinstance(obj, nnx.Module) and obj is not self:
return obj.v if trainable else obj.buffers

elif isinstance(obj, (list, tuple)):
Expand Down Expand Up @@ -942,7 +942,7 @@ def __delattr__(self, name):
if hasattr(self, name):
if isinstance(
getattr(self, name),
(Module, nn.Module),
(Module, nnx.Module),
):
super().__delattr__(name)
return
Expand Down

0 comments on commit 55095e7

Please sign in to comment.