-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from flatironinstitute/glm_class_restructure
Glm class restructure
- Loading branch information
Showing
41 changed files
with
6,020 additions
and
579 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,3 +143,6 @@ docs/generated/ | |
|
||
# vscode | ||
.vscode/ | ||
|
||
# nwb cahce | ||
nwb-cache/ |
2 changes: 1 addition & 1 deletion
2
docs/developers_notes/basis_module.md → docs/developers_notes/01-basis_module.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# The Basis Module | ||
# The `basis` Module | ||
|
||
## Introduction | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# The `base_class` Module | ||
|
||
## Introduction | ||
|
||
The `base_class` module introduces the `Base` class and abstract classes defining broad model categories. These abstract classes **must** inherit from `Base`. | ||
|
||
The `Base` class is envisioned as the foundational component for any object type (e.g., regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_class.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.). | ||
|
||
Designed to be compatible with the `scikit-learn` API, the class structure aims to facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models. | ||
|
||
Below a scheme of how we envision the architecture of the `nemos` models. | ||
|
||
``` | ||
Abstract Class Base | ||
│ | ||
├─ Abstract Subclass BaseRegressor | ||
│ │ | ||
│ └─ Concrete Subclass GLM | ||
│ │ | ||
│ └─ Concrete Subclass RecurrentGLM | ||
│ | ||
├─ Abstract Subclass BaseManifold *(not implemented yet) | ||
│ │ | ||
│ ... | ||
│ | ||
├─ Abstract Subclass Regularizer | ||
│ │ | ||
│ ├─ Concrete Subclass UnRegularized | ||
│ │ | ||
│ ├─ Concrete Subclass Ridge | ||
│ ... | ||
│ | ||
├─ Abstract Subclass Observations | ||
│ │ | ||
│ ├─ Concrete Subclass PoissonObservations | ||
│ │ | ||
│ ├─ Concrete Subclass GammaObservations *(not implemented yet) | ||
│ ... | ||
│ | ||
... | ||
``` | ||
|
||
!!! Example | ||
The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the " GLM regression" category. | ||
As any `BaseRegressor`, it **must** implement the `fit`, `score`, `predict`, and `simulate` methods. | ||
|
||
|
||
## The Class `model_base.Base` | ||
|
||
The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. This alignment is achieved by implementing the `get_params` and `set_params` methods, essential for `scikit-learn` compatibility and foundational for all model implementations. Additionally, the class provides auxiliary helper methods to identify available computational devices (such as GPUs and TPUs) and to facilitate data transfer to these devices. | ||
|
||
For a detailed understanding, consult the [`scikit-learn` API Reference](https://scikit-learn.org/stable/modules/classes.html) and [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). | ||
|
||
!!! Note | ||
We've intentionally omitted the `get_metadata_routing` method. Given its current experimental status and its lack of relevance to the `GLM` class, this method was excluded. Should future needs arise around parameter routing, consider directly inheriting from `sklearn.BaseEstimator`. More information can be found [here](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing). | ||
|
||
### Public methods | ||
|
||
- **`get_params`**: The `get_params` method retrieves parameters set during model instance initialization. Opting for a deep inspection allows the method to assess nested object parameters, resulting in a comprehensive parameter dictionary. | ||
- **`set_params`**: The `set_params` method offers a mechanism to adjust or set an estimator's parameters. It's versatile, accommodating both individual estimators and more complex nested structures like pipelines. Feeding an unrecognized parameter will raise a `ValueError`. | ||
|
||
## The Abstract Class `model_base.BaseRegressor` | ||
|
||
`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of abstract methods: `fit`, `predict`, `score`, and `simulate`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures. | ||
|
||
### Abstract Methods | ||
|
||
For subclasses derived from `BaseRegressor` to function correctly, they must implement the following: | ||
|
||
1. `fit`: Adapt the model using input data `X` and corresponding observations `y`. | ||
2. `predict`: Provide predictions based on the trained model and input data `X`. | ||
3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`. | ||
4. `simulate`: Simulate data based on the trained regression model. | ||
|
||
### Public Methods | ||
|
||
To ensure the consistency and conformity of input data, the `BaseRegressor` introduces two public preprocessing methods: | ||
|
||
1. `preprocess_fit`: Assesses and converts the input for the `fit` method into the desired `jax.ndarray` format. If necessary, this method can initialize model parameters using default values. | ||
2. `preprocess_simulate`: Validates and converts inputs for the `simulate` method. This method confirms the integrity of the feedforward input and, when provided, the initial values for feedback. | ||
|
||
### Auxiliary Methods | ||
|
||
Moreover, `BaseRegressor` incorporates auxiliary methods such as `_convert_to_jnp_ndarray`, `_has_invalid_entry` | ||
and a number of other methods for checking input consistency. | ||
|
||
!!! Tip | ||
Deciding between concrete and abstract methods in a superclass can be nuanced. As a general guideline: any method that's expected in all subclasses and isn't subclass-specific should be concretely implemented in the superclass. Conversely, methods essential for a subclass's expected behavior, but vary based on the subclass, should be abstract in the superclass. For instance, compatibility with the `sklearn.cross_validation` module demands `score`, `fit`, `get_params`, and `set_params` methods. Given their specificity to individual models, `score` and `fit` are abstract in `BaseRegressor`. Conversely, as `get_params` and `set_params` are consistent across model classes, they're inherited from `Base`. This approach typifies our general implementation strategy. However, it's important to note that while these are sound guidelines, exceptions exist based on various factors like future extensibility, clarity, and maintainability. | ||
|
||
|
||
## Contributor Guidelines | ||
|
||
### Implementing Model Subclasses | ||
|
||
When devising a new model subclass based on the `BaseRegressor` abstract class, adhere to the subsequent guidelines: | ||
|
||
- **Must** inherit the `BaseRegressor` abstract superclass. | ||
- **Must** realize the abstract methods: `fit`, `predict`, `score`, and `simulate`. | ||
- **Should not** overwrite the `get_params` and `set_params` methods, inherited from `Base`. | ||
- **May** introduce auxiliary methods such as `_convert_to_jnp_ndarray` for added utility. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# The `observation_models` Module | ||
|
||
## Introduction | ||
|
||
The `observation_models` module provides objects representing the observations of GLM-like models. | ||
|
||
The abstract class `Observations` defines the structure of the subclasses which specify observation types, such as Poisson, Gamma, etc. These objects serve as attributes of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm) class, equipping the GLM with a negative log-likelihood. This is used to define the optimization objective, the deviance which measures model fit quality, and the emission of new observations, for simulating new data. | ||
|
||
## The Abstract class `Observations` | ||
|
||
The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale` methods. | ||
|
||
### Abstract Methods | ||
|
||
For subclasses derived from `Observations` to function correctly, they must implement the following: | ||
|
||
- **negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters. | ||
|
||
- **sample_generator**: Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics[^1]. For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma). | ||
|
||
- **residual_deviance**: Computes the residual deviance based on the model's estimated rates and observations. | ||
|
||
- **estimate_scale**: A method for estimating the scale parameter of the model. | ||
|
||
### Public Methods | ||
|
||
- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[^2]. | ||
|
||
|
||
### Auxiliary Methods | ||
|
||
- **_check_inverse_link_function**: Check that the provided link function is a `Callable` of the `jax` namespace. | ||
|
||
## Concrete `PoissonObservations` class | ||
|
||
The `PoissonObservations` class extends the abstract `Observations` class to provide functionalities specific to the Poisson observation model. It is designed for modeling observed spike counts based on a Poisson distribution with a given rate. | ||
|
||
### Overridden Methods | ||
|
||
- **negative_log_likelihood**: This method computes the Poisson negative log-likelihood of the predicted rates for the observed spike counts. | ||
|
||
- **sample_generator**: Generates random numbers from a Poisson distribution based on the given `predicted_rate`. | ||
|
||
- **residual_deviance**: Calculates the residual deviance for a Poisson model. | ||
|
||
- **estimate_scale**: Assigns a fixed value of 1 to the scale parameter of the Poisson model since Poisson distribution has a fixed scale. | ||
|
||
## Contributor Guidelines | ||
|
||
To implement an observation model class you | ||
|
||
- **Must** inherit from `Observations` | ||
|
||
- **Must** provide a concrete implementation of `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale`. | ||
|
||
- **Should not** reimplement the `pseudo_r2` method as well as the `_check_inverse_link_function` auxiliary method. | ||
|
||
[^1]: | ||
In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from Fisher R. A. | ||
1922. On the mathematical foundations of theoretical statistics. *Philosophical Transactions of the Royal Society of London. Series A, Containing Papers of a Mathematical or Physical Character* 222:309–368. http://doi.org/10.1098/rsta.1922.0009. | ||
[^2]: | ||
Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken. | ||
*Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*. | ||
3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012) |
Oops, something went wrong.