Skip to content

Commit

Permalink
disable jit compilation in tf > 2.16
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Jul 25, 2024
1 parent 63c1282 commit e7fd0ca
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
else: # in case of disaster
_to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()}

JIT_COMPILE = "auto"
if tf.__version__ >= "2.16":
# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170
# makes jit compilation unusable in GPU. "auto" is the default value.
JIT_COMPILE = False


# Define in this dictionary new optimizers as well as the arguments they accept
# (with default values if needed be)
Expand Down Expand Up @@ -307,7 +313,7 @@ def compile(
target_output = [target_output]
self.target_tensors = target_output

super().compile(optimizer=opt, loss=loss)
super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE)

def set_masks_to(self, names, val=0.0):
"""Set all mask value to the selected value
Expand Down

0 comments on commit e7fd0ca

Please sign in to comment.