Skip to content

Commit

Permalink
Abstract determination of steps per epoch in function
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Oct 31, 2023
1 parent 2611420 commit e97ec1c
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,25 +169,34 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
y = self.target_tensors

# Avoids Tensorflow overhead that happens at every epoch, by putting multiple steps in an epoch
num_repeats = 100
num_replicas = self.output_shape[-1] # This only matters on GPU, with multiple replicas
if epochs % 10 != 0:
num_repeats = 1
if num_replicas > 1:
log.warning("Epochs not divisible by 10, using 1 epoch per step")
elif epochs % 100 != 0:
num_repeats = 10
if num_replicas > 1:
log.warning("Epochs not divisible by 100, using 1 epoch per step")
steps_per_epoch = self.determine_steps_per_epoch(epochs)

for k, v in x_params.items():
x_params[k] = tf.repeat(v, num_repeats, axis=0)
y = [tf.repeat(yi, num_repeats, axis=0) for yi in y]
x_params[k] = tf.repeat(v, steps_per_epoch, axis=0)
y = [tf.repeat(yi, steps_per_epoch, axis=0) for yi in y]

history = super().fit(x=x_params, y=y, epochs=epochs // num_repeats, batch_size=1, **kwargs)
history = super().fit(
x=x_params, y=y, epochs=epochs // steps_per_epoch, batch_size=1, **kwargs
)
loss_dict = history.history
return loss_dict

def determine_steps_per_epoch(self, epochs):
num_replicas = self.output_shape[0][0]
# in this case we're most likely running on the CPU and this is not worth it
if num_replicas == 1:
return 1

# On the GPU, run with
for divisor in [10, 100]:
if epochs % divisor != 0:
steps_per_epoch = divisor // 10
log.warning(
f"Epochs {epochs} not divisible by {divisor}, using {steps_per_epoch} steps per epoch"
)
return steps_per_epoch
return 100

def predict(self, x=None, **kwargs):
"""Call super().predict with the right input arguments"""
x = self._parse_input(x)
Expand Down

0 comments on commit e97ec1c

Please sign in to comment.