Skip to content

Commit

Permalink
Avoid TensorFlow overhead by making one step a batch rather than an
Browse files Browse the repository at this point in the history
epoch.

Avoids memory overhead by only combining up to 100 steps into one epoch,
and not changing anything when using only 1 replica (i.e. on CPU).
  • Loading branch information
APJansen committed Mar 6, 2024
1 parent 0a5fc61 commit bb366aa
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 12 deletions.
36 changes: 34 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
backend-dependent calls.
"""

import logging
import re
import shutil

Expand Down Expand Up @@ -164,10 +165,36 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
x_params = self._parse_input(x)
if y is None:
y = self.target_tensors
history = super().fit(x=x_params, y=y, epochs=epochs, **kwargs)

# Avoids Tensorflow overhead that happens at every epoch, by putting multiple steps in an epoch
steps_per_epoch = self.determine_steps_per_epoch(epochs)

for k, v in x_params.items():
x_params[k] = tf.repeat(v, steps_per_epoch, axis=0)
y = [tf.repeat(yi, steps_per_epoch, axis=0) for yi in y]

history = super().fit(
x=x_params, y=y, epochs=epochs // steps_per_epoch, batch_size=1, **kwargs
)
loss_dict = history.history
return loss_dict

def determine_steps_per_epoch(self, epochs):
num_replicas = self.output_shape[0][0]
# in this case we're most likely running on the CPU and this is not worth it
if num_replicas == 1:
return 1

# On the GPU, run with
for divisor in [10, 100]:
if epochs % divisor != 0:
steps_per_epoch = divisor // 10
log.warning(
f"Epochs {epochs} not divisible by {divisor}, using {steps_per_epoch} steps per epoch"
)
return steps_per_epoch
return 100

def predict(self, x=None, **kwargs):
"""Call super().predict with the right input arguments"""
x = self._parse_input(x)
Expand All @@ -193,10 +220,15 @@ def compute_losses(self):
out_names = [f"{i}_loss" for i in self.output_names]
out_names.insert(0, "loss")

inputs = self._parse_input(None)
# get rid of the repetitions by number of epochs made in perform_fit
for k, v in inputs.items():
inputs[k] = v[:1]

# Compile a evaluation function
@tf.function
def losses_fun():
predictions = self(self._parse_input(None))
predictions = self(inputs)
# If we only have one dataset the output changes
if len(out_names) == 2:
predictions = [predictions]
Expand Down
58 changes: 51 additions & 7 deletions n3fit/src/n3fit/backends/keras_backend/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
The callbacks defined in this module can be passed to the ``callbacks`` argument
of the ``perform_fit`` method as a list.
For the most typical usage: ``on_epoch_end``,
For the most typical usage: ``on_batch_end``,
they must take as input an epoch number and a log of the partial losses.
Note: the terminology used everywhere refers to a single training step as a single epoch.
It turns out that to avoid tensorflow overhead, it is beneficial to write a step as a
single batch instead. So callbacks must use ``on_batch_end``.
"""

import logging
Expand All @@ -18,7 +22,46 @@
log = logging.getLogger(__name__)


class TimerCallback(Callback):
class CallbackStep(Callback):
"""
Wrapper around the keras Callback that keeps track of how the steps are divided
between epochs and batches.
The callback will call ``on_step_end`` instead of ``on_batch_end``.
"""

def __init__(self):
super().__init__()
self.steps_in_epoch = 0
self.epochs_finished = 0
self.steps_per_epoch = 0 # will be defined in the first epoch
self._previous_logs = {}

def on_epoch_end(self, epoch, logs=None):
if self.steps_per_epoch == 0:
self.steps_per_epoch = self.steps_in_epoch
self.steps_in_epoch = 0
self.epochs_finished += 1

def on_batch_end(self, batch, logs=None):
step_number = self.steps_in_epoch + self.epochs_finished * self.steps_per_epoch
self.on_step_end(step_number, logs)
self.steps_in_epoch += 1

def correct_logs(self, logs: dict) -> dict:
"""
The logs that get computed by default are an average over batches.
This converts it into the logs for the current step.
"""
corrected_logs = {}
for k in logs.keys():
previous_total = self._previous_logs.get(k, 0.0) * self.steps_in_epoch
current_total = logs[k] * (self.steps_in_epoch + 1)
corrected_logs[k] = current_total - previous_total
self._previous_logs = logs
return corrected_logs


class TimerCallback(CallbackStep):
"""Callback to be used during debugging to time the fit"""

def __init__(self, count_range=100):
Expand All @@ -30,7 +73,7 @@ def __init__(self, count_range=100):
self.starting_time = None
self.last_time = 0

def on_epoch_end(self, epoch, logs=None):
def on_step_end(self, epoch, logs=None):
"""At the end of every epoch it checks the time"""
new_time = time()
if epoch == 0:
Expand All @@ -57,7 +100,7 @@ def on_train_end(self, logs=None):
log.info(f"> > > Total time: {total_time/60:.5} min")


class StoppingCallback(Callback):
class StoppingCallback(CallbackStep):
"""
Given a ``stopping_object``, the callback will monitor the validation chi2
and will stop the training model when the conditions given by ``stopping_object``
Expand All @@ -77,14 +120,15 @@ def __init__(self, stopping_object, log_freq=100):
self.log_freq = log_freq
self.stopping_object = stopping_object

def on_epoch_end(self, epoch, logs=None):
def on_step_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch
Every ``log_freq`` number of epochs, the ``monitor_chi2`` method of the ``stopping_object``
will be called and the validation loss (broken down by experiment) will be logged.
For the training model only the total loss is logged during the training.
"""
print_stats = ((epoch + 1) % self.log_freq) == 0
# Note that the input logs correspond to the fit before the weights are updated
logs = self.correct_logs(logs)
self.stopping_object.monitor_chi2(logs, epoch, print_stats=print_stats)
if self.stopping_object.stop_here():
self.model.stop_training = True
Expand All @@ -97,7 +141,7 @@ def on_train_end(self, logs=None):
self.stopping_object.make_stop()


class LagrangeCallback(Callback):
class LagrangeCallback(CallbackStep):
"""
Updates the given datasets
with its respective multipliers each ``update_freq`` epochs
Expand Down Expand Up @@ -137,7 +181,7 @@ def _update_weights(self):
for w in ws:
w.assign(w * multiplier)

def on_epoch_end(self, epoch, logs=None):
def on_step_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch"""
if (epoch + 1) % self.update_freq == 0:
self._update_weights()
Expand Down
6 changes: 3 additions & 3 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,12 @@ def print_current_stats(self, epoch, fitstate):
"""
epoch_index = epoch + 1
vl_chi2 = fitstate.total_vl_chi2()
total_str = f"""Epoch {epoch_index}/{self.total_epochs}: loss: {fitstate.tr_loss:.7f}
Validation loss after training step: {vl_chi2:.7f}.
Validation chi2s: """
total_str = f"Epoch {epoch_index}/{self.total_epochs}: loss: {fitstate.tr_loss:.7f}"
total_str += f"\nValidation loss after training step: {vl_chi2:.7f}."

# The partial chi2 makes no sense for more than one replica at once:
if self._n_replicas == 1:
total_str += "\nValidation chi2s: "
partial_vl_chi2 = fitstate.total_partial_vl_chi2()
partials = []
for experiment, chi2 in partial_vl_chi2.items():
Expand Down

0 comments on commit bb366aa

Please sign in to comment.