Skip to content

Commit

Permalink
fix (backends)(jax)(module.py): fixing the implementation for train
Browse files Browse the repository at this point in the history
… and `eval` methods
  • Loading branch information
YushaArif99 committed Sep 18, 2024
1 parent 34bc324 commit 17657d2
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,24 @@ def register_parameter(self, name: str, value: jax.Array):
def train(self, mode: bool = True):
self._training = mode
for module in self.children():
if isinstance(module, nn.Module) and not hasattr(module, "train"):
if isinstance(module, Module):
module.trainable = mode
continue
module.train(mode)

super().train()
self.trainable = mode
return self

def eval(self):
return self.train(mode=False)
def eval(
self,
):
self._training = False
for module in self.children():
if isinstance(module, Module):
module.trainable = False

super().eval()
self.trainable = False
return self

def call(self, inputs, training=None, mask=None):
raise NotImplementedError(
Expand Down

0 comments on commit 17657d2

Please sign in to comment.