Skip to content

Commit a7fdc8d

Browse files
thomasZenjonathanoesterledgonschorek
committed
Initial commit
Co-authored-by: jonathanoesterle <jonathan.oesterle.work@gmail.com> Co-authored-by: dgonschorek <dominic.gonschorek@web.de> Co-authored-by: thomasZen <thomas.zenkel@gmail.com>
0 parents  commit a7fdc8d

File tree

16 files changed

+5619
-0
lines changed

16 files changed

+5619
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
name: Publish gcl-classifier
2+
# Adapted from https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
3+
4+
on: push
5+
6+
jobs:
7+
build:
8+
name: Build distribution
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- uses: actions/checkout@v4
13+
with:
14+
persist-credentials: false
15+
- name: Set up Python
16+
uses: actions/setup-python@v5
17+
with:
18+
python-version: "3.x"
19+
- name: Install pypa/build
20+
run: >-
21+
python3 -m
22+
pip install
23+
build
24+
--user
25+
- name: Build a binary wheel and a source tarball
26+
run: python3 -m build
27+
- name: Store the distribution packages
28+
uses: actions/upload-artifact@v4
29+
with:
30+
name: python-package-distributions
31+
path: dist/
32+
33+
publish-to-testpypi:
34+
name: TestPyPI
35+
needs:
36+
- build
37+
runs-on: ubuntu-latest
38+
39+
environment:
40+
name: testpypi
41+
url: https://test.pypi.org/p/gcl-classifier
42+
43+
permissions:
44+
id-token: write # IMPORTANT: mandatory for trusted publishing
45+
46+
steps:
47+
- name: Download all the dists
48+
uses: actions/download-artifact@v4
49+
with:
50+
name: python-package-distributions
51+
path: dist/
52+
- name: Publish distribution to TestPyPI
53+
uses: pypa/gh-action-pypi-publish@release/v1
54+
with:
55+
repository-url: https://test.pypi.org/legacy/
56+
skip-existing: true # skip if version already exists
57+
58+
publish-to-pypi:
59+
name: PyPI
60+
if: startsWith(github.ref, 'refs/tags/release') # only publish to PyPI on tag pushes that start with release
61+
needs:
62+
- build
63+
runs-on: ubuntu-latest
64+
environment:
65+
name: pypi
66+
url: https://pypi.org/p/gcl-classifier
67+
permissions:
68+
id-token: write # IMPORTANT: mandatory for trusted publishing
69+
70+
steps:
71+
- name: Download all the dists
72+
uses: actions/download-artifact@v4
73+
with:
74+
name: python-package-distributions
75+
path: dist/
76+
- name: Publish distribution to PyPI
77+
uses: pypa/gh-action-pypi-publish@release/v1
78+
79+
github-release:
80+
name: Github Release
81+
needs:
82+
- publish-to-pypi
83+
runs-on: ubuntu-latest
84+
85+
permissions:
86+
contents: write
87+
id-token: write
88+
89+
steps:
90+
- name: Download all the dists
91+
uses: actions/download-artifact@v4
92+
with:
93+
name: python-package-distributions
94+
path: dist/
95+
- name: Sign the dists with Sigstore
96+
uses: sigstore/gh-action-sigstore-python@v3.0.0
97+
with:
98+
inputs: >-
99+
./dist/*.tar.gz
100+
./dist/*.whl
101+
- name: Create GitHub Release
102+
env:
103+
GITHUB_TOKEN: ${{ github.token }}
104+
run: >-
105+
gh release create
106+
"$GITHUB_REF_NAME"
107+
--repo "$GITHUB_REPOSITORY"
108+
--notes ""
109+
- name: Upload artifact signatures to GitHub Release
110+
env:
111+
GITHUB_TOKEN: ${{ github.token }}
112+
# Upload to GitHub Release using the `gh` CLI.
113+
# `dist/` contains the built packages, and the
114+
# sigstore-produced signatures and certificates.
115+
run: >-
116+
gh release upload
117+
"$GITHUB_REF_NAME" dist/**
118+
--repo "$GITHUB_REPOSITORY"

.github/workflows/unit-tests.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Unit Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
unit-tests:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: [ '3.13' ]
15+
16+
steps:
17+
- name: Checkout repository
18+
uses: actions/checkout@v4
19+
20+
- name: Set up Python
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
25+
- name: Install uv
26+
run: |
27+
curl -LsSf https://astral.sh/uv/install.sh | sh
28+
echo "${HOME}/.local/bin" >> $GITHUB_PATH
29+
30+
- name: Install package with dependencies
31+
run: |
32+
uv sync
33+
uv pip install pytest
34+
35+
- name: Run Unit Tests
36+
run: echo "Skipped" # uv run pytest tests/ -v

.gitignore

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
.venv
2+
.idea
3+
data
4+
outputs
5+
6+
# Byte-compiled / optimized / DLL files
7+
__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
share/python-wheels/
29+
*.egg-info/
30+
.installed.cfg
31+
*.egg
32+
MANIFEST
33+
34+
# Notebooks
35+
.ipynb_checkpoints/

README.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# GCL Classifier
2+
3+
A ganglion cell layer (GCL) classifier trained on functional two-photon calcium imaging recordings of mouse retinas in response to chirp and moving bar stimuli.
4+
The labels and classification are based on Baden, Franke, Berens et al. (2016) "The functional diversity of retinal ganglion cells in the mouse." Nature 529.7586 (2016): 345-350.
5+
The classifier was already used in two publictions:
6+
1. Qiu, et al. "Efficient coding of natural scenes improves neural system identification." PLoS computational biology 19.4 (2023): e1011037.
7+
2. Gonschorek, et al. "Nitric oxide modulates contrast suppression in a subset of mouse retinal ganglion cells." Elife 13 (2025): RP98742.
8+
9+
10+
# Installation
11+
12+
## Quick Start
13+
```bash
14+
pip install gcl_classifier
15+
```
16+
17+
If you additonally want to run the attached notebooks, install the extra dependencies using:
18+
```bash
19+
pip install "gcl_classifier[notebook]"
20+
```
21+
22+
## First Time Setup
23+
24+
On first use, the model will be downloaded from Hugging Face Hub.
25+
This happens automatically and only needs to be done once.
26+
```python
27+
from gcl_classifier import get_model
28+
29+
# Downloads model to ~/.cache/gcl_classifier/
30+
model = get_model()
31+
```
32+
33+
Let's use this model for celltype classification:
34+
```python
35+
import numpy as np
36+
from gcl_classifier.data import get_data
37+
from gcl_classifier.classifier import extract_features
38+
39+
# "Fake" preprocessed data for two cells
40+
bar_ds_pvalues = np.array([0.04, 0.20])
41+
roi_sizes_um2 = np.array([43.0, 56.5])
42+
chirp_traces = np.random.random((2, 249))
43+
bar_traces = np.random.random((2, 32))
44+
45+
# Load feat matrix to transform data into feature space
46+
data = get_data()
47+
chirp_features = data["chirp_feats"]
48+
bar_features = data["bar_feats"]
49+
50+
# Extract the features for classifier
51+
X, feature_names = extract_features(
52+
preproc_chirps=chirp_traces,
53+
preproc_bars=bar_traces,
54+
bar_ds_pvalues=bar_ds_pvalues,
55+
roi_size_um2s=roi_sizes_um2,
56+
chirp_features=chirp_features,
57+
bar_features=bar_features,
58+
)
59+
60+
# Get predictions and probabilities
61+
predictions = model.predict(X)
62+
predictions_probs = model.predict_proba(X)
63+
```
64+
65+
## Cache Location
66+
67+
Models (and training data) are cached at:
68+
- Linux/Mac: `~/.cache/gcl_classifier/`
69+
- Windows: `C:\Users\<username>\.cache\gcl_classifier\`
70+
71+
To clear cache and re-download the model:
72+
```python
73+
from gcl_classifier.model import load_model
74+
model, model_dict = load_model(force_download=True)
75+
```
76+
77+
You can do the same for the training data, but you only need to do this if you want to re-train the model.
78+
```python
79+
from gcl_classifier.data import load_data
80+
data = load_data(force_download=True)
81+
```
82+
83+
## Offline Usage
84+
85+
After initial download, the package works offline using the cached model.

gcl_classifier/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .model import get_model, predict, predict_proba
2+
from .data import get_data

gcl_classifier/classifier.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import pickle
2+
3+
import numpy as np
4+
5+
from gcl_classifier.labels import baden_cluster_id_to_group_id, baden_group_id_to_supergroup, BADEN_CLUSTER_INFO
6+
7+
8+
def classify_cells(preproc_chirps, preproc_bars, bar_ds_pvalues, roi_size_um2s,
9+
chirp_features, bar_features, classifier):
10+
features, feature_names = extract_features(
11+
preproc_chirps, preproc_bars, bar_ds_pvalues, roi_size_um2s, chirp_features, bar_features)
12+
probs = classifier.predict_proba(features)
13+
return probs
14+
15+
16+
def baden16_cluster_probs_to_info(probs):
17+
if len(probs) != 75:
18+
raise ValueError(f"Expected 75 probabilities corresponding to 75 Baden clusters, got {len(probs)}.")
19+
20+
cluster_id = np.argmax(probs) + 1 # Cluster IDs are 1-indexed
21+
group_id = baden_cluster_id_to_group_id(cluster_id)
22+
supergroup = baden_group_id_to_supergroup(group_id)
23+
prob_cluster = probs[cluster_id - 1]
24+
25+
group_ids = BADEN_CLUSTER_INFO[:, 2].astype(int)
26+
supergroups = BADEN_CLUSTER_INFO[:, 3].astype(str)
27+
28+
prob_group = np.sum(probs[group_ids == group_id])
29+
prob_supergroup = np.sum(probs[supergroups == supergroup])
30+
prob_rgc = np.sum(probs[supergroups != 'dAC'])
31+
prob_class = (1. - prob_rgc) if supergroup == 'dAC' else prob_rgc
32+
33+
return cluster_id, group_id, supergroup, prob_cluster, prob_group, prob_supergroup, prob_class
34+
35+
36+
def extract_features(
37+
preproc_chirps,
38+
preproc_bars,
39+
bar_ds_pvalues,
40+
roi_size_um2s,
41+
chirp_features,
42+
bar_features,
43+
) -> tuple[np.ndarray, list[str]]:
44+
"""
45+
Transforms the preprocessed chirps and bars using the provided chirp/bar features.
46+
Concatenates the results with the bar_ds_pvalues and roi_sizes, and returns them together with feature names.
47+
The result can be used as input to the classifier.
48+
"""
49+
features = np.concatenate([
50+
np.dot(preproc_chirps, chirp_features),
51+
np.dot(preproc_bars, bar_features),
52+
bar_ds_pvalues[:, np.newaxis],
53+
roi_size_um2s[:, np.newaxis]
54+
], axis=-1)
55+
56+
feature_names = [f'chirp_{i}' for i in range(chirp_features.shape[1])] + \
57+
[f'bar_{i}' for i in range(bar_features.shape[1])] + ['bar_ds_pvalue', 'roi_size_um2']
58+
59+
return features, feature_names
60+
61+
62+
def check_classifier_dict(clf_dict: dict) -> dict:
63+
assert type(clf_dict) == dict, "Classifier file must contain a dictionary with classifier data."
64+
65+
# Check keys
66+
assert 'classifier' in clf_dict, "Classifier dictionary must contain a 'classifier' key."
67+
assert 'chirp_feats' in clf_dict, "Classifier dictionary must contain a 'chirp_feats' key."
68+
assert 'bar_feats' in clf_dict, "Classifier dictionary must contain a 'bar_feats' key."
69+
assert 'feature_names' in clf_dict, "Classifier dictionary must contain a 'feature_names' key."
70+
assert 'train_x' in clf_dict, "Classifier dictionary must contain a 'train_x' key."
71+
assert 'train_y' in clf_dict, "Classifier dictionary must contain a 'train_y' key."
72+
assert 'y_names' in clf_dict, "Classifier dictionary must contain a 'y_names' key."
73+
74+
# Chek value
75+
assert isinstance(clf_dict['train_x'], np.ndarray), "The 'train_x' key must contain a numpy array."
76+
assert isinstance(clf_dict['train_y'], np.ndarray), "The 'train_y' key must contain a numpy array."
77+
assert clf_dict['train_x'].shape[0] == clf_dict[
78+
'train_y'].size, "The number of samples in 'train_x' and 'train_y' must match."
79+
80+
for val in np.unique(clf_dict['train_y']):
81+
assert val in clf_dict['y_names'].keys(), f"Value {val} in 'train_y' not found in 'y_names'."
82+
83+
# Check if classifier is a valid scikit-learn classifier
84+
from sklearn.base import is_classifier
85+
assert is_classifier(clf_dict['classifier']), "The 'classifier' key must contain a valid scikit-learn classifier."
86+
87+
return clf_dict
88+
89+
90+
def save_classifier_and_data(classifier, chirp_feats, bar_feats, feature_names, train_x, train_y, y_names,
91+
classifier_file, **kwargs) -> None:
92+
"""
93+
Saves the classifier and its metadata to a file.
94+
"""
95+
clf_dict = {
96+
'classifier': classifier,
97+
'chirp_feats': chirp_feats,
98+
'bar_feats': bar_feats,
99+
'feature_names': feature_names,
100+
'train_x': train_x,
101+
'train_y': train_y,
102+
'y_names': y_names,
103+
**kwargs
104+
}
105+
106+
check_classifier_dict(clf_dict)
107+
108+
with open(classifier_file, 'wb') as f:
109+
pickle.dump(clf_dict, f)

0 commit comments

Comments
 (0)