diff --git a/MarcDecoderWin/lib/callbacks.py b/MarcDecoderWin/lib/callbacks.py index bc50122..b16367a 100755 --- a/MarcDecoderWin/lib/callbacks.py +++ b/MarcDecoderWin/lib/callbacks.py @@ -13,10 +13,10 @@ def get_filename(params: Dict[str, Any]) -> str: if params["save_best_only"]: - filename = "weights.keras" + filename = "weights.h5" return filename - filename = f"weights-{{epoch}}-{{{params['monitor']}}}.keras" + filename = f"weights-{{epoch}}-{{{params['monitor']}}}.h5" return filename @@ -26,6 +26,7 @@ def get_callbacks(config: Config) -> Tuple[Callback]: checkpoint_filename = get_filename(params["checkpoint"]) model_checkpoint = ModelCheckpoint( filepath=str(config.checkpoint_dir / checkpoint_filename), + save_weights_only=True, **params["checkpoint"], ) diff --git a/MarcDecoderWin/lib/util.py b/MarcDecoderWin/lib/util.py index 5bbfd52..3aa0d03 100755 --- a/MarcDecoderWin/lib/util.py +++ b/MarcDecoderWin/lib/util.py @@ -1,5 +1,5 @@ import itertools -import random +import numpy as np from qrennd.configs import Config from qrennd.layouts import Layout @@ -12,8 +12,34 @@ from .sequences import Sequence +def _pad_and_concat(arrays: tuple[np.ndarray, ...]) -> np.ndarray: + """Pad arrays along non-batch dimensions before concatenation. + + Each array is assumed to have the same rank and its first dimension is the + batch dimension. The remaining dimensions are padded with zeros so that all + arrays share a common shape, allowing them to be concatenated along the + batch axis without shape mismatches. + """ + + if not arrays: + return np.array([]) + + ndims = arrays[0].ndim + # Compute the target size for every non-batch dimension + max_shape = [max(arr.shape[i] for arr in arrays) for i in range(1, ndims)] + + padded = [] + for arr in arrays: + pad_width = [(0, 0)] + pad_width += [ + (0, max_dim - arr.shape[i + 1]) for i, max_dim in enumerate(max_shape) + ] + padded.append(np.pad(arr, pad_width, mode="constant")) + + return np.concatenate(padded, axis=0) + + def load_datasets(config: Config, layout: Layout, dataset_name: str): - batch_size = config.train["batch_size"] experiment_name = config.dataset["folder_format_name"] input_names = config.dataset["input_names"] @@ -35,15 +61,14 @@ def load_datasets(config: Config, layout: Layout, dataset_name: str): to_inputs(dataset, proj_matrix, input_names=input_names) for dataset in datasets ] - # Process for keras.model input - input = [to_model_input(*arrs, data_type=data_type) for arrs in processed] - # - # sequences = (Sequence(*tensors, batch_size) for tensors in input) - # sequences = ((b for b in sequence) for sequence in sequences) - # sequences_flattened = itertools.chain.from_iterable(sequences) - # sequences_flattened = list(sequences_flattened) - - return input + # Process for keras.model input and concatenate for tf.data + model_inputs = [to_model_input(*arrs, data_type=data_type) for arrs in processed] + rec_inputs, eval_inputs, log_errors = zip(*model_inputs) + rec_inputs = _pad_and_concat(rec_inputs) + eval_inputs = _pad_and_concat(eval_inputs) + log_errors = np.concatenate(log_errors, axis=0) + + return rec_inputs, eval_inputs, log_errors def load_datasets_backup( diff --git a/MarcDecoderWin/train.py b/MarcDecoderWin/train.py index de79ce0..709af48 100755 --- a/MarcDecoderWin/train.py +++ b/MarcDecoderWin/train.py @@ -1,5 +1,4 @@ import os -import gc import pathlib import random import numpy as np @@ -13,7 +12,6 @@ from lib.util import load_datasets from lib.callbacks import get_callbacks -from lib.sequences import Sequence # from tensorflow.compat.v1 import ConfigProto # from tensorflow.compat.v1 import InteractiveSession @@ -77,25 +75,37 @@ train_data = load_datasets(config=config, layout=layout, dataset_name="train") print("completed") -# this is for model.fit to know that the num_rounds coordinate is not fixed +# build tf.data pipeline batch_size = config.train["batch_size"] -tensor1, tensor2 = train_data[0], train_data[-1] -seq1, seq2 = Sequence(*tensor1, batch_size), Sequence(*tensor2, batch_size) -first_batch, second_batch = seq1[0], seq2[0] -def infinite_gen(inputs): - while True: - random.shuffle(train_data) - sequences = (Sequence(*tensors, batch_size) for tensors in inputs) - # this is for model.fit to know that the num_rounds coordinate is not fixed - yield first_batch - yield second_batch +def make_dataset(rec_input, eval_input, labels, training=False): + """Construct a ``tf.data.Dataset`` from numpy arrays. - for k, sequence in enumerate(sequences): - # cannot do 'yield from sequence' because it has no end! - for i in range(sequence._num_batches): - yield sequence[i] + Using ``from_tensor_slices`` removes Python overhead from the input + pipeline, enabling TensorFlow to better overlap input processing with GPU + execution. When ``training`` is ``True`` the dataset is shuffled and + repeated to provide an infinite stream of data. + """ + + dataset = tf.data.Dataset.from_tensor_slices( + ({"rec_input": rec_input, "eval_input": eval_input}, labels) + ) + dataset = dataset.cache() + if training: + dataset = dataset.shuffle(len(labels)).repeat() + dataset = dataset.batch(batch_size) + + # Prefetch to GPU if available to hide host-to-device transfer latency + gpus = tf.config.list_physical_devices("GPU") + if gpus: + dataset = dataset.apply( + tf.data.experimental.copy_to_device("/GPU:0") + ).prefetch(tf.data.AUTOTUNE) + else: + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + return dataset # load model @@ -125,27 +135,20 @@ def infinite_gen(inputs): # train model -train = config.dataset["train"] -val = config.dataset["val"] -batch_size = config.train["batch_size"] +train_rec, train_eval, train_labels = train_data +val_rec, val_eval, val_labels = val_data + +train_ds = make_dataset(train_rec, train_eval, train_labels, training=True) +val_ds = make_dataset(val_rec, val_eval, val_labels) + history = model.fit( - infinite_gen(train_data), - validation_data=infinite_gen(val_data), - # batch_size=config.train["batch_size"], + train_ds, + validation_data=val_ds, epochs=config.train["epochs"], callbacks=callbacks, - # shuffle=True, verbose=1, - steps_per_epoch=train["shots"] - * len(train["rounds"]) - * len(train["states"]) - // batch_size - + 2, # +2 is for model.fit to know that the num_rounds coordinate is not fixed - validation_steps=val["shots"] - * len(val["rounds"]) - * len(val["states"]) - // batch_size - + 2, # +2 is for model.fit to know that the num_rounds coordinate is not fixed + steps_per_epoch=train_rec.shape[0] // batch_size, + validation_steps=val_rec.shape[0] // batch_size, ) -model.save(config.checkpoint_dir / "final_weights.keras") +model.save_weights(config.checkpoint_dir / "final_weights.h5")