Skip to content

Commit

Permalink
Add implementation of EEGThoughtDecoder for training & evaluation (#1)
Browse files Browse the repository at this point in the history
* Add EEG data loader

* Add the Transformer architecture

* Add the Graph Neural Network for spatial relationships

* Add the Mixture of Experts model for specialization

* Add the Agentic model for dynamic adaptation

* Combine the Transformer, GNN, MoE and Agentic model

* Add training script

* Add evaluation script

* Fix typing related bugs

* Bumping version from 0.1.0 to 0.1.1
  • Loading branch information
victor-iyi authored Sep 19, 2024
1 parent 14ab889 commit 503517c
Show file tree
Hide file tree
Showing 21 changed files with 916 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "eeg-thought-decoder"
version = "0.1.0"
version = "0.1.1"
description = "Decoding Human Thought from EEG Signals"
license = "MIT"

Expand Down Expand Up @@ -35,6 +35,12 @@ python = "^3.11"

# Render rich text, progress bars, syntax highlighting and more to the terminal
rich = "^13.8.1"
# Differentiate, compile, and transform NumPy code.
jax = "^0.4.33"
# A neural network library for JAX designed for flexibility.
flax = "^0.9.0"
# A gradient processing and optimization library in JAX.
optax = "^0.2.3"


[tool.poetry.group.dev.dependencies]
Expand Down
18 changes: 18 additions & 0 deletions src/thought_decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from thought_decoder.logging import logger
from thought_decoder.models.agentic.policy import AgenticModel
from thought_decoder.data.data_loader import EEGDataLoader
from thought_decoder.models.transformer.encoder import TransformerEncoder
from thought_decoder.models.gnn.graph_nn import GNN
from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts
from thought_decoder.models.decoder import EEGThoughtDecoder


__all__ = [
'AgenticModel',
'EEGDataLoader',
'EEGThoughtDecoder',
'GNN',
'MixtureOfExperts',
'TransformerEncoder',
'logger',
]
4 changes: 4 additions & 0 deletions src/thought_decoder/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from thought_decoder.data.data_loader import EEGDataLoader


__all__ = ['EEGDataLoader']
61 changes: 61 additions & 0 deletions src/thought_decoder/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Data Loader module.
This module handles the loading and pre-processing of EEG data for training and evaluation.
"""

from collections.abc import Iterator

from pathlib import Path

import jax
import jax.numpy as jnp


class EEGDataLoader:
"""EEG Data Loader.
Loads and pre-processes EEG data for training and evaluation.
"""

def __init__(self, data_dir: Path | str, batch_size: int = 32, shuffle: bool = True) -> None:
"""Initialize the EEGDataLoader.
Args:
data_dir (Path | str): The path to the data directory.
batch_size (int): The batch size for the data loader.
Defaults to 32.
shuffle (bool): Whether to shuffle the data.
Defaults to True.
"""
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.shuffle = shuffle

# TODO(victor-iyi): You might wanna get the data from a remote source.
self._load_data()

def get_batches(self) -> Iterator[tuple[jax.Array, jax.Array]]:
"""Get an iterator over the data batches.
Yields:
tuple[Array, Array]: A tuple of input and target arrays.
"""
dataset_size = self.inputs.shape[0]
indices = jnp.arange(dataset_size)

if self.shuffle:
indices = jax.random.permutation(jax.random.PRNGKey(0), indices)

for start_idx in range(0, dataset_size, self.batch_size):
end_idx = min(start_idx + self.batch_size, dataset_size)
batch_indices = indices[start_idx:end_idx]
yield self.inputs[batch_indices], self.targets[batch_indices]

def _load_data(self) -> None:
"""Load the EEG data."""
self.inputs = jnp.load(self.data_dir / 'inputs.npy')
self.targets = jnp.load(self.data_dir / 'targets.npy')
67 changes: 67 additions & 0 deletions src/thought_decoder/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Evaluation script.
Handles the evaluation loop for the `EEGThoughtDecoder` model.
"""

# pylint: disable=not-callable
# mypy: disable-error-code="assignment,no-untyped-call"

import jax.numpy as jnp
from flax.training import checkpoints, train_state

from thought_decoder.logging import logger
from thought_decoder.data import EEGDataLoader
from thought_decoder.models import EEGThoughtDecoder
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams


def evaluate() -> None:
"""Evaluate the model."""
# Data loader.
data_loader = EEGDataLoader(data_dir='data/test/', batch_size=32, shuffle=False)

# Model intialization.
transformer_params = TransformerParams(
num_layers=4,
model_dim=128,
num_heads=8,
diff=512,
input_vocab_size=1_000,
maximum_position_encoding=1_000,
dropout_rate=0.1,
)
gnn_params = GNNParams(
num_layers=2,
hidden_dim=128,
adjacency_matrix=jnp.ones((64, 64)), # Placeholder adjacency matrix.
)
moe_params = MixtureOfExpertsParams(num_experts=4, expert_output_dim=128)
agentic_params = AgenticParams(action_dim=10)

model = EEGThoughtDecoder(
transformer_params=transformer_params,
gnn_params=gnn_params,
moe_params=moe_params,
agentic_params=agentic_params,
)

# Restore checkpoint.
params = checkpoints.restore_checkpoint('checkpoints/', target=None)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=None)

total_correct, total_samples = 0, 0

for batch in data_loader.get_batches():
logits = state.apply_fn({'params': state.params}, batch[0], training=False)

predictions = jnp.argmax(logits, axis=-1)
total_correct += jnp.sum(predictions == batch[1])
total_samples += batch[1].shape[0]

accuracy = total_correct / total_samples
logger.info(f'Accuracy: {accuracy:.4f}')


if __name__ == '__main__':
evaluate()
52 changes: 52 additions & 0 deletions src/thought_decoder/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Logging utilities.
This module provides a set of utilities for logging messages to the console.
"""

import logging

from rich.console import Console
from rich.logging import RichHandler


def set_logger(name: str = 'thought_decoder', log_level: int = logging.INFO) -> logging.Logger:
"""Set up a logger with a specified name and log level.
Args:
name (str): The name of the logger.
log_level (int): The log level for this logger.
Returns:
logging.Logger: The logger object.
"""
# Create a logger with a specified name.
_logger = logging.getLogger(name)

# Set the log level for this logger.
_logger.setLevel(log_level)

# Create a RichHandler for console output.
console_handler = RichHandler(
console=Console(),
rich_tracebacks=True,
log_time_format='[%X]',
# tracebacks_show_locals=True,
keywords=[name],
)

# Create a formatter.
formatter = logging.Formatter('%(name)-8s %(message)s')
console_handler.setFormatter(formatter)

_logger.addHandler(console_handler)

# Prevent logging from propagating to the root logger.
_logger.propagate = False

return _logger


# Global library logger.
logger: logging.Logger = set_logger()
19 changes: 19 additions & 0 deletions src/thought_decoder/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from thought_decoder.models.agentic.policy import AgenticModel
from thought_decoder.models.transformer.encoder import TransformerEncoder
from thought_decoder.models.gnn.graph_nn import GNN
from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts
from thought_decoder.models.decoder import EEGThoughtDecoder
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams


__all__ = [
'AgenticModel',
'AgenticParams',
'EEGThoughtDecoder',
'GNN',
'GNNParams',
'MixtureOfExperts',
'MixtureOfExpertsParams',
'TransformerEncoder',
'TransformerParams',
]
4 changes: 4 additions & 0 deletions src/thought_decoder/models/agentic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from thought_decoder.models.agentic.policy import AgenticModel


__all__ = ['AgenticModel']
53 changes: 53 additions & 0 deletions src/thought_decoder/models/agentic/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Agentic policy for the agent in the environment.
The agentic model helps learning for dynamic adaptation using Reinforcement Learning (RL) algorithms.
"""

from jax import Array
from flax import linen as nn


class PolicyNetwork(nn.Module):
"""Policy Network Module."""

action_dim: int

def __call__(self, x: Array) -> Array:
"""Apply the policy network.
Args:
x (Array): The input tensor of shape (batch_size, input_dim).
Returns:
Array: The output tensor of shape (batch_size, action_dim).
"""
x = nn.Dense(features=128)(x)
x = nn.relu(x)
logits = nn.Dense(features=self.action_dim)(x)
return logits # actions probabilities.


class AgenticModel(nn.Module):
"""Agentic Model Module."""

action_dim: int

def setup(self) -> None:
"""Set up the module."""
self.policy_network = PolicyNetwork(action_dim=self.action_dim)

def __call__(self, x: Array) -> Array:
"""Apply the module.
Args:
x (Array): The input tensor of shape (batch_size, input_dim).
Returns:
Array: The output tensor of shape (batch_size, action_dim).
"""
logits = self.policy_network(x)
action_probs = nn.softmax(logits)
return action_probs
54 changes: 54 additions & 0 deletions src/thought_decoder/models/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""This module combines Transformer, GNN, Mixture of Experts and Agentic Model into one architecture."""

# pylint: disable=attribute-defined-outside-init
# from collections.abc import Mapping

from jax import Array
from flax import linen as nn

from thought_decoder.models import AgenticModel, GNN, MixtureOfExperts, TransformerEncoder
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams


class EEGThoughtDecoder(nn.Module):
"""EEG Thought Decoder Module."""

transformer_params: TransformerParams # Mapping[str, int | float]
gnn_params: GNNParams # Mapping[str, int | Array]
moe_params: MixtureOfExpertsParams # Mapping[str, int]
agentic_params: AgenticParams # Mapping[str, int]

def setup(self) -> None:
"""Set up the module."""
self.transformer = TransformerEncoder(**self.transformer_params)
self.gnn = GNN(**self.gnn_params)
self.moe = MixtureOfExperts(**self.moe_params)
self.agent = AgenticModel(**self.agentic_params)
self.classifier = nn.Dense(features=self.agentic_params['action_dim'])

def __call__(self, x: Array, training: bool) -> Array:
"""Apply the module.
Args:
x (Array): The input tensor of shape (batch_size, seq_len, input_dim).
training (bool): Whether the model is training or not.
Returns:
Array: The output tensor of shape (batch_size, action_dim).
"""
# Transformer Encoder for temporal dynamics.
x = self.transformer(x, training)

# GNN for spatial relationships.
x = self.gnn(x)

# Mixture of Experts for specialization.
x = self.moe(x)

# Agentic Model for dynamic adaptation.
action_probs = self.agent(x)

# Classifier.
logits = self.classifier(action_probs)
return logits
4 changes: 4 additions & 0 deletions src/thought_decoder/models/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from thought_decoder.models.gnn.graph_nn import GraphConvolutionLayer, GNN


__all__ = ['GraphConvolutionLayer', 'GNN']
Loading

0 comments on commit 503517c

Please sign in to comment.