-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add implementation of
EEGThoughtDecoder
for training & evaluation (#1)
* Add EEG data loader * Add the Transformer architecture * Add the Graph Neural Network for spatial relationships * Add the Mixture of Experts model for specialization * Add the Agentic model for dynamic adaptation * Combine the Transformer, GNN, MoE and Agentic model * Add training script * Add evaluation script * Fix typing related bugs * Bumping version from 0.1.0 to 0.1.1
- Loading branch information
1 parent
14ab889
commit 503517c
Showing
21 changed files
with
916 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from thought_decoder.logging import logger | ||
from thought_decoder.models.agentic.policy import AgenticModel | ||
from thought_decoder.data.data_loader import EEGDataLoader | ||
from thought_decoder.models.transformer.encoder import TransformerEncoder | ||
from thought_decoder.models.gnn.graph_nn import GNN | ||
from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts | ||
from thought_decoder.models.decoder import EEGThoughtDecoder | ||
|
||
|
||
__all__ = [ | ||
'AgenticModel', | ||
'EEGDataLoader', | ||
'EEGThoughtDecoder', | ||
'GNN', | ||
'MixtureOfExperts', | ||
'TransformerEncoder', | ||
'logger', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from thought_decoder.data.data_loader import EEGDataLoader | ||
|
||
|
||
__all__ = ['EEGDataLoader'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Data Loader module. | ||
This module handles the loading and pre-processing of EEG data for training and evaluation. | ||
""" | ||
|
||
from collections.abc import Iterator | ||
|
||
from pathlib import Path | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
class EEGDataLoader: | ||
"""EEG Data Loader. | ||
Loads and pre-processes EEG data for training and evaluation. | ||
""" | ||
|
||
def __init__(self, data_dir: Path | str, batch_size: int = 32, shuffle: bool = True) -> None: | ||
"""Initialize the EEGDataLoader. | ||
Args: | ||
data_dir (Path | str): The path to the data directory. | ||
batch_size (int): The batch size for the data loader. | ||
Defaults to 32. | ||
shuffle (bool): Whether to shuffle the data. | ||
Defaults to True. | ||
""" | ||
self.data_dir = Path(data_dir) | ||
self.batch_size = batch_size | ||
self.shuffle = shuffle | ||
|
||
# TODO(victor-iyi): You might wanna get the data from a remote source. | ||
self._load_data() | ||
|
||
def get_batches(self) -> Iterator[tuple[jax.Array, jax.Array]]: | ||
"""Get an iterator over the data batches. | ||
Yields: | ||
tuple[Array, Array]: A tuple of input and target arrays. | ||
""" | ||
dataset_size = self.inputs.shape[0] | ||
indices = jnp.arange(dataset_size) | ||
|
||
if self.shuffle: | ||
indices = jax.random.permutation(jax.random.PRNGKey(0), indices) | ||
|
||
for start_idx in range(0, dataset_size, self.batch_size): | ||
end_idx = min(start_idx + self.batch_size, dataset_size) | ||
batch_indices = indices[start_idx:end_idx] | ||
yield self.inputs[batch_indices], self.targets[batch_indices] | ||
|
||
def _load_data(self) -> None: | ||
"""Load the EEG data.""" | ||
self.inputs = jnp.load(self.data_dir / 'inputs.npy') | ||
self.targets = jnp.load(self.data_dir / 'targets.npy') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Evaluation script. | ||
Handles the evaluation loop for the `EEGThoughtDecoder` model. | ||
""" | ||
|
||
# pylint: disable=not-callable | ||
# mypy: disable-error-code="assignment,no-untyped-call" | ||
|
||
import jax.numpy as jnp | ||
from flax.training import checkpoints, train_state | ||
|
||
from thought_decoder.logging import logger | ||
from thought_decoder.data import EEGDataLoader | ||
from thought_decoder.models import EEGThoughtDecoder | ||
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams | ||
|
||
|
||
def evaluate() -> None: | ||
"""Evaluate the model.""" | ||
# Data loader. | ||
data_loader = EEGDataLoader(data_dir='data/test/', batch_size=32, shuffle=False) | ||
|
||
# Model intialization. | ||
transformer_params = TransformerParams( | ||
num_layers=4, | ||
model_dim=128, | ||
num_heads=8, | ||
diff=512, | ||
input_vocab_size=1_000, | ||
maximum_position_encoding=1_000, | ||
dropout_rate=0.1, | ||
) | ||
gnn_params = GNNParams( | ||
num_layers=2, | ||
hidden_dim=128, | ||
adjacency_matrix=jnp.ones((64, 64)), # Placeholder adjacency matrix. | ||
) | ||
moe_params = MixtureOfExpertsParams(num_experts=4, expert_output_dim=128) | ||
agentic_params = AgenticParams(action_dim=10) | ||
|
||
model = EEGThoughtDecoder( | ||
transformer_params=transformer_params, | ||
gnn_params=gnn_params, | ||
moe_params=moe_params, | ||
agentic_params=agentic_params, | ||
) | ||
|
||
# Restore checkpoint. | ||
params = checkpoints.restore_checkpoint('checkpoints/', target=None) | ||
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=None) | ||
|
||
total_correct, total_samples = 0, 0 | ||
|
||
for batch in data_loader.get_batches(): | ||
logits = state.apply_fn({'params': state.params}, batch[0], training=False) | ||
|
||
predictions = jnp.argmax(logits, axis=-1) | ||
total_correct += jnp.sum(predictions == batch[1]) | ||
total_samples += batch[1].shape[0] | ||
|
||
accuracy = total_correct / total_samples | ||
logger.info(f'Accuracy: {accuracy:.4f}') | ||
|
||
|
||
if __name__ == '__main__': | ||
evaluate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Logging utilities. | ||
This module provides a set of utilities for logging messages to the console. | ||
""" | ||
|
||
import logging | ||
|
||
from rich.console import Console | ||
from rich.logging import RichHandler | ||
|
||
|
||
def set_logger(name: str = 'thought_decoder', log_level: int = logging.INFO) -> logging.Logger: | ||
"""Set up a logger with a specified name and log level. | ||
Args: | ||
name (str): The name of the logger. | ||
log_level (int): The log level for this logger. | ||
Returns: | ||
logging.Logger: The logger object. | ||
""" | ||
# Create a logger with a specified name. | ||
_logger = logging.getLogger(name) | ||
|
||
# Set the log level for this logger. | ||
_logger.setLevel(log_level) | ||
|
||
# Create a RichHandler for console output. | ||
console_handler = RichHandler( | ||
console=Console(), | ||
rich_tracebacks=True, | ||
log_time_format='[%X]', | ||
# tracebacks_show_locals=True, | ||
keywords=[name], | ||
) | ||
|
||
# Create a formatter. | ||
formatter = logging.Formatter('%(name)-8s %(message)s') | ||
console_handler.setFormatter(formatter) | ||
|
||
_logger.addHandler(console_handler) | ||
|
||
# Prevent logging from propagating to the root logger. | ||
_logger.propagate = False | ||
|
||
return _logger | ||
|
||
|
||
# Global library logger. | ||
logger: logging.Logger = set_logger() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from thought_decoder.models.agentic.policy import AgenticModel | ||
from thought_decoder.models.transformer.encoder import TransformerEncoder | ||
from thought_decoder.models.gnn.graph_nn import GNN | ||
from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts | ||
from thought_decoder.models.decoder import EEGThoughtDecoder | ||
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams | ||
|
||
|
||
__all__ = [ | ||
'AgenticModel', | ||
'AgenticParams', | ||
'EEGThoughtDecoder', | ||
'GNN', | ||
'GNNParams', | ||
'MixtureOfExperts', | ||
'MixtureOfExpertsParams', | ||
'TransformerEncoder', | ||
'TransformerParams', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from thought_decoder.models.agentic.policy import AgenticModel | ||
|
||
|
||
__all__ = ['AgenticModel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Agentic policy for the agent in the environment. | ||
The agentic model helps learning for dynamic adaptation using Reinforcement Learning (RL) algorithms. | ||
""" | ||
|
||
from jax import Array | ||
from flax import linen as nn | ||
|
||
|
||
class PolicyNetwork(nn.Module): | ||
"""Policy Network Module.""" | ||
|
||
action_dim: int | ||
|
||
def __call__(self, x: Array) -> Array: | ||
"""Apply the policy network. | ||
Args: | ||
x (Array): The input tensor of shape (batch_size, input_dim). | ||
Returns: | ||
Array: The output tensor of shape (batch_size, action_dim). | ||
""" | ||
x = nn.Dense(features=128)(x) | ||
x = nn.relu(x) | ||
logits = nn.Dense(features=self.action_dim)(x) | ||
return logits # actions probabilities. | ||
|
||
|
||
class AgenticModel(nn.Module): | ||
"""Agentic Model Module.""" | ||
|
||
action_dim: int | ||
|
||
def setup(self) -> None: | ||
"""Set up the module.""" | ||
self.policy_network = PolicyNetwork(action_dim=self.action_dim) | ||
|
||
def __call__(self, x: Array) -> Array: | ||
"""Apply the module. | ||
Args: | ||
x (Array): The input tensor of shape (batch_size, input_dim). | ||
Returns: | ||
Array: The output tensor of shape (batch_size, action_dim). | ||
""" | ||
logits = self.policy_network(x) | ||
action_probs = nn.softmax(logits) | ||
return action_probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""This module combines Transformer, GNN, Mixture of Experts and Agentic Model into one architecture.""" | ||
|
||
# pylint: disable=attribute-defined-outside-init | ||
# from collections.abc import Mapping | ||
|
||
from jax import Array | ||
from flax import linen as nn | ||
|
||
from thought_decoder.models import AgenticModel, GNN, MixtureOfExperts, TransformerEncoder | ||
from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams | ||
|
||
|
||
class EEGThoughtDecoder(nn.Module): | ||
"""EEG Thought Decoder Module.""" | ||
|
||
transformer_params: TransformerParams # Mapping[str, int | float] | ||
gnn_params: GNNParams # Mapping[str, int | Array] | ||
moe_params: MixtureOfExpertsParams # Mapping[str, int] | ||
agentic_params: AgenticParams # Mapping[str, int] | ||
|
||
def setup(self) -> None: | ||
"""Set up the module.""" | ||
self.transformer = TransformerEncoder(**self.transformer_params) | ||
self.gnn = GNN(**self.gnn_params) | ||
self.moe = MixtureOfExperts(**self.moe_params) | ||
self.agent = AgenticModel(**self.agentic_params) | ||
self.classifier = nn.Dense(features=self.agentic_params['action_dim']) | ||
|
||
def __call__(self, x: Array, training: bool) -> Array: | ||
"""Apply the module. | ||
Args: | ||
x (Array): The input tensor of shape (batch_size, seq_len, input_dim). | ||
training (bool): Whether the model is training or not. | ||
Returns: | ||
Array: The output tensor of shape (batch_size, action_dim). | ||
""" | ||
# Transformer Encoder for temporal dynamics. | ||
x = self.transformer(x, training) | ||
|
||
# GNN for spatial relationships. | ||
x = self.gnn(x) | ||
|
||
# Mixture of Experts for specialization. | ||
x = self.moe(x) | ||
|
||
# Agentic Model for dynamic adaptation. | ||
action_probs = self.agent(x) | ||
|
||
# Classifier. | ||
logits = self.classifier(action_probs) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from thought_decoder.models.gnn.graph_nn import GraphConvolutionLayer, GNN | ||
|
||
|
||
__all__ = ['GraphConvolutionLayer', 'GNN'] |
Oops, something went wrong.