Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved doc strings #58

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions data_dir/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
"""
A dataloader class for loading data in batches. Each model in this repository is designed to take a single argument as
input. Hence, the dataloader can handle three different cases.
- The first is simply the value of the time series data. This is used by stacked recurrent models, such as recurrent
neural networks and structured state space models. In this case data should be a jnp.ndarray of shape
(n_samples, n_timesteps, n_features)
- The second case is NCDEs, which requires the sampling time, the coefficients of an interpolation, and the initial
value of the data. In this case, data should be a tuple of length 3, where the first element is a jnp.ndarray of shape
(n_samples, n_timesteps) for the sampling times, the third element is a jnp.ndarray of shape (n_samples, n_features)
for the initial value, and the second element is a tuple of length n_coeffs, where each element is a jnp.ndarray of
shape (n_samples, n_timesteps-1, n_features) for the coefficients of the interpolation.
- The third case are NRDEs and Log-NCDEs, which require the sampling time, the log-signature of the data, and the
initial value of the data. In this case, data should be a tuple of length 3, where the first element is a jnp.ndarray
of shape (n_samples, n_timesteps) for the sampling times, the third element is a jnp.ndarray of shape
(n_samples, n_features) for the initial value, and the second element is a jnp.ndarray of shape
(n_samples, n_intervals, n_logsig_features) for the log-signature of the data over n_intervals.

Additionally, the data can be stored as a numpy array, and each batch converted to a jax numpy array, to save GPU
memory.
This module implements a `Dataloader` class for loading and batching data. It supports three different types of
data inputs, tailored for different types of models used in this repository.

1. **Time Series Data**: Used by models like recurrent neural networks and structured state space models.
- Input data should be a `jnp.ndarray` of shape `(n_samples, n_timesteps, n_features)`.

2. **Neural Controlled Differential Equations (NCDEs)**: Requires sampling times, coefficients of an interpolation,
and the initial value of the data.
- Input data should be a tuple of length 3:
- The first element: `jnp.ndarray` of shape `(n_samples, n_timesteps)` for sampling times.
- The second element: a tuple of length `n_coeffs`, where each element is a `jnp.ndarray` of shape
`(n_samples, n_timesteps-1, n_features)` for interpolation coefficients.
- The third element: `jnp.ndarray` of shape `(n_samples, n_features)` for the initial value.

3. **Neural Rough Differential Equations (NRDEs) and Log-NCDEs**: Requires sampling times, log-signature of the data,
and the initial value of the data.
- Input data should be a tuple of length 3:
- The first element: `jnp.ndarray` of shape `(n_samples, n_timesteps)` for sampling times.
- The second element: `jnp.ndarray` of shape `(n_samples, n_intervals, n_logsig_features)` for log-signature data.
- The third element: `jnp.ndarray` of shape `(n_samples, n_features)` for the initial value.

Additionally, data can be stored as a NumPy array to save GPU memory, with each batch converted to a JAX NumPy array.

Methods:
- `loop(batch_size, *, key)`: Generates data batches indefinitely. Randomly shuffles data for each batch.
- `loop_epoch(batch_size)`: Generates data batches for one epoch (i.e., a full pass through the dataset).
"""

import jax.numpy as jnp
Expand Down
20 changes: 13 additions & 7 deletions data_dir/datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""
This module contains the Dataset class and functions to generate datasets. Since each model requires different versions
of the data as input, a Dataset object contains three dataloaders:
- raw_dataloaders: dataloaders which return the value of the data at each time point, as is used by recurrent neural
networks and structured state space models.
- coeff_dataloaders: dataloaders which return the coefficients of an interpolation of the data, as is used by NCDEs.
- path_dataloaders: dataloaders which return the log-signature of the data over intervals, as is used by NRDEs and
Log-NCDEs.
This module defines the `Dataset` class and functions for generating datasets tailored to different model types.
A `Dataset` object in this module contains three different dataloaders, each providing a specific version of the data
required by different models:

- `raw_dataloaders`: Returns the raw time series data, suitable for recurrent neural networks (RNNs) and structured
state space models (SSMs).
- `coeff_dataloaders`: Provides the coefficients of an interpolation of the data, used by Neural Controlled Differential
Equations (NCDEs).
- `path_dataloaders`: Provides the log-signature of the data over intervals, used by Neural Rough Differential Equations
(NRDEs) and Log-NCDEs.

The module also includes utility functions for processing and generating these datasets, ensuring compatibility with
different model requirements.
"""

import os
Expand Down
22 changes: 16 additions & 6 deletions models/LRU.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""
Code modified from https://gist.github.com/Ryu1845/7e78da4baa8925b4de482969befa949d

This module implements the LRU class. The LRU class has the following attributes:
- linear_encoder: The linear encoder applied to the input time series.
- blocks: A list of LRU blocks.
- linear_layer: The final linear layer of the S5 model which outputs the predictions.
- classification: Whether the model is used for classification.
- output_step: If the model is used for regression, how many steps to skip before outputting a prediction.
This module implements the `LRU` class, a model architecture using JAX and Equinox.

Attributes of the `LRU` class:
- `linear_encoder`: The linear encoder applied to the input time series data.
- `blocks`: A list of `LRUBlock` instances, each containing the LRU layer, normalization, GLU, and dropout.
- `linear_layer`: The final linear layer that outputs the model predictions.
- `classification`: A boolean indicating whether the model is used for classification tasks.
- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction.

The module also includes the following classes and functions:
- `GLU`: Implements a Gated Linear Unit for non-linear transformations within the model.
- `LRULayer`: A single LRU layer that applies complex-valued transformations and projections to the input.
- `LRUBlock`: A block consisting of normalization, LRU layer, GLU, and dropout, used as a building block for the `LRU`
model.
- `binary_operator_diag`: A helper function used in the associative scan operation within `LRULayer` to process diagonal
elements.
"""

from typing import List
Expand Down
40 changes: 23 additions & 17 deletions models/LogNeuralCDEs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"""
This scripts implements the LogNeuralCDE class using Jax and equinox. The model is a NCDE and the output is
approximated during training using the Log-ODE method. The model's attributes are:
- vf: The vector field $f_{\theta}$ of the NCDE.
- data_dim: The number of channels in the time series.
- depth: The depth of the Log-ODE method. Currently implemented only for depth=1 and 2.
- hidden_dim: The dimension of the hidden state $h_t$.
- linear1: The input linear layer for initialising $h_0$.
- linear2: The output linear layer for obtaining predictions from $h_t$.
- pairs: The pairs of basis elements for the terms in the depth-2 log-signature of the path.
- classification: Whether the model is used for classification.
- output_step: If the model is used for regression, how many steps to skip before outputting a prediction.
- intervals: The intervals for the Log-ODE method.
- solver: The solver applied to the ODE produce by the Log-ODE method.
- stepsize_controller: The stepsize controller for the solver.
- dt0: The initial step size for the solver.
- max_steps: The maximum number of steps for the solver.
- lambd: The Lip(2) regularisation parameter.
This module implements the `LogNeuralCDE` class using JAX and Equinox. The model is a
Neural Controlled Differential Equation (NCDE), where the output is approximated during
training using the Log-ODE method.

Attributes of the `LogNeuralCDE` model:
- `vf`: The vector field $f_{\theta}$ of the NCDE.
- `data_dim`: The number of channels in the input time series.
- `depth`: The depth of the Log-ODE method, currently implemented for depth 1 and 2.
- `hidden_dim`: The dimension of the hidden state $h_t$.
- `linear1`: The input linear layer used to initialise the hidden state $h_0$.
- `linear2`: The output linear layer used to obtain predictions from $h_t$.
- `pairs`: The pairs of basis elements for the terms in the depth-2 log-signature of the path.
- `classification`: Boolean indicating if the model is used for classification tasks.
- `output_step`: If the model is used for regression, the number of steps to skip before outputting a prediction.
- `intervals`: The intervals used in the Log-ODE method.
- `solver`: The solver applied to the ODE produced by the Log-ODE method.
- `stepsize_controller`: The stepsize controller for the solver.
- `dt0`: The initial step size for the solver.
- `max_steps`: The maximum number of steps allowed for the solver.
- `lambd`: The Lip(2) regularisation parameter, used to control the smoothness of the vector field.

The class also includes methods for initialising the model and for performing the forward pass, where the dynamics are
solved using the specified ODE solver.
"""

import diffrax
Expand Down
58 changes: 31 additions & 27 deletions models/NeuralCDEs.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
"""
This script implemented the NeuralCDE and NeuralRDE classes using Jax and equinox. The NeuralCDE class has the
following attributes:
- vf: The vector field $f_{\theta}$ of the NCDE.
- data_dim: The number of channels in the time series.
- hidden_dim: The dimension of the hidden state $h_t$.
- linear1: The input linear layer for initialising $h_0$.
- linear2: The output linear layer for obtaining predictions from $h_t$.
- classification: Whether the model is used for classification.
- output_step: If the model is used for regression, how many steps to skip before outputting a prediction.
- solver: The solver applied to the NCDE.
- stepsize_controller: The stepsize controller for the solver.
- dt0: The initial step size for the solver.
- max_steps: The maximum number of steps for the solver.
This module implements the `NeuralCDE` and `NeuralRDE` classes using JAX and Equinox.

The NeuralRDE class has the following attributes:
- vf: The vector field $\bar{f}_{\theta}$ of the NRDE (except the final linear layer).
- data_dim: The number of channels in the time series.
- logsig_dim: The dimension of the log-signature of the path which will be used as input to the NRDE.
- hidden_dim: The dimension of the hidden state $h_t$.
- mlp_linear: The final linear layer of the vector field.
- linear1: The input linear layer for initialising $h_0$.
- linear2: The output linear layer for obtaining predictions from $h_t$.
- classification: Whether the model is used for classification.
- output_step: If the model is used for regression, how many steps to skip before outputting a prediction.
- solver: The solver applied to the NRDE.
- stepsize_controller: The stepsize controller for the solver.
- dt0: The initial step size for the solver.
- max_steps: The maximum number of steps for the solver.
Attributes of `NeuralCDE`:
- `vf`: The vector field $f_{\theta}$ of the NCDE.
- `data_dim`: Number of channels in the input time series.
- `hidden_dim`: Dimension of the hidden state $h_t$.
- `linear1`: Input linear layer for initializing $h_0$.
- `linear2`: Output linear layer for generating predictions from $h_t$.
- `classification`: Boolean indicating if the model is used for classification.
- `output_step`: For regression tasks, specifies the step interval for outputting predictions.
- `solver`: The solver used to integrate the NCDE.
- `stepsize_controller`: Controls the step size for the solver.
- `dt0`: Initial step size for the solver.
- `max_steps`: Maximum number of steps allowed for the solver.

Attributes of `NeuralRDE`:
- `vf`: The vector field $\bar{f}_{\theta}$ of the NRDE (excluding the final linear layer).
- `data_dim`: Number of channels in the input time series.
- `logsig_dim`: Dimension of the log-signature used as input to the NRDE.
- `hidden_dim`: Dimension of the hidden state $h_t$.
- `mlp_linear`: Final linear layer of the vector field.
- `linear1`: Input linear layer for initializing $h_0$.
- `linear2`: Output linear layer for generating predictions from $h_t$.
- `classification`: Boolean indicating if the model is used for classification.
- `output_step`: For regression tasks, specifies the step interval for outputting predictions.
- `solver`: The solver used to integrate the NRDE.
- `stepsize_controller`: Controls the step size for the solver.
- `dt0`: Initial step size for the solver.
- `max_steps`: Maximum number of steps allowed for the solver.

The module also includes the `VectorField` class, which defines the vector fields used by both
`NeuralCDE` and `NeuralRDE`.
"""

import diffrax
Expand Down
31 changes: 26 additions & 5 deletions models/RNN.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
"""
This module implements the RNN class and the RNN cell classes. The RNN class has the following attributes:
- cell: The RNN cell used in the RNN.
- output_layer: The linear layer used to obtain the output of the RNN.
- hidden_dim: The dimension of the hidden state $h_t$.
- classification: Whether the model is used for classification.
This module implements the `RNN` class and various RNN cell classes using JAX and Equinox. The `RNN`
class is designed to handle both classification and regression tasks, and can be configured with different
types of RNN cells.

Attributes of the `RNN` class:
- `cell`: The RNN cell used within the RNN, which can be one of several types (e.g., `LinearCell`, `GRUCell`,
`LSTMCell`, `MLPCell`).
- `output_layer`: The linear layer applied to the hidden state to produce the model's output.
- `hidden_dim`: The dimension of the hidden state $h_t$.
- `classification`: A boolean indicating whether the model is used for classification tasks.
- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction.

RNN Cell Classes:
- `_AbstractRNNCell`: An abstract base class for all RNN cells, defining the interface for custom RNN cells.
- `LinearCell`: A simple RNN cell that applies a linear transformation to the concatenated input and hidden state.
- `GRUCell`: An implementation of the Gated Recurrent Unit (GRU) cell.
- `LSTMCell`: An implementation of the Long Short-Term Memory (LSTM) cell.
- `MLPCell`: An RNN cell that applies a multi-layer perceptron (MLP) to the concatenated input and hidden state.

Each RNN cell class implements the following methods:
- `__init__`: Initialises the RNN cell with the specified input dimensions and hidden state size.
- `__call__`: Applies the RNN cell to the input and hidden state, returning the updated hidden state.

The `RNN` class also includes:
- A `__call__` method that processes a sequence of inputs, returning either the final output for classification or a
sequence of outputs for regression.
"""

import abc
Expand Down
21 changes: 15 additions & 6 deletions models/S5.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
"""
S5 implementation modified from: https://github.com/lindermanlab/S5/blob/main/s5/ssm_init.py

This module implements the structured state space model S5. The S5 model has the following attributes:
- linear_encoder: The linear encoder applied to the input time series.
- blocks: A list of S5 blocks.
- linear_layer: The final linear layer of the S5 model which outputs the predictions.
- classification: Whether the model is used for classification.
- output_step: If the model is used for regression, how many steps to skip before outputting a prediction.
This module implements S5 using JAX and Equinox.

Attributes of the S5 model:
- `linear_encoder`: The linear encoder applied to the input time series.
- `blocks`: A list of S5 blocks, each consisting of an S5 layer, normalisation, GLU, and dropout.
- `linear_layer`: The final linear layer that outputs the predictions of the model.
- `classification`: A boolean indicating whether the model is used for classification tasks.
- `output_step`: For regression tasks, specifies how many steps to skip before outputting a prediction.

The module also includes:
- `S5Layer`: Implements the core S5 layer using structured state space models with options for
different discretisation methods and eigenvalue clipping.
- `S5Block`: Combines the S5 layer with batch normalisation, a GLU activation, and dropout.
- Utility functions for initialising and discretising the state space model components,
such as `make_HiPPO`, `make_NPLR_HiPPO`, and `make_DPLR_HiPPO`.
"""

from typing import List
Expand Down
38 changes: 37 additions & 1 deletion models/generate_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
"""
This module contains a function to generate a model based on the model name and the hyperparameters.
This module provides a function to generate a model based on a model name and hyperparameters.
It supports various types of models, including Neural CDEs, RNNs, and the S5 model.

Function:
- `create_model`: Generates and returns a model instance along with its state (if applicable)
based on the provided model name and hyperparameters.

Parameters for `create_model`:
- `model_name`: A string specifying the model architecture to create. Supported values include
'log_ncde', 'ncde', 'nrde', 'lru', 'S5', 'rnn_linear', 'rnn_gru', 'rnn_lstm', and 'rnn_mlp'.
- `data_dim`: The input data dimension.
- `logsig_dim`: The dimension of the log-signature used in NRDE and Log-NCDE models.
- `logsig_depth`: The depth of the log-signature used in NRDE and Log-NCDE models.
- `intervals`: The intervals used in NRDE and Log-NCDE models.
- `label_dim`: The output label dimension.
- `hidden_dim`: The hidden state dimension for the model.
- `num_blocks`: The number of blocks (layers) in models like LRU or S5.
- `vf_depth`: The depth of the vector field network for CDE models.
- `vf_width`: The width of the vector field network for CDE models.
- `classification`: A boolean indicating whether the task is classification (True) or regression (False).
- `output_step`: The step interval for outputting predictions in sequence models.
- `ssm_dim`: The state-space model dimension for S5 models.
- `ssm_blocks`: The number of SSM blocks in S5 models.
- `solver`: The ODE solver used in CDE models, with a default of `diffrax.Heun()`.
- `stepsize_controller`: The step size controller used in CDE models, with a default of `diffrax.ConstantStepSize()`.
- `dt0`: The initial time step for the solver.
- `max_steps`: The maximum number of steps for the solver.
- `scale`: A scaling factor applied to the vf initialisation in CDE models.
- `lambd`: A regularisation parameter used in Log-NCDE models.
- `key`: A JAX PRNG key for random number generation.

Returns:
- A tuple containing the created model and its state (if applicable).

Raises:
- `ValueError`: If required hyperparameters for the specified model are not provided or if an
unknown model name is passed.
"""

import diffrax
Expand Down
17 changes: 14 additions & 3 deletions results/analyse_results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""
This script is used to analyse the results of the UEA and PPG experiments. It is designed both to determine the
optimal hyperparameters for each model and dataset following a grid search optimisation, and to compare the performance
of the models across different random seeds on fixed hyperparameters.
This script analyses the results of the UEA and PPG experiments. It is designed to determine the
optimal hyperparameters for each model and dataset following a grid search optimisation, and to
compare the performance of models across different random seeds using fixed hyperparameters.

The script performs the following tasks:
- Identifies the best hyperparameters based on validation accuracy for each model and dataset.
- Compares model performance across different random seeds, calculating the mean and standard
deviation of test accuracy.

The script can handle two types of experiments:
1. `hypopt`: Hyperparameter optimisation, where the best configuration is selected based on
validation accuracy.
2. `repeats`: Repeated experiments with fixed hyperparameters, where the results across multiple
runs are aggregated.
"""

import os
Expand Down
Loading
Loading