Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions prediction-gaussiannb/jax/GaussianNbJAX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Gaussian Naive Bayes JAX Implementation

A JAX-based Gaussian Naive Bayes classifier for EEG movement classification.

## Files

- **`gaussiannb_jax.py`**: Core model implementation with save/load utilities
- **`gaussiannb_jax_train.ipynb`**: Full training pipeline notebook
- **`inference_gaussiannb.py`**: Inference script template
- **`README.md`**: This file

## Quick Start

### 1. Update Data Path

In `gaussiannb_jax_train.ipynb`, replace:
```python
BASE_DIR = 'path/to/your/data'
```

with your actual data directory containing movement-labeled subdirectories (backward, forward, landing, left, right, takeoff), each with CSV files.

### 2. Run Training Notebook

```bash
jupyter notebook gaussiannb_jax_train.ipynb
```

Execute all cells to:
- Load EEG data
- Engineer features (variance filtering, statistics, transforms)
- Train GaussianNB with uniform class priors
- Evaluate on test set
- Export model to `gaussiannb_jax_model.pkl`

### 3. Use Trained Model

Load and use the model:
```python
from gaussiannb_jax import load_model

model, state = load_model('gaussiannb_jax_model.pkl')
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)
```

## Model Details

### GaussianNBJAX

**Equations**:

For each class $c$ and feature $j$:
- Mean: $\mu_{cj} = \frac{1}{n_c} \sum_{i \in c} X_{ij}$
- Variance: $\sigma_{cj}^2 = \frac{1}{n_c} \sum_{i \in c} X_{ij}^2 - \mu_{cj}^2 + \lambda$

Log-likelihood for sample $x$:
$$\log P(x|c) = \sum_j \left[ -\frac{(x_j - \mu_{cj})^2}{2\sigma_{cj}^2} - \frac{1}{2}\log(2\pi\sigma_{cj}^2) \right] + \log P(c)$$

### Feature Engineering

1. **Variance Filter**: Removes near-constant raw channels (threshold: $10^{-3}$)
2. **Derived Stats**: Per-row mean, std, absolute mean, energy
3. **Nonlinear Transforms**: Absolute values, log1p of absolute values
4. **Standardization**: Final z-score normalization

Input: 32 raw channels → Output: 55 engineered features

### Training Config

- **VAR_SMOOTHING**: $10^{-9}$ (regularization)
- **UNIFORM_PRIORS**: True (equal class priors)
- **TRAIN_RATIO**: 0.8 (80% train, 20% test)
- **RNG_SEED**: 42

## Performance

- **Accuracy**: ~0.29 (6-class balanced classification)
- **Best Recall**: Landing (0.64)
- **Best Precision**: Forward (0.59)

## Data Format

Expected CSV structure: Tab-delimited numeric values (no header).

Directory layout:
```
path/to/your/data/
├── backward/
│ ├── *.csv
├── forward/
│ ├── *.csv
├── landing/
├── left/
├── right/
└── takeoff/
```

## API Reference

### `GaussianNBJAX`

```python
class GaussianNBJAX:
def __init__(self, var_smoothing=1e-9):
"""Initialize model."""

def fit(self, X, y):
"""Fit model (X: n_samples×n_features, y: labels 0..C-1)."""

def predict(self, X) -> np.ndarray:
"""Predict class labels."""

def predict_proba(self, X) -> np.ndarray:
"""Predict class probabilities."""

def predict_log_proba(self, X) -> np.ndarray:
"""Predict log-probabilities."""
```

### Utilities

```python
def save_model(path: str, model: GaussianNBJAX, meta: dict):
"""Save model to pickle."""

def load_model(path: str) -> tuple[GaussianNBJAX, dict]:
"""Load model and metadata from pickle."""
```

## Next Steps

- **Improve Accuracy**: Add windowed Welch bandpower (delta/theta/alpha/beta/gamma)
- **Temporal Features**: Aggregate statistics over fixed time windows
- **Class Balancing**: Use weighted priors or resampling
- **Dimensionality Reduction**: Apply PCA before classification

## Requirements

- numpy >= 1.24
- jax >= 0.4.20
- pandas (for data loading)
- pickle (standard library)


206 changes: 206 additions & 0 deletions prediction-gaussiannb/jax/algorithm_gnb_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
Gaussian Naive Bayes classifier implementation using JAX.

This module provides a JAX-based GaussianNB implementation for EEG movement
classification. The model is trained offline and can be saved/loaded for
inference on new data.

Usage:
# Training
model = GaussianNBJAX(var_smoothing=1e-9)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

# Saving/Loading
save_model('path/to/model.pkl', model, metadata_dict)
model, meta = load_model('path/to/model.pkl')
"""

from typing import Any, Dict, Tuple
import pickle

import numpy as np
import jax
import jax.numpy as jnp


class GaussianNBJAX:
"""Gaussian Naive Bayes classifier using JAX for efficient computation.

Attributes:
var_smoothing: Regularization parameter for variance estimates.
class_prior_: Prior probability of each class.
theta_: Mean feature vectors per class.
var_: Variance estimates per feature per class.
n_classes_: Number of classes.
n_features_: Number of features.
"""

def __init__(self, var_smoothing: float = 1e-9):
"""Initialize GaussianNBJAX.

Args:
var_smoothing: Smoothing value added to variance estimates to prevent
singularity (default: 1e-9).
"""
self.var_smoothing = float(var_smoothing)
self.class_prior_ = None
self.theta_ = None
self.var_ = None
self.n_classes_ = None
self.n_features_ = None

def fit(self, X: np.ndarray, y: np.ndarray):
"""Fit the Gaussian Naive Bayes model.

Args:
X: Training feature matrix (n_samples, n_features).
y: Training labels (n_samples,), assumed 0..C-1.

Returns:
self
"""
X_j = jnp.asarray(X, dtype=jnp.float32)
y_j = jnp.asarray(y, dtype=jnp.int32)

n_samples = X_j.shape[0]
n_features = X_j.shape[1]

num_classes = int(jnp.max(y_j) + 1)
counts = jnp.bincount(y_j, length=num_classes)
counts_f = jnp.maximum(counts.astype(jnp.float32), 1.0)

def sums_for_class(c):
mask = (y_j == c)
masked = jnp.where(mask[:, None], X_j, 0.0)
return jnp.sum(masked, axis=0)

def sums2_for_class(c):
mask = (y_j == c)
masked2 = jnp.where(mask[:, None], X_j * X_j, 0.0)
return jnp.sum(masked2, axis=0)

classes = jnp.arange(num_classes)
sums = jax.vmap(sums_for_class)(classes)
sums2 = jax.vmap(sums2_for_class)(classes)

means = sums / counts_f[:, None]
vars_ = (sums2 / counts_f[:, None]) - jnp.square(means)
vars_ = jnp.maximum(vars_, self.var_smoothing)

priors = counts.astype(jnp.float32) / float(n_samples)

self.theta_ = np.asarray(means, dtype=np.float32)
self.var_ = np.asarray(vars_, dtype=np.float32)
self.class_prior_ = np.asarray(priors, dtype=np.float32)
self.n_classes_ = int(num_classes)
self.n_features_ = int(n_features)
return self

@staticmethod
@jax.jit
def _predict_log_proba_jit(X: jnp.ndarray, mu: jnp.ndarray, var: jnp.ndarray,
log_prior: jnp.ndarray) -> jnp.ndarray:
"""JAX-compiled log-probability computation (internal).

Args:
X: Test feature matrix (n_samples, n_features).
mu: Class means (n_classes, n_features).
var: Class variances (n_classes, n_features).
log_prior: Log class priors (n_classes,).

Returns:
Log-probabilities (n_samples, n_classes).
"""
const_term = -0.5 * jnp.sum(jnp.log(2.0 * jnp.pi * var), axis=1)
diff = X[:, None, :] - mu[None, :, :]
quad = -0.5 * jnp.sum((diff * diff) / (var[None, :, :]), axis=2)
log_lik = quad + const_term[None, :]
return log_lik + log_prior[None, :]

def predict_log_proba(self, X: np.ndarray) -> np.ndarray:
"""Compute log-probabilities for samples.

Args:
X: Feature matrix (n_samples, n_features).

Returns:
Log-probabilities (n_samples, n_classes).
"""
X_j = jnp.asarray(X, dtype=jnp.float32)
mu = jnp.asarray(self.theta_, dtype=jnp.float32)
var = jnp.asarray(self.var_, dtype=jnp.float32)
log_prior = jnp.log(jnp.asarray(self.class_prior_, dtype=jnp.float32) + 1e-12)
out = self._predict_log_proba_jit(X_j, mu, var, log_prior)
return np.asarray(out)

def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Compute class probabilities for samples.

Args:
X: Feature matrix (n_samples, n_features).

Returns:
Class probabilities (n_samples, n_classes), row-wise normalized.
"""
logp = self.predict_log_proba(X)
m = np.max(logp, axis=1, keepdims=True)
p = np.exp(logp - m)
p /= np.sum(p, axis=1, keepdims=True)
return p

def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict class labels for samples.

Args:
X: Feature matrix (n_samples, n_features).

Returns:
Predicted class labels (n_samples,).
"""
logp = self.predict_log_proba(X)
return np.argmax(logp, axis=1).astype(np.int32)


def save_model(path: str, model: GaussianNBJAX, meta: Dict[str, Any]) -> None:
"""Save trained model to disk.

Args:
path: Output file path (e.g., 'path/to/model.pkl').
model: Fitted GaussianNBJAX instance.
meta: Optional metadata dictionary (e.g., accuracy, class names).
"""
payload: Dict[str, Any] = {
'model': {
'var_smoothing': float(model.var_smoothing),
'class_prior_': model.class_prior_,
'theta_': model.theta_,
'var_': model.var_,
'n_classes_': model.n_classes_,
'n_features_': model.n_features_,
},
'meta': meta or {},
'format': 'GaussianNBJAX-pickle-v1'
}
with open(path, 'wb') as f:
pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_model(path: str) -> Tuple[GaussianNBJAX, Dict[str, Any]]:
"""Load trained model from disk.

Args:
path: Path to pickle file (e.g., 'path/to/model.pkl').

Returns:
Tuple of (fitted GaussianNBJAX, metadata dictionary).
"""
with open(path, 'rb') as f:
payload = pickle.load(f)
m = GaussianNBJAX(var_smoothing=payload['model'].get('var_smoothing', 1e-9))
m.class_prior_ = np.array(payload['model']['class_prior_'], dtype=np.float32)
m.theta_ = np.array(payload['model']['theta_'], dtype=np.float32)
m.var_ = np.array(payload['model']['var_'], dtype=np.float32)
m.n_classes_ = int(payload['model']['n_classes_'])
m.n_features_ = int(payload['model']['n_features_'])
return m, payload.get('meta', {})
Loading