Skip to content

Commit

Permalink
Avoid TensorFlow overhead by making one step a batch rather than an
Browse files Browse the repository at this point in the history
epoch.

Avoids memory overhead by only combining up to 100 steps into one epoch,
and not changing anything when using only 1 replica (i.e. on CPU).
  • Loading branch information
APJansen committed Feb 28, 2024
1 parent d122a8d commit 9032cdf
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 17 deletions.
38 changes: 36 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
backend-dependent calls.
"""

import logging
import re

import h5py
Expand All @@ -16,6 +17,8 @@

import n3fit.backends.keras_backend.operations as op

log = logging.getLogger(__name__)

# Check the TF version to check if legacy-mode is needed (TF < 2.2)
tf_version = tf.__version__.split(".")
if int(tf_version[0]) == 2 and int(tf_version[1]) < 2:
Expand Down Expand Up @@ -170,10 +173,36 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
x_params = self._parse_input(x)
if y is None:
y = self.target_tensors
history = super().fit(x=x_params, y=y, epochs=epochs, **kwargs)

# Avoids Tensorflow overhead that happens at every epoch, by putting multiple steps in an epoch
steps_per_epoch = self.determine_steps_per_epoch(epochs)

for k, v in x_params.items():
x_params[k] = tf.repeat(v, steps_per_epoch, axis=0)
y = [tf.repeat(yi, steps_per_epoch, axis=0) for yi in y]

history = super().fit(
x=x_params, y=y, epochs=epochs // steps_per_epoch, batch_size=1, **kwargs
)
loss_dict = history.history
return loss_dict

def determine_steps_per_epoch(self, epochs):
num_replicas = self.output_shape[0][0]
# in this case we're most likely running on the CPU and this is not worth it
if num_replicas == 1:
return 1

# On the GPU, run with
for divisor in [10, 100]:
if epochs % divisor != 0:
steps_per_epoch = divisor // 10
log.warning(
f"Epochs {epochs} not divisible by {divisor}, using {steps_per_epoch} steps per epoch"
)
return steps_per_epoch
return 100

def predict(self, x=None, **kwargs):
"""Call super().predict with the right input arguments"""
x = self._parse_input(x)
Expand All @@ -199,10 +228,15 @@ def compute_losses(self):
out_names = [f"{i}_loss" for i in self.output_names]
out_names.insert(0, "loss")

inputs = self._parse_input(None)
# get rid of the repetitions by number of epochs made in perform_fit
for k, v in inputs.items():
inputs[k] = v[:1]

# Compile a evaluation function
@tf.function
def losses_fun():
predictions = self(self._parse_input(None))
predictions = self(inputs)
# If we only have one dataset the output changes
if len(out_names) == 2:
predictions = [predictions]
Expand Down
75 changes: 60 additions & 15 deletions n3fit/src/n3fit/backends/keras_backend/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,64 @@
The callbacks defined in this module can be passed to the ``callbacks`` argument
of the ``perform_fit`` method as a list.
For the most typical usage: ``on_epoch_end``,
For the most typical usage: ``on_batch_end``,
they must take as input an epoch number and a log of the partial losses.
Note: the terminology used everywhere refers to a single training step as a single epoch.
It turns out that to avoid tensorflow overhead, it is beneficial to write a step as a
single batch instead. So callbacks must use ``on_batch_end``.
"""

import logging
from time import time

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard, Callback
from tensorflow.keras.callbacks import Callback, TensorBoard

log = logging.getLogger(__name__)


class TimerCallback(Callback):
class CallbackStep(Callback):
"""
Wrapper around the keras Callback that keeps track of how the steps are divided
between epochs and batches.
The callback will call ``on_step_end`` instead of ``on_batch_end``.
"""

def __init__(self):
super().__init__()
self.steps_in_epoch = 0
self.epochs_finished = 0
self.steps_per_epoch = 0 # will be defined in the first epoch
self._previous_logs = {}

def on_epoch_end(self, epoch, logs=None):
if self.steps_per_epoch == 0:
self.steps_per_epoch = self.steps_in_epoch
self.steps_in_epoch = 0
self.epochs_finished += 1

def on_batch_end(self, batch, logs=None):
step_number = self.steps_in_epoch + self.epochs_finished * self.steps_per_epoch
self.on_step_end(step_number, logs)
self.steps_in_epoch += 1

def correct_logs(self, logs: dict) -> dict:
"""
The logs that get computed by default are an average over batches.
This converts it into the logs for the current step.
"""
corrected_logs = {}
for k in logs.keys():
previous_total = self._previous_logs.get(k, 0.0) * self.steps_in_epoch
current_total = logs[k] * (self.steps_in_epoch + 1)
corrected_logs[k] = current_total - previous_total
self._previous_logs = logs
return corrected_logs


class TimerCallback(CallbackStep):
"""Callback to be used during debugging to time the fit"""

def __init__(self, count_range=100):
Expand All @@ -29,8 +73,8 @@ def __init__(self, count_range=100):
self.starting_time = None
self.last_time = 0

def on_epoch_end(self, epoch, logs=None):
""" At the end of every epoch it checks the time """
def on_step_end(self, epoch, logs=None):
"""At the end of every epoch it checks the time"""
new_time = time()
if epoch == 0:
# The first epoch is only useful for starting
Expand All @@ -45,18 +89,18 @@ def on_epoch_end(self, epoch, logs=None):
self.last_time = new_time

def on_train_end(self, logs=None):
""" Print the results """
"""Print the results"""
total_time = time() - self.starting_time
n_times = len(self.all_times)
# Skip the first 100 epochs to avoid fluctuations due to compilations of part of the code
# by epoch 100 all parts of the code have usually been called so it's a good compromise
mean = np.mean(self.all_times[min(110, n_times-1):])
std = np.std(self.all_times[min(110, n_times-1):])
mean = np.mean(self.all_times[min(110, n_times - 1) :])
std = np.std(self.all_times[min(110, n_times - 1) :])
log.info(f"> > Average time per epoch: {mean:.5} +- {std:.5} s")
log.info(f"> > > Total time: {total_time/60:.5} min")


class StoppingCallback(Callback):
class StoppingCallback(CallbackStep):
"""
Given a ``stopping_object``, the callback will monitor the validation chi2
and will stop the training model when the conditions given by ``stopping_object``
Expand All @@ -76,10 +120,11 @@ def __init__(self, stopping_object, log_freq=100):
self.log_freq = log_freq
self.stopping_object = stopping_object

def on_epoch_end(self, epoch, logs=None):
""" Function to be called at the end of every epoch """
def on_step_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch"""
print_stats = ((epoch + 1) % self.log_freq) == 0
# Note that the input logs correspond to the fit before the weights are updated
logs = self.correct_logs(logs)
self.stopping_object.monitor_chi2(logs, epoch, print_stats=print_stats)
if self.stopping_object.stop_here():
self.model.stop_training = True
Expand All @@ -92,7 +137,7 @@ def on_train_end(self, logs=None):
self.stopping_object.make_stop()


class LagrangeCallback(Callback):
class LagrangeCallback(CallbackStep):
"""
Updates the given datasets
with its respective multipliers each ``update_freq`` epochs
Expand All @@ -117,7 +162,7 @@ def __init__(self, datasets, multipliers, update_freq=100):
self.updateable_weights = []

def on_train_begin(self, logs=None):
""" Save an instance of all relevant layers """
"""Save an instance of all relevant layers"""
for layer_name in self.datasets:
layer = self.model.get_layer(layer_name)
self.updateable_weights.append(layer.weights)
Expand All @@ -132,8 +177,8 @@ def _update_weights(self):
for w in ws:
w.assign(w * multiplier)

def on_epoch_end(self, epoch, logs=None):
""" Function to be called at the end of every epoch """
def on_step_end(self, epoch, logs=None):
"""Function to be called at the end of every epoch"""
if (epoch + 1) % self.update_freq == 0:
self._update_weights()

Expand Down

0 comments on commit 9032cdf

Please sign in to comment.