diff --git a/data_dir/dataloaders.py b/data_dir/dataloaders.py index 8007c2b..a57cae5 100644 --- a/data_dir/dataloaders.py +++ b/data_dir/dataloaders.py @@ -1,22 +1,30 @@ """ -A dataloader class for loading data in batches. Each model in this repository is designed to take a single argument as -input. Hence, the dataloader can handle three different cases. -- The first is simply the value of the time series data. This is used by stacked recurrent models, such as recurrent -neural networks and structured state space models. In this case data should be a jnp.ndarray of shape -(n_samples, n_timesteps, n_features) -- The second case is NCDEs, which requires the sampling time, the coefficients of an interpolation, and the initial -value of the data. In this case, data should be a tuple of length 3, where the first element is a jnp.ndarray of shape -(n_samples, n_timesteps) for the sampling times, the third element is a jnp.ndarray of shape (n_samples, n_features) -for the initial value, and the second element is a tuple of length n_coeffs, where each element is a jnp.ndarray of -shape (n_samples, n_timesteps-1, n_features) for the coefficients of the interpolation. -- The third case are NRDEs and Log-NCDEs, which require the sampling time, the log-signature of the data, and the -initial value of the data. In this case, data should be a tuple of length 3, where the first element is a jnp.ndarray -of shape (n_samples, n_timesteps) for the sampling times, the third element is a jnp.ndarray of shape -(n_samples, n_features) for the initial value, and the second element is a jnp.ndarray of shape -(n_samples, n_intervals, n_logsig_features) for the log-signature of the data over n_intervals. - -Additionally, the data can be stored as a numpy array, and each batch converted to a jax numpy array, to save GPU -memory. +This module implements a `Dataloader` class for loading and batching data. It supports three different types of +data inputs, tailored for different types of models used in this repository. + +1. **Time Series Data**: Used by models like recurrent neural networks and structured state space models. + - Input data should be a `jnp.ndarray` of shape `(n_samples, n_timesteps, n_features)`. + +2. **Neural Controlled Differential Equations (NCDEs)**: Requires sampling times, coefficients of an interpolation, + and the initial value of the data. + - Input data should be a tuple of length 3: + - The first element: `jnp.ndarray` of shape `(n_samples, n_timesteps)` for sampling times. + - The second element: a tuple of length `n_coeffs`, where each element is a `jnp.ndarray` of shape + `(n_samples, n_timesteps-1, n_features)` for interpolation coefficients. + - The third element: `jnp.ndarray` of shape `(n_samples, n_features)` for the initial value. + +3. **Neural Rough Differential Equations (NRDEs) and Log-NCDEs**: Requires sampling times, log-signature of the data, + and the initial value of the data. + - Input data should be a tuple of length 3: + - The first element: `jnp.ndarray` of shape `(n_samples, n_timesteps)` for sampling times. + - The second element: `jnp.ndarray` of shape `(n_samples, n_intervals, n_logsig_features)` for log-signature data. + - The third element: `jnp.ndarray` of shape `(n_samples, n_features)` for the initial value. + +Additionally, data can be stored as a NumPy array to save GPU memory, with each batch converted to a JAX NumPy array. + +Methods: +- `loop(batch_size, *, key)`: Generates data batches indefinitely. Randomly shuffles data for each batch. +- `loop_epoch(batch_size)`: Generates data batches for one epoch (i.e., a full pass through the dataset). """ import jax.numpy as jnp diff --git a/data_dir/datasets.py b/data_dir/datasets.py index a4fd34d..9757743 100644 --- a/data_dir/datasets.py +++ b/data_dir/datasets.py @@ -1,11 +1,17 @@ """ -This module contains the Dataset class and functions to generate datasets. Since each model requires different versions -of the data as input, a Dataset object contains three dataloaders: -- raw_dataloaders: dataloaders which return the value of the data at each time point, as is used by recurrent neural -networks and structured state space models. -- coeff_dataloaders: dataloaders which return the coefficients of an interpolation of the data, as is used by NCDEs. -- path_dataloaders: dataloaders which return the log-signature of the data over intervals, as is used by NRDEs and -Log-NCDEs. +This module defines the `Dataset` class and functions for generating datasets tailored to different model types. +A `Dataset` object in this module contains three different dataloaders, each providing a specific version of the data +required by different models: + +- `raw_dataloaders`: Returns the raw time series data, suitable for recurrent neural networks (RNNs) and structured + state space models (SSMs). +- `coeff_dataloaders`: Provides the coefficients of an interpolation of the data, used by Neural Controlled Differential + Equations (NCDEs). +- `path_dataloaders`: Provides the log-signature of the data over intervals, used by Neural Rough Differential Equations + (NRDEs) and Log-NCDEs. + +The module also includes utility functions for processing and generating these datasets, ensuring compatibility with +different model requirements. """ import os diff --git a/models/LRU.py b/models/LRU.py index 75f4d93..4bd2531 100644 --- a/models/LRU.py +++ b/models/LRU.py @@ -1,12 +1,22 @@ """ Code modified from https://gist.github.com/Ryu1845/7e78da4baa8925b4de482969befa949d -This module implements the LRU class. The LRU class has the following attributes: -- linear_encoder: The linear encoder applied to the input time series. -- blocks: A list of LRU blocks. -- linear_layer: The final linear layer of the S5 model which outputs the predictions. -- classification: Whether the model is used for classification. -- output_step: If the model is used for regression, how many steps to skip before outputting a prediction. +This module implements the `LRU` class, a model architecture using JAX and Equinox. + +Attributes of the `LRU` class: +- `linear_encoder`: The linear encoder applied to the input time series data. +- `blocks`: A list of `LRUBlock` instances, each containing the LRU layer, normalization, GLU, and dropout. +- `linear_layer`: The final linear layer that outputs the model predictions. +- `classification`: A boolean indicating whether the model is used for classification tasks. +- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction. + +The module also includes the following classes and functions: +- `GLU`: Implements a Gated Linear Unit for non-linear transformations within the model. +- `LRULayer`: A single LRU layer that applies complex-valued transformations and projections to the input. +- `LRUBlock`: A block consisting of normalization, LRU layer, GLU, and dropout, used as a building block for the `LRU` + model. +- `binary_operator_diag`: A helper function used in the associative scan operation within `LRULayer` to process diagonal + elements. """ from typing import List diff --git a/models/LogNeuralCDEs.py b/models/LogNeuralCDEs.py index c5468a8..70da781 100644 --- a/models/LogNeuralCDEs.py +++ b/models/LogNeuralCDEs.py @@ -1,21 +1,27 @@ """ -This scripts implements the LogNeuralCDE class using Jax and equinox. The model is a NCDE and the output is -approximated during training using the Log-ODE method. The model's attributes are: -- vf: The vector field $f_{\theta}$ of the NCDE. -- data_dim: The number of channels in the time series. -- depth: The depth of the Log-ODE method. Currently implemented only for depth=1 and 2. -- hidden_dim: The dimension of the hidden state $h_t$. -- linear1: The input linear layer for initialising $h_0$. -- linear2: The output linear layer for obtaining predictions from $h_t$. -- pairs: The pairs of basis elements for the terms in the depth-2 log-signature of the path. -- classification: Whether the model is used for classification. -- output_step: If the model is used for regression, how many steps to skip before outputting a prediction. -- intervals: The intervals for the Log-ODE method. -- solver: The solver applied to the ODE produce by the Log-ODE method. -- stepsize_controller: The stepsize controller for the solver. -- dt0: The initial step size for the solver. -- max_steps: The maximum number of steps for the solver. -- lambd: The Lip(2) regularisation parameter. +This module implements the `LogNeuralCDE` class using JAX and Equinox. The model is a +Neural Controlled Differential Equation (NCDE), where the output is approximated during +training using the Log-ODE method. + +Attributes of the `LogNeuralCDE` model: +- `vf`: The vector field $f_{\theta}$ of the NCDE. +- `data_dim`: The number of channels in the input time series. +- `depth`: The depth of the Log-ODE method, currently implemented for depth 1 and 2. +- `hidden_dim`: The dimension of the hidden state $h_t$. +- `linear1`: The input linear layer used to initialise the hidden state $h_0$. +- `linear2`: The output linear layer used to obtain predictions from $h_t$. +- `pairs`: The pairs of basis elements for the terms in the depth-2 log-signature of the path. +- `classification`: Boolean indicating if the model is used for classification tasks. +- `output_step`: If the model is used for regression, the number of steps to skip before outputting a prediction. +- `intervals`: The intervals used in the Log-ODE method. +- `solver`: The solver applied to the ODE produced by the Log-ODE method. +- `stepsize_controller`: The stepsize controller for the solver. +- `dt0`: The initial step size for the solver. +- `max_steps`: The maximum number of steps allowed for the solver. +- `lambd`: The Lip(2) regularisation parameter, used to control the smoothness of the vector field. + +The class also includes methods for initialising the model and for performing the forward pass, where the dynamics are +solved using the specified ODE solver. """ import diffrax diff --git a/models/NeuralCDEs.py b/models/NeuralCDEs.py index ec62b00..1ab971f 100644 --- a/models/NeuralCDEs.py +++ b/models/NeuralCDEs.py @@ -1,32 +1,36 @@ """ -This script implemented the NeuralCDE and NeuralRDE classes using Jax and equinox. The NeuralCDE class has the -following attributes: -- vf: The vector field $f_{\theta}$ of the NCDE. -- data_dim: The number of channels in the time series. -- hidden_dim: The dimension of the hidden state $h_t$. -- linear1: The input linear layer for initialising $h_0$. -- linear2: The output linear layer for obtaining predictions from $h_t$. -- classification: Whether the model is used for classification. -- output_step: If the model is used for regression, how many steps to skip before outputting a prediction. -- solver: The solver applied to the NCDE. -- stepsize_controller: The stepsize controller for the solver. -- dt0: The initial step size for the solver. -- max_steps: The maximum number of steps for the solver. +This module implements the `NeuralCDE` and `NeuralRDE` classes using JAX and Equinox. -The NeuralRDE class has the following attributes: -- vf: The vector field $\bar{f}_{\theta}$ of the NRDE (except the final linear layer). -- data_dim: The number of channels in the time series. -- logsig_dim: The dimension of the log-signature of the path which will be used as input to the NRDE. -- hidden_dim: The dimension of the hidden state $h_t$. -- mlp_linear: The final linear layer of the vector field. -- linear1: The input linear layer for initialising $h_0$. -- linear2: The output linear layer for obtaining predictions from $h_t$. -- classification: Whether the model is used for classification. -- output_step: If the model is used for regression, how many steps to skip before outputting a prediction. -- solver: The solver applied to the NRDE. -- stepsize_controller: The stepsize controller for the solver. -- dt0: The initial step size for the solver. -- max_steps: The maximum number of steps for the solver. +Attributes of `NeuralCDE`: +- `vf`: The vector field $f_{\theta}$ of the NCDE. +- `data_dim`: Number of channels in the input time series. +- `hidden_dim`: Dimension of the hidden state $h_t$. +- `linear1`: Input linear layer for initializing $h_0$. +- `linear2`: Output linear layer for generating predictions from $h_t$. +- `classification`: Boolean indicating if the model is used for classification. +- `output_step`: For regression tasks, specifies the step interval for outputting predictions. +- `solver`: The solver used to integrate the NCDE. +- `stepsize_controller`: Controls the step size for the solver. +- `dt0`: Initial step size for the solver. +- `max_steps`: Maximum number of steps allowed for the solver. + +Attributes of `NeuralRDE`: +- `vf`: The vector field $\bar{f}_{\theta}$ of the NRDE (excluding the final linear layer). +- `data_dim`: Number of channels in the input time series. +- `logsig_dim`: Dimension of the log-signature used as input to the NRDE. +- `hidden_dim`: Dimension of the hidden state $h_t$. +- `mlp_linear`: Final linear layer of the vector field. +- `linear1`: Input linear layer for initializing $h_0$. +- `linear2`: Output linear layer for generating predictions from $h_t$. +- `classification`: Boolean indicating if the model is used for classification. +- `output_step`: For regression tasks, specifies the step interval for outputting predictions. +- `solver`: The solver used to integrate the NRDE. +- `stepsize_controller`: Controls the step size for the solver. +- `dt0`: Initial step size for the solver. +- `max_steps`: Maximum number of steps allowed for the solver. + +The module also includes the `VectorField` class, which defines the vector fields used by both +`NeuralCDE` and `NeuralRDE`. """ import diffrax diff --git a/models/RNN.py b/models/RNN.py index 3b72af2..338f7db 100644 --- a/models/RNN.py +++ b/models/RNN.py @@ -1,9 +1,30 @@ """ -This module implements the RNN class and the RNN cell classes. The RNN class has the following attributes: -- cell: The RNN cell used in the RNN. -- output_layer: The linear layer used to obtain the output of the RNN. -- hidden_dim: The dimension of the hidden state $h_t$. -- classification: Whether the model is used for classification. +This module implements the `RNN` class and various RNN cell classes using JAX and Equinox. The `RNN` +class is designed to handle both classification and regression tasks, and can be configured with different +types of RNN cells. + +Attributes of the `RNN` class: +- `cell`: The RNN cell used within the RNN, which can be one of several types (e.g., `LinearCell`, `GRUCell`, + `LSTMCell`, `MLPCell`). +- `output_layer`: The linear layer applied to the hidden state to produce the model's output. +- `hidden_dim`: The dimension of the hidden state $h_t$. +- `classification`: A boolean indicating whether the model is used for classification tasks. +- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction. + +RNN Cell Classes: +- `_AbstractRNNCell`: An abstract base class for all RNN cells, defining the interface for custom RNN cells. +- `LinearCell`: A simple RNN cell that applies a linear transformation to the concatenated input and hidden state. +- `GRUCell`: An implementation of the Gated Recurrent Unit (GRU) cell. +- `LSTMCell`: An implementation of the Long Short-Term Memory (LSTM) cell. +- `MLPCell`: An RNN cell that applies a multi-layer perceptron (MLP) to the concatenated input and hidden state. + +Each RNN cell class implements the following methods: +- `__init__`: Initialises the RNN cell with the specified input dimensions and hidden state size. +- `__call__`: Applies the RNN cell to the input and hidden state, returning the updated hidden state. + +The `RNN` class also includes: +- A `__call__` method that processes a sequence of inputs, returning either the final output for classification or a +sequence of outputs for regression. """ import abc diff --git a/models/S5.py b/models/S5.py index 56aa9b0..b391b5c 100644 --- a/models/S5.py +++ b/models/S5.py @@ -1,12 +1,21 @@ """ S5 implementation modified from: https://github.com/lindermanlab/S5/blob/main/s5/ssm_init.py -This module implements the structured state space model S5. The S5 model has the following attributes: -- linear_encoder: The linear encoder applied to the input time series. -- blocks: A list of S5 blocks. -- linear_layer: The final linear layer of the S5 model which outputs the predictions. -- classification: Whether the model is used for classification. -- output_step: If the model is used for regression, how many steps to skip before outputting a prediction. +This module implements S5 using JAX and Equinox. + +Attributes of the S5 model: +- `linear_encoder`: The linear encoder applied to the input time series. +- `blocks`: A list of S5 blocks, each consisting of an S5 layer, normalisation, GLU, and dropout. +- `linear_layer`: The final linear layer that outputs the predictions of the model. +- `classification`: A boolean indicating whether the model is used for classification tasks. +- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction. + +The module also includes: +- `S5Layer`: Implements the core S5 layer using structured state space models with options for + different discretisation methods and eigenvalue clipping. +- `S5Block`: Combines the S5 layer with batch normalisation, a GLU activation, and dropout. +- Utility functions for initialising and discretising the state space model components, + such as `make_HiPPO`, `make_NPLR_HiPPO`, and `make_DPLR_HiPPO`. """ from typing import List diff --git a/models/generate_model.py b/models/generate_model.py index 2b2a5a6..e41a3d3 100644 --- a/models/generate_model.py +++ b/models/generate_model.py @@ -1,5 +1,41 @@ """ -This module contains a function to generate a model based on the model name and the hyperparameters. +This module provides a function to generate a model based on a model name and hyperparameters. +It supports various types of models, including Neural CDEs, RNNs, and the S5 model. + +Function: +- `create_model`: Generates and returns a model instance along with its state (if applicable) + based on the provided model name and hyperparameters. + +Parameters for `create_model`: +- `model_name`: A string specifying the model architecture to create. Supported values include + 'log_ncde', 'ncde', 'nrde', 'lru', 'S5', 'rnn_linear', 'rnn_gru', 'rnn_lstm', and 'rnn_mlp'. +- `data_dim`: The input data dimension. +- `logsig_dim`: The dimension of the log-signature used in NRDE and Log-NCDE models. +- `logsig_depth`: The depth of the log-signature used in NRDE and Log-NCDE models. +- `intervals`: The intervals used in NRDE and Log-NCDE models. +- `label_dim`: The output label dimension. +- `hidden_dim`: The hidden state dimension for the model. +- `num_blocks`: The number of blocks (layers) in models like LRU or S5. +- `vf_depth`: The depth of the vector field network for CDE models. +- `vf_width`: The width of the vector field network for CDE models. +- `classification`: A boolean indicating whether the task is classification (True) or regression (False). +- `output_step`: The step interval for outputting predictions in sequence models. +- `ssm_dim`: The state-space model dimension for S5 models. +- `ssm_blocks`: The number of SSM blocks in S5 models. +- `solver`: The ODE solver used in CDE models, with a default of `diffrax.Heun()`. +- `stepsize_controller`: The step size controller used in CDE models, with a default of `diffrax.ConstantStepSize()`. +- `dt0`: The initial time step for the solver. +- `max_steps`: The maximum number of steps for the solver. +- `scale`: A scaling factor applied to the vf initialisation in CDE models. +- `lambd`: A regularisation parameter used in Log-NCDE models. +- `key`: A JAX PRNG key for random number generation. + +Returns: +- A tuple containing the created model and its state (if applicable). + +Raises: +- `ValueError`: If required hyperparameters for the specified model are not provided or if an + unknown model name is passed. """ import diffrax diff --git a/results/analyse_results.py b/results/analyse_results.py index 31fffba..27dfc6e 100644 --- a/results/analyse_results.py +++ b/results/analyse_results.py @@ -1,7 +1,18 @@ """ -This script is used to analyse the results of the UEA and PPG experiments. It is designed both to determine the -optimal hyperparameters for each model and dataset following a grid search optimisation, and to compare the performance -of the models across different random seeds on fixed hyperparameters. +This script analyses the results of the UEA and PPG experiments. It is designed to determine the +optimal hyperparameters for each model and dataset following a grid search optimisation, and to +compare the performance of models across different random seeds using fixed hyperparameters. + +The script performs the following tasks: +- Identifies the best hyperparameters based on validation accuracy for each model and dataset. +- Compares model performance across different random seeds, calculating the mean and standard + deviation of test accuracy. + +The script can handle two types of experiments: +1. `hypopt`: Hyperparameter optimisation, where the best configuration is selected based on + validation accuracy. +2. `repeats`: Repeated experiments with fixed hyperparameters, where the results across multiple + runs are aggregated. """ import os diff --git a/run_experiment.py b/run_experiment.py index a57c1f1..4e5dfbf 100644 --- a/run_experiment.py +++ b/run_experiment.py @@ -1,7 +1,22 @@ """ -This script loads a JSON file containing the hyperparameters for each model and dataset, and uses -create_dataset_model_and_train from train.py to train the models on the datasets using the hyperparameters. The results -are saved in the output directory specified in the JSON file. +This script loads hyperparameters from JSON files and trains models on specified datasets using +the `create_dataset_model_and_train` function from `train.py` or its PyTorch equivalent. The results +are saved in the output directories defined in the JSON files. + +The `run_experiments` function iterates over model names and dataset names, loading configuration +files from a specified folder, and then calls the appropriate training function based on the +framework (PyTorch or JAX). + +Arguments for `run_experiments`: +- `model_names`: List of model architectures to use. +- `dataset_names`: List of datasets to train on. +- `experiment_folder`: Directory containing JSON configuration files. +- `pytorch_experiments`: Boolean indicating whether to use PyTorch (True) or JAX (False). + +The script also provides a command-line interface (CLI) for specifying whether to run PyTorch experiments. + +Usage: +- Use the `--pytorch_experiments` flag to run experiments with PyTorch; otherwise, JAX is used by default. """ import argparse diff --git a/torch_experiments/jax_dataset.py b/torch_experiments/jax_dataset.py index bf965a2..4188e97 100644 --- a/torch_experiments/jax_dataset.py +++ b/torch_experiments/jax_dataset.py @@ -1,3 +1,17 @@ +""" +This module defines a custom PyTorch `Dataset` class for loading and processing time series data +from different benchmarks (UEA, toy, PPG) which have been preprocessed and saved as Jax numpy arrays. +The dataset can be pre-split into training, validation, and test sets, or dynamically split based on provided indexes. + +Classes: +- `Dataset`: A PyTorch dataset class that handles loading data and labels from pickle files of jax numpy arrays, + optional inclusion of time as a feature, and splitting of data into train/val/test sets. + +Methods: +- `__len__`: Returns the length of the dataset. +- `__getitem__`: Retrieves a data-label pair at the specified index. +""" + import os import pickle diff --git a/torch_experiments/s6_recurrence.py b/torch_experiments/s6_recurrence.py index 9b861a7..0506cbd 100644 --- a/torch_experiments/s6_recurrence.py +++ b/torch_experiments/s6_recurrence.py @@ -1,4 +1,16 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. +""" +This module defines the `S6Layer` class, a PyTorch implementation of a sequence model layer using +the S6 architecture. + +Classes: +- `S6Layer`: Implements the S6 layer with options for model dimensions, state dimensions, and + initialisation strategies. + +Functions: +- `forward`: Applies the S6 transformation to the input sequence, returning the processed sequence. + +Copyright (c) 2023, Tri Dao, Albert Gu. +""" import math diff --git a/torch_experiments/train.py b/torch_experiments/train.py index a03147a..88ea05c 100644 --- a/torch_experiments/train.py +++ b/torch_experiments/train.py @@ -1,3 +1,36 @@ +""" +This module defines classes and functions for creating and training Mamba and S6 using PyTorch. +The main function, `create_dataset_model_and_train`, is designed to initialise the dataset, construct the model, and +execute the training process. + +The function `create_dataset_model_and_train` takes the following arguments: + +- `seed`: An integer representing the random seed for reproducibility. +- `data_dir`: The directory where the dataset is stored. +- `output_parent_dir`: The parent directory where the training outputs will be saved. +- `model_name`: A string specifying the model architecture to use ('mamba' or 'S6'). +- `metric`: The evaluation metric to use during training, either 'accuracy' for classification or 'mse' for regression. +- `batch_size`: The number of samples per batch during training. +- `dataset_name`: The name of the dataset to load and use for training. +- `n_samples`: The total number of samples in the dataset. +- `output_step`: For regression tasks, defines the interval for outputting predictions. +- `use_presplit`: A boolean indicating whether to use a pre-split dataset. +- `include_time`: A boolean that determines whether to include time as a feature in the dataset. +- `num_steps`: The total number of steps for training the model. +- `print_steps`: The interval of steps after which to print training progress and metrics. +- `lr`: The learning rate for the optimiser. +- `model_args`: A dictionary containing additional arguments and hyperparameters for model customisation. + +Classes defined in this module: + +- `GLU`: Implements a Gated Linear Unit (GLU) layer, which applies a linear transformation followed by a gated + activation. +- `MambaBlock`: A block that consists of normalisation, a Mamba or S6 layer, a GLU layer, and dropout. It serves as a + basic building block for the Mamba model. +- `Mamba`: A sequence model that stacks multiple MambaBlock layers and includes an encoder and decoder for input/output + transformation. +""" + import os import shutil import time diff --git a/train.py b/train.py index 74a7bd5..7d7cd54 100644 --- a/train.py +++ b/train.py @@ -1,24 +1,38 @@ """ -This module defines functions which take hyperparameters and produce datasets and models, as well as train them. -The function create_dataset_model_and_train takes as argument: -- seed: a random seed -- data_dir: the directory where the data is stored -- use_presplit: a boolean indicating whether to use a pre-split dataset -- dataset_name: the name of the dataset -- output_step: if a regression dataset, how many steps to skip before outputting a prediction. -- metric: the metric to use for evaluation. Currently implemented `mse' or 'accuracy'. -- include_time: whether to include time as a channel in the time series. -- T: Scale time to [0, T]. -- model_name: the name of the model to use. -- stepsize: the size of the intervals for the Log-ODE method. -- logsig_depth: the depth of the Log-ODE method. Currently implemented only for depth=1 and 2. -- model_args: a dictionary of additional arguments for the model. -- num_steps: the number of steps to train the model. -- print_steps: how often to print the loss. -- lr: the learning rate. -- lr_scheduler: the learning rate scheduler. -- batch_size: the batch size. -- output_parent_dir: the parent directory where the output is stored. +This module defines functions for creating datasets, building models, and training them using JAX +and Equinox. The main function, `create_dataset_model_and_train`, is designed to initialise the +dataset, construct the model, and execute the training process. + +The function `create_dataset_model_and_train` takes the following arguments: + +- `seed`: A random seed for reproducibility. +- `data_dir`: The directory where the dataset is stored. +- `use_presplit`: A boolean indicating whether to use a pre-split dataset. +- `dataset_name`: The name of the dataset to load and use for training. +- `output_step`: For regression tasks, the number of steps to skip before outputting a prediction. +- `metric`: The metric to use for evaluation. Supported values are `'mse'` for regression and `'accuracy'` for + classification. +- `include_time`: A boolean indicating whether to include time as a channel in the time series data. +- `T`: The maximum time value to scale time data to [0, T]. +- `model_name`: The name of the model architecture to use. +- `stepsize`: The size of the intervals for the Log-ODE method. +- `logsig_depth`: The depth of the Log-ODE method. Currently implemented for depths 1 and 2. +- `model_args`: A dictionary of additional arguments to customise the model. +- `num_steps`: The number of steps to train the model. +- `print_steps`: How often to print the loss during training. +- `lr`: The learning rate for the optimiser. +- `lr_scheduler`: The learning rate scheduler function. +- `batch_size`: The number of samples per batch during training. +- `output_parent_dir`: The parent directory where the training outputs will be saved. + +The module also includes the following key functions: + +- `calc_output`: Computes the model output, handling stateful and nondeterministic models with JAX's `vmap` for + batching. +- `classification_loss`: Computes the loss for classification tasks, including optional regularisation. +- `regression_loss`: Computes the loss for regression tasks, including optional regularisation. +- `make_step`: Performs a single optimisation step, updating model parameters based on the computed gradients. +- `train_model`: Handles the training loop, managing metrics, early stopping, and saving progress at regular intervals. """ import os