diff --git a/pyproject.toml b/pyproject.toml index e5b1b1a..4f03aad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "eeg-thought-decoder" -version = "0.1.0" +version = "0.1.1" description = "Decoding Human Thought from EEG Signals" license = "MIT" @@ -35,6 +35,12 @@ python = "^3.11" # Render rich text, progress bars, syntax highlighting and more to the terminal rich = "^13.8.1" +# Differentiate, compile, and transform NumPy code. +jax = "^0.4.33" +# A neural network library for JAX designed for flexibility. +flax = "^0.9.0" +# A gradient processing and optimization library in JAX. +optax = "^0.2.3" [tool.poetry.group.dev.dependencies] diff --git a/src/thought_decoder/__init__.py b/src/thought_decoder/__init__.py index e69de29..24ebeeb 100644 --- a/src/thought_decoder/__init__.py +++ b/src/thought_decoder/__init__.py @@ -0,0 +1,18 @@ +from thought_decoder.logging import logger +from thought_decoder.models.agentic.policy import AgenticModel +from thought_decoder.data.data_loader import EEGDataLoader +from thought_decoder.models.transformer.encoder import TransformerEncoder +from thought_decoder.models.gnn.graph_nn import GNN +from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts +from thought_decoder.models.decoder import EEGThoughtDecoder + + +__all__ = [ + 'AgenticModel', + 'EEGDataLoader', + 'EEGThoughtDecoder', + 'GNN', + 'MixtureOfExperts', + 'TransformerEncoder', + 'logger', +] diff --git a/src/thought_decoder/data/__init__.py b/src/thought_decoder/data/__init__.py new file mode 100644 index 0000000..ada5720 --- /dev/null +++ b/src/thought_decoder/data/__init__.py @@ -0,0 +1,4 @@ +from thought_decoder.data.data_loader import EEGDataLoader + + +__all__ = ['EEGDataLoader'] diff --git a/src/thought_decoder/data/data_loader.py b/src/thought_decoder/data/data_loader.py new file mode 100644 index 0000000..748bc14 --- /dev/null +++ b/src/thought_decoder/data/data_loader.py @@ -0,0 +1,61 @@ +"""Data Loader module. + +This module handles the loading and pre-processing of EEG data for training and evaluation. + +""" + +from collections.abc import Iterator + +from pathlib import Path + +import jax +import jax.numpy as jnp + + +class EEGDataLoader: + """EEG Data Loader. + + Loads and pre-processes EEG data for training and evaluation. + + """ + + def __init__(self, data_dir: Path | str, batch_size: int = 32, shuffle: bool = True) -> None: + """Initialize the EEGDataLoader. + + Args: + data_dir (Path | str): The path to the data directory. + batch_size (int): The batch size for the data loader. + Defaults to 32. + shuffle (bool): Whether to shuffle the data. + Defaults to True. + + """ + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.shuffle = shuffle + + # TODO(victor-iyi): You might wanna get the data from a remote source. + self._load_data() + + def get_batches(self) -> Iterator[tuple[jax.Array, jax.Array]]: + """Get an iterator over the data batches. + + Yields: + tuple[Array, Array]: A tuple of input and target arrays. + + """ + dataset_size = self.inputs.shape[0] + indices = jnp.arange(dataset_size) + + if self.shuffle: + indices = jax.random.permutation(jax.random.PRNGKey(0), indices) + + for start_idx in range(0, dataset_size, self.batch_size): + end_idx = min(start_idx + self.batch_size, dataset_size) + batch_indices = indices[start_idx:end_idx] + yield self.inputs[batch_indices], self.targets[batch_indices] + + def _load_data(self) -> None: + """Load the EEG data.""" + self.inputs = jnp.load(self.data_dir / 'inputs.npy') + self.targets = jnp.load(self.data_dir / 'targets.npy') diff --git a/src/thought_decoder/evaluate.py b/src/thought_decoder/evaluate.py new file mode 100644 index 0000000..598c226 --- /dev/null +++ b/src/thought_decoder/evaluate.py @@ -0,0 +1,67 @@ +"""Evaluation script. + +Handles the evaluation loop for the `EEGThoughtDecoder` model. + +""" + +# pylint: disable=not-callable +# mypy: disable-error-code="assignment,no-untyped-call" + +import jax.numpy as jnp +from flax.training import checkpoints, train_state + +from thought_decoder.logging import logger +from thought_decoder.data import EEGDataLoader +from thought_decoder.models import EEGThoughtDecoder +from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams + + +def evaluate() -> None: + """Evaluate the model.""" + # Data loader. + data_loader = EEGDataLoader(data_dir='data/test/', batch_size=32, shuffle=False) + + # Model intialization. + transformer_params = TransformerParams( + num_layers=4, + model_dim=128, + num_heads=8, + diff=512, + input_vocab_size=1_000, + maximum_position_encoding=1_000, + dropout_rate=0.1, + ) + gnn_params = GNNParams( + num_layers=2, + hidden_dim=128, + adjacency_matrix=jnp.ones((64, 64)), # Placeholder adjacency matrix. + ) + moe_params = MixtureOfExpertsParams(num_experts=4, expert_output_dim=128) + agentic_params = AgenticParams(action_dim=10) + + model = EEGThoughtDecoder( + transformer_params=transformer_params, + gnn_params=gnn_params, + moe_params=moe_params, + agentic_params=agentic_params, + ) + + # Restore checkpoint. + params = checkpoints.restore_checkpoint('checkpoints/', target=None) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=None) + + total_correct, total_samples = 0, 0 + + for batch in data_loader.get_batches(): + logits = state.apply_fn({'params': state.params}, batch[0], training=False) + + predictions = jnp.argmax(logits, axis=-1) + total_correct += jnp.sum(predictions == batch[1]) + total_samples += batch[1].shape[0] + + accuracy = total_correct / total_samples + logger.info(f'Accuracy: {accuracy:.4f}') + + +if __name__ == '__main__': + evaluate() diff --git a/src/thought_decoder/logging.py b/src/thought_decoder/logging.py new file mode 100644 index 0000000..8d324d1 --- /dev/null +++ b/src/thought_decoder/logging.py @@ -0,0 +1,52 @@ +"""Logging utilities. + +This module provides a set of utilities for logging messages to the console. + +""" + +import logging + +from rich.console import Console +from rich.logging import RichHandler + + +def set_logger(name: str = 'thought_decoder', log_level: int = logging.INFO) -> logging.Logger: + """Set up a logger with a specified name and log level. + + Args: + name (str): The name of the logger. + log_level (int): The log level for this logger. + + Returns: + logging.Logger: The logger object. + + """ + # Create a logger with a specified name. + _logger = logging.getLogger(name) + + # Set the log level for this logger. + _logger.setLevel(log_level) + + # Create a RichHandler for console output. + console_handler = RichHandler( + console=Console(), + rich_tracebacks=True, + log_time_format='[%X]', + # tracebacks_show_locals=True, + keywords=[name], + ) + + # Create a formatter. + formatter = logging.Formatter('%(name)-8s %(message)s') + console_handler.setFormatter(formatter) + + _logger.addHandler(console_handler) + + # Prevent logging from propagating to the root logger. + _logger.propagate = False + + return _logger + + +# Global library logger. +logger: logging.Logger = set_logger() diff --git a/src/thought_decoder/models/__init__.py b/src/thought_decoder/models/__init__.py new file mode 100644 index 0000000..7e34b78 --- /dev/null +++ b/src/thought_decoder/models/__init__.py @@ -0,0 +1,19 @@ +from thought_decoder.models.agentic.policy import AgenticModel +from thought_decoder.models.transformer.encoder import TransformerEncoder +from thought_decoder.models.gnn.graph_nn import GNN +from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts +from thought_decoder.models.decoder import EEGThoughtDecoder +from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams + + +__all__ = [ + 'AgenticModel', + 'AgenticParams', + 'EEGThoughtDecoder', + 'GNN', + 'GNNParams', + 'MixtureOfExperts', + 'MixtureOfExpertsParams', + 'TransformerEncoder', + 'TransformerParams', +] diff --git a/src/thought_decoder/models/agentic/__init__.py b/src/thought_decoder/models/agentic/__init__.py new file mode 100644 index 0000000..aa136cb --- /dev/null +++ b/src/thought_decoder/models/agentic/__init__.py @@ -0,0 +1,4 @@ +from thought_decoder.models.agentic.policy import AgenticModel + + +__all__ = ['AgenticModel'] diff --git a/src/thought_decoder/models/agentic/policy.py b/src/thought_decoder/models/agentic/policy.py new file mode 100644 index 0000000..8682947 --- /dev/null +++ b/src/thought_decoder/models/agentic/policy.py @@ -0,0 +1,53 @@ +"""Agentic policy for the agent in the environment. + +The agentic model helps learning for dynamic adaptation using Reinforcement Learning (RL) algorithms. + +""" + +from jax import Array +from flax import linen as nn + + +class PolicyNetwork(nn.Module): + """Policy Network Module.""" + + action_dim: int + + def __call__(self, x: Array) -> Array: + """Apply the policy network. + + Args: + x (Array): The input tensor of shape (batch_size, input_dim). + + Returns: + Array: The output tensor of shape (batch_size, action_dim). + + """ + x = nn.Dense(features=128)(x) + x = nn.relu(x) + logits = nn.Dense(features=self.action_dim)(x) + return logits # actions probabilities. + + +class AgenticModel(nn.Module): + """Agentic Model Module.""" + + action_dim: int + + def setup(self) -> None: + """Set up the module.""" + self.policy_network = PolicyNetwork(action_dim=self.action_dim) + + def __call__(self, x: Array) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, input_dim). + + Returns: + Array: The output tensor of shape (batch_size, action_dim). + + """ + logits = self.policy_network(x) + action_probs = nn.softmax(logits) + return action_probs diff --git a/src/thought_decoder/models/decoder.py b/src/thought_decoder/models/decoder.py new file mode 100644 index 0000000..b0b230f --- /dev/null +++ b/src/thought_decoder/models/decoder.py @@ -0,0 +1,54 @@ +"""This module combines Transformer, GNN, Mixture of Experts and Agentic Model into one architecture.""" + +# pylint: disable=attribute-defined-outside-init +# from collections.abc import Mapping + +from jax import Array +from flax import linen as nn + +from thought_decoder.models import AgenticModel, GNN, MixtureOfExperts, TransformerEncoder +from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams + + +class EEGThoughtDecoder(nn.Module): + """EEG Thought Decoder Module.""" + + transformer_params: TransformerParams # Mapping[str, int | float] + gnn_params: GNNParams # Mapping[str, int | Array] + moe_params: MixtureOfExpertsParams # Mapping[str, int] + agentic_params: AgenticParams # Mapping[str, int] + + def setup(self) -> None: + """Set up the module.""" + self.transformer = TransformerEncoder(**self.transformer_params) + self.gnn = GNN(**self.gnn_params) + self.moe = MixtureOfExperts(**self.moe_params) + self.agent = AgenticModel(**self.agentic_params) + self.classifier = nn.Dense(features=self.agentic_params['action_dim']) + + def __call__(self, x: Array, training: bool) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, input_dim). + training (bool): Whether the model is training or not. + + Returns: + Array: The output tensor of shape (batch_size, action_dim). + + """ + # Transformer Encoder for temporal dynamics. + x = self.transformer(x, training) + + # GNN for spatial relationships. + x = self.gnn(x) + + # Mixture of Experts for specialization. + x = self.moe(x) + + # Agentic Model for dynamic adaptation. + action_probs = self.agent(x) + + # Classifier. + logits = self.classifier(action_probs) + return logits diff --git a/src/thought_decoder/models/gnn/__init__.py b/src/thought_decoder/models/gnn/__init__.py new file mode 100644 index 0000000..9564ae9 --- /dev/null +++ b/src/thought_decoder/models/gnn/__init__.py @@ -0,0 +1,4 @@ +from thought_decoder.models.gnn.graph_nn import GraphConvolutionLayer, GNN + + +__all__ = ['GraphConvolutionLayer', 'GNN'] diff --git a/src/thought_decoder/models/gnn/graph_nn.py b/src/thought_decoder/models/gnn/graph_nn.py new file mode 100644 index 0000000..3fa1cfc --- /dev/null +++ b/src/thought_decoder/models/gnn/graph_nn.py @@ -0,0 +1,98 @@ +"""Graph Neural Network (GNN) model for capturing spatial relationships in EEG data.""" + +# pylint: disable=attribute-defined-outside-init +import jax.numpy as jnp +from jax import Array +from flax import linen as nn + + +class GraphConvolutionLayer(nn.Module): + """Graph Convolution Layer.""" + + output_dim: int + adjacency_matrix: Array + + def __call__(self, x: Array) -> Array: + """Apply Graph convolution. + + Args: + x (Array): The input tensor of shape (batch_size, num_nodes, input_dim). + + Returns: + Array: The output tensor of shape (batch_size, num_nodes, output_dim). + + """ + # Compute the normalized Laplacian matrix. + laplacian = self._compute_normalized_laplacian(self.adjacency_matrix) + + # Compute the graph convolution. + return jnp.matmul(laplacian, x) @ self.param( + 'weights', nn.initializers.xavier_uniform(), (x.shape[-1], self.output_dim) + ) + + def _compute_normalized_laplacian(self, adjacency_matrix: Array) -> Array: + """Compute the normalized Laplacian matrix. + + Args: + adjacency_matrix (Array): The adjacency matrix of shape (num_nodes, num_nodes). + + Returns: + Array: The normalized Laplacian matrix. + + """ + # Compute the degree matrix. + # degree_matrix = jnp.sum(adjacency_matrix, axis=0) + # degree_matrix = jnp.where(degree_matrix > 0, 1.0 / jnp.sqrt(degree_matrix), 0.0) + # + # # Compute the normalized Laplacian matrix. + # normalized_laplacian = jnp.eye(adjacency_matrix.shape[0]) - degree_matrix * adjacency_matrix * degree_matrix + # + # return normalized_laplacian + + # Compute the degree matrix. + d = jnp.sum(adjacency_matrix, axis=-1) + + # Compute the inverse square root of the degree matrix. + d_inv_sqrt = jnp.power(d, -0.5) + d_inv_sqrt = jnp.diag(d_inv_sqrt) + + # Compute the normalized Laplacian matrix. + laplacian = jnp.eye(adjacency_matrix.shape[0]) - d_inv_sqrt @ adjacency_matrix @ d_inv_sqrt + + # Return the normalized Laplacian matrix. + return laplacian + + +class GNN(nn.Module): + """Graph Neural Network (GNN) model.""" + + num_layers: int + hidden_dim: int + adjacency_matrix: Array + + def setup(self) -> None: + """Set up the module.""" + # Initialize the Graph Convolution Layers. + self.gcn_layers = [ + GraphConvolutionLayer(output_dim=self.hidden_dim, adjacency_matrix=self.adjacency_matrix) + for _ in range(self.num_layers) + ] + self.relu = nn.relu + + def __call__(self, x: Array) -> Array: + """Apply the GNN model. + + Args: + x (Array): The input tensor of shape (batch_size, num_nodes, input_dim). + + Returns: + Array: The output tensor of shape (batch_size, num_nodes, hidden_dim). + + """ + # Apply the Graph Convolution Layers. + for gcn_layer in self.gcn_layers: + x = gcn_layer(x) + x = self.relu(x) + + # Return the output tensor. + return x diff --git a/src/thought_decoder/models/moe/__init__.py b/src/thought_decoder/models/moe/__init__.py new file mode 100644 index 0000000..55b32f8 --- /dev/null +++ b/src/thought_decoder/models/moe/__init__.py @@ -0,0 +1,4 @@ +from thought_decoder.models.moe.mixture_of_experts import MixtureOfExperts + + +__all__ = ['MixtureOfExperts'] diff --git a/src/thought_decoder/models/moe/mixture_of_experts.py b/src/thought_decoder/models/moe/mixture_of_experts.py new file mode 100644 index 0000000..49de4ce --- /dev/null +++ b/src/thought_decoder/models/moe/mixture_of_experts.py @@ -0,0 +1,74 @@ +"""Mixture of Expert model. + +This module implements the Mixture of Experts model for specialization. + +""" + +# pylint: disable=attribute-defined-outside-init +import jax.numpy as jnp +from jax import Array +from flax import linen as nn + + +class Expert(nn.Module): + """Expert Module.""" + + output_dim: int + + def __call__(self, x: Array) -> Array: + """Apply the expert network. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + + Returns: + Array: The output tensor of shape (batch_size, seq_len, output_dim). + + """ + x = nn.Dense(features=self.output_dim)(x) + x = nn.relu(x) + return x + + +class MixtureOfExperts(nn.Module): + """Mixture of Experts Module.""" + + num_experts: int + expert_output_dim: int + + def setup(self) -> None: + """Set up the module.""" + self.experts = [Expert(output_dim=self.expert_output_dim) for _ in range(self.num_experts)] + self.gating_network = nn.Dense(features=self.num_experts) + + def __call__(self, x: Array) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + + Returns: + Array: The output tensor of shape (batch_size, seq_len, model_dim). + + """ + # # Compute the gate values. + # gate_values = nn.softmax(self.gating_network(x), axis=-1) + # + # # Compute the expert outputs. + # expert_outputs = [expert(x) for expert in self.experts] + # + # # Compute the mixture of expert output. + # mixture_of_experts_output = sum( + # gate_values[:, :, i, None] * expert_output for i, expert_output in enumerate(expert_outputs) + # ) + # + # return mixture_of_experts_output + gating_logits = self.gating_network(x) + gating_weights = nn.softmax(gating_logits, axis=-1) # Shape: (batch_size, num_experts) + + # Shape: (batch_size, num_experts, seq_len, expert_output_dim) + expert_outputs = jnp.stack([expert(x) for expert in self.experts], axis=1) + + gated_output = jnp.einsum('be,bed->bd', gating_weights, expert_outputs) + + return gated_output diff --git a/src/thought_decoder/models/transformer/__init__.py b/src/thought_decoder/models/transformer/__init__.py new file mode 100644 index 0000000..a67fdcd --- /dev/null +++ b/src/thought_decoder/models/transformer/__init__.py @@ -0,0 +1,6 @@ +from thought_decoder.models.transformer.attention import MultiHeadSelfAttention +from thought_decoder.models.transformer.pos_encoding import PositionalEncoding +from thought_decoder.models.transformer.encoder import TransformerEncoderLayer, TransformerEncoder + + +__all__ = ['MultiHeadSelfAttention', 'PositionalEncoding', 'TransformerEncoderLayer', 'TransformerEncoder'] diff --git a/src/thought_decoder/models/transformer/attention.py b/src/thought_decoder/models/transformer/attention.py new file mode 100644 index 0000000..9d87b71 --- /dev/null +++ b/src/thought_decoder/models/transformer/attention.py @@ -0,0 +1,86 @@ +"""Attention mechanism for the Transformer model.""" + +# pylint: disable=attribute-defined-outside-init +import jax.numpy as jnp +from jax import Array +from flax import linen as nn + + +class MultiHeadSelfAttention(nn.Module): + """Multi-Head Self Attention Mechanism.""" + + model_dim: int + num_heads: int + + def setup(self) -> None: + """Set up the module.""" + assert self.model_dim % self.num_heads == 0, 'model_dim must be divisible by num_heads' + + self.depth = self.model_dim // self.num_heads + + self.wq = nn.Dense(features=self.model_dim) # Query + self.wk = nn.Dense(features=self.model_dim) # Key + self.wv = nn.Dense(features=self.model_dim) # Value + self.dense = nn.Dense(features=self.model_dim) # Final dense layer + + def __call__(self, x: Array) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + + Returns: + Array: The output tensor of shape (batch_size, seq_len, model_dim). + + """ + batch_size = x.shape[0] + + q = self.wq(x) + k = self.wk(x) + v = self.wv(x) + + scaled_attention = self._scaled_dot_product_attention(q, k, v) + scaled_attention = scaled_attention.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.model_dim) + + return self.dense(scaled_attention) + + def _split_heads(self, x: Array, batch_size: int) -> Array: + """Split the heads. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + batch_size (int): The batch size. + + Returns: + Array: The reshaped tensor of shape (batch_size, num_heads, seq_len, depth). + + """ + return x.reshape(batch_size, -1, self.num_heads, self.depth).transpose(0, 2, 1, 3) + + def _scaled_dot_product_attention(self, q: Array, k: Array, v: Array) -> Array: + """Scaled Dot Product Attention. + + Args: + q (Array): The query tensor of shape (batch_size, num_heads, seq_len_q, depth). + k (Array): The key tensor of shape (batch_size, num_heads, seq_len_k, depth). + v (Array): The value tensor of shape (batch_size, num_heads, seq_len_v, depth). + + Returns: + Array: The output tensor of shape (batch_size, num_heads, seq_len_q, depth). + + """ + # Compute the dot product. + # matmul_qk = jnp.matmul(q, k, transpose_b=True) + matmul_qk = jnp.matmul(q, k.transpose(0, 1, 3, 2)) + + # Scale the dot product. + dk = jnp.array(k.shape[-1], dtype=jnp.float32) + scaled_attention_logits = matmul_qk / jnp.sqrt(dk) + + # Softmax on the last axis. + attention_weights = nn.softmax(scaled_attention_logits, axis=-1) + + # Compute the output. + output = jnp.matmul(attention_weights, v) + + return output diff --git a/src/thought_decoder/models/transformer/encoder.py b/src/thought_decoder/models/transformer/encoder.py new file mode 100644 index 0000000..e956c44 --- /dev/null +++ b/src/thought_decoder/models/transformer/encoder.py @@ -0,0 +1,109 @@ +"""Transofrmers encoder layer & model.""" + +# pylint: disable=attribute-defined-outside-init +from jax import Array +from flax import linen as nn + +from thought_decoder.models.transformer.attention import MultiHeadSelfAttention +from thought_decoder.models.transformer.pos_encoding import PositionalEncoding + + +class TransformerEncoderLayer(nn.Module): + """Transformer Encoder Layer.""" + + model_dim: int + num_heads: int + diff: int + dropout_rate: float = 0.1 + + def setup(self) -> None: + """Set up the module.""" + # Multi-Head Self Attention and Feed Forward Network. + self.mha = MultiHeadSelfAttention(model_dim=self.model_dim, num_heads=self.num_heads) + self.ffn = nn.Sequential([nn.Dense(features=self.diff), nn.relu, nn.Dense(features=self.model_dim)]) + + # Layer Normalization and Dropout. + self.layernorm1 = nn.LayerNorm() + self.layernorm2 = nn.LayerNorm() + self.dropout = nn.Dropout(self.dropout_rate) + + def __call__(self, x: Array, training: bool) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + training (bool): Whether the model is training or not. + + Returns: + Array: The output tensor of shape (batch_size, seq_len, model_dim). + + """ + # Multi-Head Self Attention + attn_output = self.mha(x) + attn_output = self.dropout(attn_output, deterministic=not training) + out1 = self.layernorm1(x + attn_output) + + # Feed Forward Network. + ffn_output = self.ffn(out1) + ffn_output = self.dropout(ffn_output, deterministic=not training) + out2: Array = self.layernorm2(out1 + ffn_output) + + # Return the output. + return out2 + + +class TransformerEncoder(nn.Module): + """Transformer Encoder with multiple Transformer Encoder Layers.""" + + num_layers: int + model_dim: int + num_heads: int + diff: int + input_vocab_size: int + maximum_position_encoding: int + dropout_rate: float = 0.1 + + def setup(self) -> None: + """Set up the module.""" + # Initialize the embedding & positional encoding layer. + self.embedding = nn.Embed(num_embeddings=self.input_vocab_size, features=self.model_dim) + self.pos_encoding = PositionalEncoding(model_dim=self.model_dim) + + # Intialize the Transformer Encoder Layers. + self.encoder_layers = [ + TransformerEncoderLayer( + model_dim=self.model_dim, num_heads=self.num_heads, diff=self.diff, dropout_rate=self.droptout_rate + ) + for _ in range(self.num_layers) + ] + + # Dropout layer. + self.dropout = nn.Dropout(self.dropout_rate) + + def __call__(self, x: Array, training: bool) -> Array: + """Apply the module. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len). + training (bool): Whether the model is training or not. + + Returns: + Array: The output tensor of shape (batch_size, seq_len, model_dim). + + """ + # Get the batch size and sequence length. + # batch_size, seq_len = x.shape + + # Add positional encoding to the input tensor. + x = self.embedding(x) + x = self.pos_encoding(x) + + # Apply the dropout layer. + x = self.dropout(x, deterministic=not training) + + # Apply the Transformer Encoder Layers. + for encoder_layer in self.encoder_layers: + x = encoder_layer(x, training) + + # Return the output tensor [batch_size, seq_len, model_dim] + return x diff --git a/src/thought_decoder/models/transformer/pos_encoding.py b/src/thought_decoder/models/transformer/pos_encoding.py new file mode 100644 index 0000000..82eda57 --- /dev/null +++ b/src/thought_decoder/models/transformer/pos_encoding.py @@ -0,0 +1,38 @@ +"""Positional encoding layer for Transformer model.""" + +import jax.numpy as jnp +from jax import Array +from flax import linen as nn + + +class PositionalEncoding(nn.Module): + """Positional Encoding Layer.""" + + model_dim: int + + def __call__(self, x: Array) -> Array: + """Add positional encoding to the input tensor. + + Args: + x (Array): The input tensor of shape (batch_size, seq_len, model_dim). + + Returns: + Array: The tensor with positional encoding added. + + """ + # pos_enc = jnp.arange(x.shape[1])[:, jnp.newaxis] / jnp.power( + # 10000, 2 * jnp.arange(self.model_dim)[jnp.newaxis, :] / self.model_dim + # ) + # pos_enc = jnp.where(jnp.arange(self.model_dim) % 2 == 0, jnp.sin(pos_enc), jnp.cos(pos_enc)) + # return x + pos_enc + seq_len = x.shape[1] + + position = jnp.arange(seq_len)[:, jnp.newaxis] + div_term = jnp.exp(jnp.arange(0, self.model_dim, 2) * -(jnp.log(10000.0) / self.model_dim)) + + pe = jnp.zeros((seq_len, self.model_dim)) + pe = pe.at[:, 0::2].set(jnp.sin(position * div_term)) + pe = pe.at[:, 1::2].set(jnp.cos(position * div_term)) + pe = pe[jnp.newaxis, ...] + + return x + pe diff --git a/src/thought_decoder/models/utils.py b/src/thought_decoder/models/utils.py new file mode 100644 index 0000000..aa2d05a --- /dev/null +++ b/src/thought_decoder/models/utils.py @@ -0,0 +1,37 @@ +"""Utility functions for the project.""" + +from typing import TypedDict +from jax import Array + + +class TransformerParams(TypedDict, total=False): + """Hyperparameters for the Transformer model.""" + + num_layers: int + model_dim: int + num_heads: int + diff: int + input_vocab_size: int + maximum_position_encoding: int + dropout_rate: float + + +class GNNParams(TypedDict): + """Hyperparameters for the GNN model.""" + + num_layers: int + hidden_dim: int + adjacency_matrix: Array + + +class MixtureOfExpertsParams(TypedDict): + """Hyperparameters for the Mixture of Experts model.""" + + num_experts: int + expert_output_dim: int + + +class AgenticParams(TypedDict): + """Hyperparameters for the Agentic model.""" + + action_dim: int diff --git a/src/thought_decoder/train.py b/src/thought_decoder/train.py new file mode 100644 index 0000000..ffdce68 --- /dev/null +++ b/src/thought_decoder/train.py @@ -0,0 +1,110 @@ +"""Training script. + +Handles the training loop for the `EEGThoughtDecoder` model. + +""" + +# mypy: disable-error-code="no-untyped-call" +import time +from collections.abc import Mapping +from typing import Any + +import jax +import jax.numpy as jnp +import optax +from flax import linen as nn +from flax.training import checkpoints, train_state + +from thought_decoder.logging import logger +from thought_decoder.types import KeyArray, Array, InputShape +from thought_decoder.data import EEGDataLoader +from thought_decoder.models import EEGThoughtDecoder +from thought_decoder.models.utils import AgenticParams, GNNParams, MixtureOfExpertsParams, TransformerParams + + +def create_train_state( + rng: KeyArray, model: nn.Module, learning_rate: float, input_shape: InputShape +) -> train_state.TrainState: + """Create the initial training state.""" + params = model.init(rng, jnp.ones(input_shape), training=True) + optimizer = optax.adam(learning_rate) + state: train_state.TrainState = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer) + return state + + +@jax.jit +def train_step(state: train_state.TrainState, batch: tuple[Array, Array]) -> tuple[train_state.TrainState, float]: + """Train for a single step. + + Args: + state (train_state.TrainState): The training state. + batch (tuple[Array, Array]): The batch of inputs and targets. + + Returns: + tuple[train_state.TrainState, float]: The updated training state and the loss. + + """ + + def loss_fn(params: Mapping[str, Any]) -> tuple[Array, float]: + logits = state.apply_fn({'params': params}, batch[0], training=True) + loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean() + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, _), grad = grad_fn(state.params) + state = state.apply_gradients(grads=grad) + return state, loss + + +def train() -> None: + """Training loop.""" + # Hyperparameters. + num_epochs = 10 + batch_size = 32 + learning_rate = 1e-3 + input_shape = (-1, 100, 64) # placeholder input shape. + + # Initialize the data loader. + data_loader = EEGDataLoader(data_dir='data/', batch_size=batch_size) + + # Model intialization. + rng = jax.random.PRNGKey(0) + transformer_params = TransformerParams( + num_layers=4, + model_dim=128, + num_heads=8, + diff=512, + input_vocab_size=1_000, + maximum_position_encoding=1_000, + dropout_rate=0.1, + ) + gnn_params = GNNParams( + num_layers=2, + hidden_dim=128, + adjacency_matrix=jnp.ones((64, 64)), # Placeholder adjacency matrix. + ) + moe_params = MixtureOfExpertsParams(num_experts=4, expert_output_dim=128) + agentic_params = AgenticParams(action_dim=10) + + model = EEGThoughtDecoder( + transformer_params=transformer_params, + gnn_params=gnn_params, + moe_params=moe_params, + agentic_params=agentic_params, + ) + + state = create_train_state(rng, model, learning_rate, input_shape) + + for epoch in range(num_epochs): + start_time = time.time() + for batch in data_loader.get_batches(): + state, loss = train_step(state, batch) + epoch_time = time.time() - start_time + logger.info(f'Epoch: {epoch + 1}, Loss: {loss:.4f}, Time: {epoch_time:.2f} sec.') + + # Save checkpoints. + checkpoints.save_checkpoint(ckpt_dir='checkpoints/', target=state.params, step=epoch) + + +if __name__ == '__main__': + train() diff --git a/src/thought_decoder/types.py b/src/thought_decoder/types.py new file mode 100644 index 0000000..a6774a5 --- /dev/null +++ b/src/thought_decoder/types.py @@ -0,0 +1,11 @@ +"""Type aliases for JAX arrays and batches.""" + +from typing import Annotated, TypeAlias + +from jax import Array + + +KeyArray: TypeAlias = Annotated[Array, 'PRNGKey'] +Vector: TypeAlias = Annotated[Array, 'Vector'] +Batch: TypeAlias = tuple[Array, Array] +InputShape: TypeAlias = Annotated[tuple[int, int, int], 'batch_size, seq_len, input_dim']