Skip to content

Commit

Permalink
rebase + small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Jun 6, 2024
1 parent cfd680a commit 83f252a
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
y = self.target_tensors

# Avoids Tensorflow overhead that happens at every epoch, by putting multiple steps in an epoch
steps_per_epoch = self.determine_steps_per_epoch(epochs)
steps_per_epoch = self._determine_steps_per_epoch(epochs)

for k, v in x_params.items():
x_params[k] = tf.repeat(v, steps_per_epoch, axis=0)
Expand All @@ -178,20 +178,19 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
loss_dict = history.history
return loss_dict

def determine_steps_per_epoch(self, epochs):
num_replicas = self.output_shape[0][0]
# in this case we're most likely running on the CPU and this is not worth it
if num_replicas == 1:
def _determine_steps_per_epoch(self, epochs):
"""Determine how many step to run in every epoch.
When running a single replica (CPU) or when the number of epochs is < 100 default to 1.
Otherwise run 100 step per epoch.
If the number of epochs requested is not divisible by 100 there will be a number
of extra training epochs being run equal to max_epochs % 100 in the worst case.
"""
num_replicas = self.output_shape[0]
if num_replicas == 1 or epochs < 100:
return 1

# On the GPU, run with
for divisor in [10, 100]:
if epochs % divisor != 0:
steps_per_epoch = divisor // 10
log.warning(
f"Epochs {epochs} not divisible by {divisor}, using {steps_per_epoch} steps per epoch"
)
return steps_per_epoch
return 100

def predict(self, x=None, **kwargs):
Expand Down

0 comments on commit 83f252a

Please sign in to comment.