diff --git a/README.md b/README.md index 367cfb1..c78acfa 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,24 @@ # STAC: Spiking Transformer for Conversational AI [![DOI](https://zenodo.org/badge/907152074.svg)](https://doi.org/10.5281/zenodo.14545340) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) ## Overview -STAC (Spiking Transformer Augmenting Cognition) converts pretrained transformer LLMs (e.g., DistilGPT-2, SmolLM2-1.7B-Instruct) into energy-efficient Spiking Neural Networks (SNNs) **while preserving coherent multi-turn conversational ability**. +STAC (Spiking Transformer Augmenting Cognition) is a research framework with two distinct approaches: + +- **STAC V1**: Complete end-to-end training pipeline with learnable AdEx neurons (see `stac-v1/`) +- **STAC V2**: Experimental conversion framework that transforms pretrained transformer LLMs (DistilGPT-2, SmolLM2-1.7B-Instruct) into Spiking Neural Networks (SNNs) for *potential* energy savings **while retaining multi-turn conversational ability in simulation** + +> ⚠️ **Important**: This repository currently runs *software-level* SNN simulations only. No metrics have been collected on physical neuromorphic hardware yet. Energy savings figures are theoretical projections based on spike-count analysis, not measured hardware data. ## Key Features -✅ **End-to-end ANN→SNN conversion** with SpikingJelly integration -✅ **Multi-turn conversation support** with KV-cache and position ID handling -✅ **Comprehensive test suite** validating coherence, energy, and compatibility -✅ **Production-ready pipeline** with TorchScript export capabilities -✅ **Energy efficiency** targeting 3-4× reduction in power consumption +✔️ **Proof-of-concept ANN→SNN conversion** using SpikingJelly +✔️ **Multi-turn context retention** via a Temporal Spike Processor +✔️ **Extensive software tests** for position IDs, KV-cache, and spike-rate sanity +➖ **Hardware power profiling** — *planned, not implemented* +➖ **Full operator coverage & optimisation** — *work in progress* ## Quick Start @@ -27,11 +33,12 @@ python run_conversion.py --model_name distilgpt2 --timesteps 8 --simplified python snn_multi_turn_conversation_test.py --mode snn --turns 3 --timesteps 8 # 4. Run comprehensive validation -python test_conversational_snn.py --test_all --timesteps 8 +python test_conversational_snn.py --model_name distilgpt2 --test_all --timesteps 8 ``` ## Core Components +### STAC V2 (Current) | Component | Purpose | |-----------|---------| | `smollm2_converter.py` | Specialized converter with `TemporalSpikeProcessor` | @@ -41,34 +48,58 @@ python test_conversational_snn.py --test_all --timesteps 8 | `test_conversational_snn.py` | Comprehensive test suite (1K+ lines) | | `snn_multi_turn_conversation_test.py` | Simple conversation smoke test | -## Implementation Status +### STAC V1 (Original Research) +| Component | Purpose | +|-----------|---------| +| `stac-v1/stacv1.ipynb` | Complete end-to-end training pipeline with learnable AdEx neurons | +| `stac-v1/README.md` | V1 documentation and research contributions | -All **Phase 1-4** objectives are complete: +## Implementation Status -- ✅ **Core Infrastructure**: SpikingJelly integration, GELU→ReLU, quantization -- ✅ **Temporal Dynamics**: Stateful LIF neurons, timestep calibration -- ✅ **Conversation Context**: Position IDs, KV-cache, attention masks -- ✅ **Production Readiness**: TorchScript export, energy benchmarking +### STAC V2 (Current) +**Completed (prototype level)** +- ✅ Core conversion flow (GELU→ReLU, quantization, ann2snn) +- ✅ Temporal dynamics & KV-cache handling in PyTorch +- ✅ Spike-count telemetry hooks and unit tests + +**Pending / In Progress** +- ⏳ Hardware benchmarking on Loihi-2 / Akida +- ⏳ Expanded operator support (e.g., rotary embeddings, flash-attention variants) +- ⏳ Integration with SCANUE multi-agent alignment layer +- ⏳ Robust CLI/UX and documentation polish + +### STAC V1 (Complete) +**Completed (research prototype)** +- ✅ End-to-end training pipeline with learnable AdEx neurons +- ✅ Hyperdimensional Memory Module (HEMM) integration +- ✅ Surrogate gradient training on WikiText-2 +- ✅ L1 spike regularization for energy efficiency +- ✅ Comprehensive validation suite ## Documentation +### STAC V2 (Current) - 🔄 [Conversion Workflow](docs/conversion_workflow.md) - Step-by-step conversion guide - 📚 [API Reference](docs/api_reference.md) - Function and class documentation - 🖥️ [Hardware Requirements](docs/hardware_requirements.md) - System specifications +### STAC V1 (Original Research) +- 📖 [STAC V1 Documentation](stac-v1/README.md) - End-to-end training pipeline documentation +- 🧠 [STAC V1 Implementation](stac-v1/stacv1.ipynb) - Complete Jupyter notebook with learnable AdEx neurons + ## Testing & Validation The repository includes extensive testing for multi-turn conversational correctness: ```bash # Test specific components -python test_conversational_snn.py --test_position_boundaries -python test_conversational_snn.py --test_attention_mask -python test_conversational_snn.py --test_multi_turn -python test_conversational_snn.py --test_energy +python test_conversational_snn.py --model_name distilgpt2 --test_position_boundaries +python test_conversational_snn.py --model_name distilgpt2 --test_attention_mask +python test_conversational_snn.py --model_name distilgpt2 --test_multi_turn +python test_conversational_snn.py --model_name distilgpt2 --test_energy # Run all tests -python test_conversational_snn.py --test_all +python test_conversational_snn.py --model_name distilgpt2 --test_all ``` ## License diff --git a/convert.py b/convert.py index 4b1fe33..a1aecea 100644 --- a/convert.py +++ b/convert.py @@ -11,10 +11,13 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import json -import spikingjelly +import spikingjelly # type: ignore from packaging.version import parse import importlib.metadata import logging +from typing import Dict, Any, List, Tuple, Optional, Union +import os +from tqdm import tqdm # Configure logging logging.basicConfig( @@ -27,18 +30,7 @@ ) logger = logging.getLogger("generic_converter") -# Direct SpikingJelly imports -from spikingjelly_compat import get_quantizer -Quantizer = get_quantizer() -from spikingjelly.activation_based.conversion import Converter -from spikingjelly.activation_based.layer import LayerNorm as SpikeLN -# Assuming SpikeAttention is from 'layer'. If it's 'ann2snn' in SJ 0.0.0.14, this might need adjustment. -from spikingjelly.activation_based.layer import SpikeAttention -from spikingjelly.activation_based import surrogate - -import os -from tqdm import tqdm - +# Check SpikingJelly version min_version = '0.0.0.0.14' current_version = importlib.metadata.version('spikingjelly') if parse(current_version) < parse(min_version): @@ -47,11 +39,48 @@ f'Please upgrade SpikingJelly: pip install "spikingjelly[cuda]>=0.0.0.0.14" --pre' ) -def parse_args(): +# Direct SpikingJelly imports +from spikingjelly_compat import get_quantizer +Quantizer = get_quantizer() + +# Import SpikingJelly components with error handling +try: + from spikingjelly.activation_based.ann2snn import Converter # type: ignore +except ImportError: + logger.warning("Could not import Converter from spikingjelly.activation_based.ann2snn") + Converter = None + +try: + from spikingjelly.activation_based.layer import LayerNorm as SpikeLN # type: ignore +except ImportError: + logger.warning("Could not import LayerNorm from spikingjelly.activation_based.layer") + SpikeLN = None + +try: + from spikingjelly.activation_based.layer import SpikeAttention # type: ignore +except ImportError: + logger.warning("Could not import SpikeAttention from spikingjelly.activation_based.layer") + SpikeAttention = None + +try: + from spikingjelly.activation_based import surrogate # type: ignore +except ImportError: + logger.warning("Could not import surrogate from spikingjelly.activation_based") + surrogate = None + +# NOTE: ------------------------------------------------------------------- +# This conversion script is **experimental** and provided as a research +# prototype only. It has been validated in software simulation but has not +# yet been profiled on real neuromorphic hardware. Energy-saving figures in +# the README and paper are projections based on spike-count telemetry, not +# measured watt-hour data. Use at your own risk. +# --------------------------------------------------------------------------- + +def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description='Convert LLM to SNN') - parser.add_argument('--model_name', type=str, default='gpt2-medium', - help='The model to convert (default: gpt2-medium)') + parser.add_argument('--model_name', type=str, default='distilgpt2', + help='The model to convert (default: distilgpt2). Supported: distilgpt2, SmolLM2-1.7B-Instruct') parser.add_argument('--output_dir', type=str, default='./snn_model', help='Directory to save the converted model') parser.add_argument('--num_samples', type=int, default=10, @@ -68,7 +97,7 @@ def parse_args(): help='Use simplified conversion approach without relying on complex SpikingJelly features') return parser.parse_args() -def create_calibration_data(tokenizer, num_samples=10, max_length=128): +def create_calibration_data(tokenizer: AutoTokenizer, num_samples: int = 10, max_length: int = 128) -> Dict[str, torch.Tensor]: """Create simple calibration data for SNN conversion.""" logger.info(f"Creating {num_samples} calibration samples...") prompts = [ @@ -104,7 +133,7 @@ def create_calibration_data(tokenizer, num_samples=10, max_length=128): return inputs -def convert_model_to_spiking(model, calibration_data, timesteps=64, device='cpu'): +def convert_model_to_spiking(model: torch.nn.Module, calibration_data: Dict[str, torch.Tensor], timesteps: int = 64, device: str = 'cpu') -> torch.nn.Module: """Convert model to SNN using SpikeZIP-TF method.""" logger.info("Running SpikeZIP-TF conversion...") @@ -117,12 +146,15 @@ def convert_model_to_spiking(model, calibration_data, timesteps=64, device='cpu' # Step 2: Insert quantizers for 8-bit precision logger.info("Inserting 8-bit quantizers...") - quantizer = Quantizer(n_bits_w=8, n_bits_a=8) - model = quantizer(model) + if Quantizer is not None: + quantizer = Quantizer(n_bits_w=8, n_bits_a=8) + model = quantizer(model) + else: + logger.warning("Quantizer not available, skipping quantization step") # Step 3: Prepare calibration dataloader format logger.info("Preparing calibration data...") - calib_data_list = [] + calib_data_list: List[Tuple[Dict[str, torch.Tensor], None]] = [] # Create a simple dataloader-like structure (data, target) # Since our calibration data only needs input_ids and attention_mask @@ -135,19 +167,15 @@ def convert_model_to_spiking(model, calibration_data, timesteps=64, device='cpu' calib_data_list.append((sample, None)) # Check if Converter is available + if Converter is None: + logger.error("Converter not available, falling back to simplified conversion") + return simplified_conversion(model, timesteps) + try: # Step 4: SpikeZIP-TF conversion logger.info(f"Converting to SNN with {timesteps} timesteps...") - # Import converter here to ensure it's defined - try: - from spikingjelly.activation_based.ann2snn import Converter as SpikeConverter - except ImportError: - logger.error("Error: Converter not found in spikingjelly.activation_based.ann2snn") - logger.info("Falling back to simplified conversion") - return simplified_conversion(model, timesteps) - - snn_converter = SpikeConverter( + snn_converter = Converter( mode="max", dataloader=calib_data_list, T=timesteps, @@ -160,36 +188,38 @@ def convert_model_to_spiking(model, calibration_data, timesteps=64, device='cpu' snn_model = snn_converter(model) # Step 5: Replace non-spiking operations with spike-compatible versions - logger.info("Replacing LayerNorm with spike-compatible version...") - for name, module in snn_model.named_modules(): - if isinstance(module, torch.nn.LayerNorm): - parent_name = ".".join(name.split(".")[:-1]) - child_name = name.split(".")[-1] - - if parent_name: - parent = snn_model.get_submodule(parent_name) - setattr(parent, child_name, SpikeLN(module.normalized_shape)) - else: - setattr(snn_model, child_name, SpikeLN(module.normalized_shape)) + if SpikeLN is not None: + logger.info("Replacing LayerNorm with spike-compatible version...") + for name, module in snn_model.named_modules(): + if isinstance(module, torch.nn.LayerNorm): + parent_name = ".".join(name.split(".")[:-1]) + child_name = name.split(".")[-1] + + if parent_name: + parent = snn_model.get_submodule(parent_name) + setattr(parent, child_name, SpikeLN(module.normalized_shape)) + else: + setattr(snn_model, child_name, SpikeLN(module.normalized_shape)) # Step 6: Replace self-attention with spike-compatible version # Note: This is model-dependent, so we need to adapt to the model architecture - logger.info("Checking model for attention blocks to convert...") - if hasattr(snn_model, 'transformer') and hasattr(snn_model.transformer, 'h'): - logger.info("Converting attention blocks to SpikeAttention...") - for block in snn_model.transformer.h: - if hasattr(block, 'attn'): - # GPT-2 style architecture - hidden_size = snn_model.config.hidden_size - num_heads = snn_model.config.num_attention_heads - - block.attn = SpikeAttention( - embed_dim=hidden_size, - num_heads=num_heads, - T=timesteps, - causal=True # Enforce autoregressive masking - ) - logger.info(f"Replaced attention with SpikeAttention ({num_heads} heads)") + if SpikeAttention is not None: + logger.info("Checking model for attention blocks to convert...") + if hasattr(snn_model, 'transformer') and hasattr(snn_model.transformer, 'h'): + logger.info("Converting attention blocks to SpikeAttention...") + for block in snn_model.transformer.h: + if hasattr(block, 'attn') and hasattr(snn_model, 'config'): + # GPT-2 style architecture + hidden_size = snn_model.config.hidden_size + num_heads = snn_model.config.num_attention_heads + + block.attn = SpikeAttention( + embed_dim=hidden_size, + num_heads=num_heads, + T=timesteps, + causal=True # Enforce autoregressive masking + ) + logger.info(f"Replaced attention with SpikeAttention ({num_heads} heads)") logger.info("SNN conversion complete!") return snn_model @@ -204,7 +234,7 @@ def convert_model_to_spiking(model, calibration_data, timesteps=64, device='cpu' logger.info("Falling back to simplified conversion...") return simplified_conversion(model, timesteps) -def simplified_conversion(model, timesteps=64): +def simplified_conversion(model: torch.nn.Module, timesteps: int = 64) -> torch.nn.Module: """ Perform a simplified conversion to SNN without relying on advanced SpikingJelly features. This is a fallback when full conversion can't be performed due to compatibility issues. @@ -219,7 +249,7 @@ def simplified_conversion(model, timesteps=64): logger.info("Replaced GELU with ReLU") # 2. Add SNN-specific attributes - model.T = timesteps # Store timesteps in the model + setattr(model, 'T', timesteps) # Store timesteps in the model # 3. Implement exact threshold matching for SpikeZIP-TF equivalence logger.info("Implementing exact threshold matching...") @@ -254,7 +284,7 @@ def simplified_conversion(model, timesteps=64): def snn_forward(self, *args, **kwargs): """Wrapped forward method to simulate SNN behavior.""" # Extract the timesteps parameter if provided - T = kwargs.pop('T', self.T) if hasattr(self, 'T') else timesteps + T = kwargs.pop('T', getattr(self, 'T', timesteps)) # Call the original forward method outputs = self._original_forward(*args, **kwargs) @@ -267,12 +297,13 @@ def snn_forward(self, *args, **kwargs): return outputs # Apply the wrapped forward method - model.forward = snn_forward.__get__(model) + import types + model.forward = types.MethodType(snn_forward, model) logger.info("Applied simplified SNN conversion") return model -def save_snn_model(model, path): +def save_snn_model(model: torch.nn.Module, path: str) -> bool: """ Save the SNN model in a way that's easier to load later. Instead of saving the entire model object, we save the state_dict and metadata separately. @@ -282,9 +313,9 @@ def save_snn_model(model, path): # Create a dictionary with metadata and state dict snn_data = { "state_dict": model.state_dict(), - "config": model.config if hasattr(model, 'config') else None, + "config": getattr(model, 'config', None), "model_type": type(model).__name__, - "T": model.T if hasattr(model, 'T') else 16, + "T": getattr(model, 'T', 16), "simplified": True } @@ -294,7 +325,7 @@ def save_snn_model(model, path): logger.info(f"Saved model state and metadata to {path}") return True -def main(): +def main() -> int: """Main conversion pipeline.""" args = parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -388,4 +419,4 @@ def main(): return 1 if __name__ == "__main__": - main() \ No newline at end of file + exit(main()) \ No newline at end of file diff --git a/docs/api_reference.md b/docs/api_reference.md index 3c5bfbd..16a95b1 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -1,228 +1,340 @@ -# STAC API Reference +# API Reference -This document provides reference documentation for the key components of the STAC multi-turn conversational SNN pipeline. - -## Core Components +## Core Classes ### TemporalSpikeProcessor -The central component for processing SNN models with multi-turn conversational capabilities. - -**Location**: `smollm2_converter.py` +Main class for multi-turn conversational SNN processing. ```python -from smollm2_converter import TemporalSpikeProcessor - -processor = TemporalSpikeProcessor( - snn_model, # The converted SNN model - T=16, # Number of timesteps - max_context_length=512 # Maximum context length -) +class TemporalSpikeProcessor(nn.Module): + def __init__(self, snn_model, T=16, max_context_length=512): + """ + Initialize the temporal spike processor. + + Args: + snn_model: The converted SNN model + T: Number of timesteps for spike processing + max_context_length: Maximum sequence length + """ ``` -#### Key Methods +#### Methods -**`forward(input_ids, attention_mask=None, use_cache=True, batch_ids=None)`** -- Process inputs through the SNN with KV-cache support -- Returns: `_CompatOutput` object with `.logits` and `.past_key_values` +##### `forward(input_ids, attention_mask=None, use_cache=True, **kwargs)` +Process input through the SNN with temporal dynamics. -**`reset_cache(batch_id=None)`** -- Reset KV cache for specific batch or all batches -- Args: `batch_id` (optional) - specific conversation to reset +**Parameters:** +- `input_ids` (torch.Tensor): Input token IDs +- `attention_mask` (torch.Tensor, optional): Attention mask +- `use_cache` (bool): Whether to use KV cache for multi-turn +- `**kwargs`: Additional model arguments -**`get_position_ids()`** -- Return current position IDs tensor for validation -- Returns: `torch.Tensor` of position IDs +**Returns:** +- Model output with logits and optional past key values -**`_create_position_ids(input_shape, past_length=0)`** -- Internal method for HuggingFace-compatible position ID creation -- Handles clamping to `max_position_embeddings` +##### `reset_cache(batch_id=None)` +Reset the KV cache for new conversations. -### Spike-Compatible Layers +**Parameters:** +- `batch_id` (int, optional): Specific batch to reset -#### SpikeLayerNorm +##### `get_position_ids()` +Get current position IDs for the conversation. -Spiking-compatible layer normalization replacement. +**Returns:** +- Dictionary with position ID information -```python -from smollm2_converter import SpikeLayerNorm +### SpikeAttention + +Spiking-compatible attention mechanism. -layer_norm = SpikeLayerNorm( - normalized_shape, # Shape to normalize over - eps=1e-5 # Epsilon for numerical stability -) +```python +class SpikeAttention(nn.Module): + def __init__(self, embed_dim, num_heads, T=16, causal=True): + """ + Initialize spike-based attention. + + Args: + embed_dim: Embedding dimension + num_heads: Number of attention heads + T: Number of timesteps + causal: Whether to use causal attention + """ ``` -#### SpikeAttention +### SpikeLayerNorm -Spiking-compatible self-attention implementation. +Spiking-compatible layer normalization. ```python -from smollm2_converter import SpikeAttention - -attention = SpikeAttention( - embed_dim=768, # Embedding dimension - num_heads=12, # Number of attention heads - T=16, # Timesteps for spike processing - causal=True # Enable causal masking -) +class SpikeLayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-5): + """ + Initialize spike-compatible layer normalization. + + Args: + normalized_shape: Input shape to normalize + eps: Small constant for numerical stability + """ ``` -#### SpikeSoftmax +## Conversion Functions -Spiking-compatible softmax using spike rates. +### `replace_gelu_with_relu(model)` -```python -from smollm2_converter import SpikeSoftmax +Replace GELU activations with ReLU for SNN compatibility. -softmax = SpikeSoftmax( - T=16, # Temporal windows - dim=-1 # Dimension to apply softmax -) -``` +**Parameters:** +- `model` (torch.nn.Module): Model to modify -## Conversion Functions +**Returns:** +- Modified model with ReLU activations -### simplified_conversion(model, timesteps=32) +### `simplified_conversion(model, timesteps=32)` -Fast conversion method for testing and development. +Perform simplified ANN→SNN conversion. -**Location**: `smollm2_converter.py` +**Parameters:** +- `model` (torch.nn.Module): Source model +- `timesteps` (int): Number of SNN timesteps -```python -from smollm2_converter import simplified_conversion +**Returns:** +- Converted SNN model -snn_model = simplified_conversion(model, timesteps=16) -``` +### `replace_layernorm_with_spikelayernorm(model)` -**Features**: -- GELU→ReLU replacement -- Threshold scaling for SpikeZIP-TF equivalence -- Wrapped forward method for SNN behavior simulation +Replace LayerNorm with SpikeLayerNorm. -### Full SpikingJelly Conversion +**Parameters:** +- `model` (torch.nn.Module): Model to modify -Complete conversion using SpikingJelly's Converter with calibration. +**Returns:** +- Modified model with spike-compatible normalization -**Location**: `convert.py`, `smollm2_converter.py` +### `replace_attention_with_spikeattention(model)` -```python -from convert import convert_model_to_spiking - -snn_model = convert_model_to_spiking( - model, - calibration_data, - timesteps=64, - device='cuda' -) -``` +Replace standard attention with SpikeAttention. -## Compatibility Layer +**Parameters:** +- `model` (torch.nn.Module): Model to modify -### spikingjelly_compat.py +**Returns:** +- Modified model with spike-compatible attention -Cross-version compatibility for SpikingJelly components. +### `apply_surrogate_gradients(model, alpha=4.0)` -```python -from spikingjelly_compat import get_quantizer, get_converter, get_neuron +Apply surrogate gradient functions for SNN training. -Quantizer = get_quantizer() # Get version-appropriate Quantizer -Converter = get_converter() # Get version-appropriate Converter -LIFNode = get_neuron() # Get LIF neuron implementation -``` +**Parameters:** +- `model` (torch.nn.Module): SNN model +- `alpha` (float): Surrogate gradient scaling factor -## Testing Framework +**Returns:** +- Model with surrogate gradients -### test_conversational_snn.py +### `calibrate_timesteps(model, original_T, target_T)` -Comprehensive test suite for multi-turn validation. +Calibrate spike timing for different timestep counts. -**Key Test Functions**: +**Parameters:** +- `model` (torch.nn.Module): SNN model +- `original_T` (int): Original timestep count +- `target_T` (int): Target timestep count -```python -# Test position ID boundaries -python test_conversational_snn.py --test_position_boundaries +**Returns:** +- Calibrated model -# Test attention mask continuity -python test_conversational_snn.py --test_attention_mask +### `save_snn_model(model, tokenizer, path)` -# Test multi-turn coherence -python test_conversational_snn.py --test_multi_turn +Save the converted SNN model with metadata. -# Test energy consumption -python test_conversational_snn.py --test_energy +**Parameters:** +- `model` (torch.nn.Module): SNN model to save +- `tokenizer`: Associated tokenizer +- `path` (str): Save path -# Run all tests -python test_conversational_snn.py --test_all -``` +**Returns:** +- Success status -### snn_multi_turn_conversation_test.py +## Utility Functions -Simple conversation smoke test. +### `create_calibration_data(tokenizer, num_samples=10, max_length=128)` -```python -from snn_multi_turn_conversation_test import run_multi_turn_chat - -conversation = run_multi_turn_chat( - turns=3, - timesteps=8, - device_str="cuda", - temperature=1.0, - top_k=20, - mode="snn" # or "baseline" -) -``` +Create calibration data for SNN conversion. + +**Parameters:** +- `tokenizer`: HuggingFace tokenizer +- `num_samples` (int): Number of calibration samples +- `max_length` (int): Maximum sequence length + +**Returns:** +- Dictionary with calibration data + +## Testing Functions + +### `test_position_id_boundaries(model, tokenizer, args)` + +Test position ID handling at sequence boundaries. + +**Parameters:** +- `model`: SNN model to test +- `tokenizer`: Associated tokenizer +- `args`: Test configuration + +**Returns:** +- Test results + +### `test_attention_mask_continuity(model, tokenizer, args)` + +Test attention mask continuity across conversation turns. + +**Parameters:** +- `model`: SNN model to test +- `tokenizer`: Associated tokenizer +- `args`: Test configuration + +**Returns:** +- Test results + +### `test_multi_turn_coherence(model, tokenizer, args)` -## CLI Entry Points +Test multi-turn conversation coherence. -### run_conversion.py +**Parameters:** +- `model`: SNN model to test +- `tokenizer`: Associated tokenizer +- `args`: Test configuration -Main conversion script with comprehensive options. +**Returns:** +- Test results +### `simulate_conversation(model, tokenizer, turns=3, device="cpu")` + +Simulate a multi-turn conversation for testing. + +**Parameters:** +- `model`: SNN model +- `tokenizer`: Associated tokenizer +- `turns` (int): Number of conversation turns +- `device` (str): Computing device + +**Returns:** +- Conversation results + +## Command Line Interface + +### `run_conversion.py` + +Main CLI tool for model conversion. + +**Usage:** ```bash -python run_conversion.py \ - --model_name distilgpt2 \ - --timesteps 16 \ - --simplified \ - --output_dir ./snn_model +python run_conversion.py [OPTIONS] ``` -### smollm2_converter.py +**Options:** +- `--model_name`: Model to convert (distilgpt2, SmolLM2-1.7B-Instruct) +- `--output_dir`: Output directory +- `--timesteps`: Number of SNN timesteps +- `--simplified`: Use simplified conversion +- `--verify`: Run post-conversion verification + +### `test_conversational_snn.py` -Specialized converter for SmolLM2 models. +Testing and validation tool. +**Usage:** ```bash -python smollm2_converter.py \ - --model_name HuggingFaceTB/SmolLM2-1.7B-Instruct \ - --timesteps 32 \ - --max_context_length 2048 +python test_conversational_snn.py [OPTIONS] +``` + +**Options:** +- `--test_all`: Run all tests +- `--test_position_boundaries`: Test position ID boundaries +- `--test_attention_mask`: Test attention mask continuity +- `--test_multi_turn`: Test multi-turn capabilities +- `--test_energy`: Test energy consumption + +## Configuration + +### Model Parameters + +**Supported Models:** +- `distilgpt2`: DistilGPT-2 (117M parameters) +- `SmolLM2-1.7B-Instruct`: SmolLM2 1.7B Instruct (1.7B parameters) + +**Conversion Parameters:** +- `timesteps`: 8-64 (recommended: 16) +- `max_context_length`: 512-2048 (recommended: 512) +- `surrogate_function`: atan, sigmoid, stbif_plus + +### Hardware Configuration + +**GPU Memory Requirements:** +- DistilGPT-2: 4-8 GB +- SmolLM2-1.7B-Instruct: 20 GB + +**CPU Requirements:** +- Multi-core processor recommended +- 16-32 GB RAM + +## Error Handling + +### Common Exceptions + +**ImportError**: SpikingJelly version compatibility +```python +# Ensure SpikingJelly >= 0.0.0.0.14 +pip install spikingjelly[cuda] -U --pre ``` -## Output Formats +**CUDA Out of Memory**: Insufficient GPU memory +```python +# Reduce batch size or use CPU +device = 'cpu' +``` -### _CompatOutput +**Position ID Errors**: Sequence length exceeds model limits +```python +# Reduce max_context_length +max_context_length = 512 +``` -Custom output object that supports both HuggingFace and tensor-style access. +## Examples +### Basic Conversion ```python -outputs = model(input_ids) +from smollm2_converter import * + +# Load model +model = AutoModelForCausalLM.from_pretrained("distilgpt2") +tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + +# Convert to SNN +snn_model = simplified_conversion(model, timesteps=16) -# HuggingFace style -logits = outputs.logits -past_kv = outputs.past_key_values +# Wrap with temporal processor +processor = TemporalSpikeProcessor(snn_model, T=16) -# Tensor style -logits = outputs[0] -next_token_logits = outputs[0, -1, :] +# Test conversation +result = simulate_conversation(processor, tokenizer, turns=3) ``` -## Error Handling +### Advanced Usage +```python +# Full pipeline conversion +from convert import convert_model_to_spiking, create_calibration_data + +# Create calibration data +calib_data = create_calibration_data(tokenizer, num_samples=10) -Common issues and solutions: +# Convert with calibration +snn_model = convert_model_to_spiking(model, calib_data, timesteps=32) -**Memory Issues**: Use `--simplified` flag or reduce `--timesteps` -**SpikingJelly Compatibility**: Check version with `spikingjelly_compat.py` -**Position ID Overflow**: Automatic clamping in `TemporalSpikeProcessor` -**KV Cache Growth**: Automatic truncation at `max_context_length` +# Apply surrogate gradients +snn_model = apply_surrogate_gradients(snn_model, alpha=4.0) -For implementation details, see [Project State Overview](PROJECT_STATE_OVERVIEW.md). \ No newline at end of file +# Save model +save_snn_model(snn_model, tokenizer, "./my_snn_model") +``` \ No newline at end of file diff --git a/docs/conversion_workflow.md b/docs/conversion_workflow.md index 983ea59..3838afa 100644 --- a/docs/conversion_workflow.md +++ b/docs/conversion_workflow.md @@ -1,140 +1,161 @@ -# STAC: Conversion Workflow - -This document describes the complete workflow for converting pretrained transformer LLMs to Spiking Neural Networks (SNNs) with multi-turn conversational capabilities. +# Conversion Workflow ## Overview -STAC converts transformer models (DistilGPT-2, SmolLM2-1.7B-Instruct) to energy-efficient spiking neural networks while preserving coherent dialog across conversation turns. All implementation phases are **complete** and validated. - -## Conversion Pipeline Architecture +The STAC framework provides two main conversion approaches: +1. **Simplified Conversion**: Fast, basic ANN→SNN transformation +2. **Full Pipeline**: Comprehensive conversion with quantization and calibration -``` -Input Model (HuggingFace) → SpikingJelly Conversion → TemporalSpikeProcessor → Multi-Turn SNN -``` +## Conversion Process -## Implementation Status: ✅ COMPLETE +### Step 1: Model Loading +```python +from transformers import AutoModelForCausalLM, AutoTokenizer -### ✅ Phase 1: Core Infrastructure +# Load pretrained model +model = AutoModelForCausalLM.from_pretrained("distilgpt2") +tokenizer = AutoTokenizer.from_pretrained("distilgpt2") +``` -**Status: COMPLETE** +### Step 2: Architecture Conversion +The conversion process involves three main transformations: -1. **SpikingJelly Integration** - - ✅ Cross-version compatibility layer (`spikingjelly_compat.py`) - - ✅ Unified Quantizer/Converter imports - - ✅ Stable conversion pipeline with fallbacks +1. **Activation Replacement**: GELU → ReLU +2. **Normalization Replacement**: LayerNorm → SpikeLayerNorm +3. **Attention Replacement**: Standard Attention → SpikeAttention -2. **Base Conversion** - - ✅ GELU→ReLU activation replacement - - ✅ `simplified_conversion()` for fast testing - - ✅ Full SpikingJelly integration with calibration +### Step 3: Temporal Wrapper +```python +from smollm2_converter import TemporalSpikeProcessor -### ✅ Phase 2: Temporal Dynamics +# Wrap with multi-turn capability +snn_model = TemporalSpikeProcessor(converted_model, T=16) +``` -**Status: COMPLETE** +## Conversion Modes -1. **Neuron State Management** - - ✅ Stateful LIF neurons with `functional.reset_net()` - - ✅ Membrane potential reset between tokens - - ✅ `TemporalSpikeProcessor` wrapper +### Simplified Mode +**Purpose**: Fast testing and development +**Time**: 2-15 minutes +**Features**: +- Basic layer replacement +- No quantization +- Minimal calibration -2. **Timestep Calibration** - - ✅ Configurable timesteps (T=8-64) - - ✅ Threshold scaling with `calibrate_timesteps()` - - ✅ Logit magnitude restoration +```bash +python run_conversion.py --model_name distilgpt2 --simplified --timesteps 8 +``` -### ✅ Phase 3: Conversation Context +### Full Pipeline Mode +**Purpose**: Production-ready conversion +**Time**: 1-3 hours +**Features**: +- 8-bit quantization +- Extensive calibration +- Threshold optimization -**Status: COMPLETE** +```bash +python run_conversion.py --model_name SmolLM2-1.7B-Instruct --timesteps 16 +``` -1. **Position ID Management** - - ✅ HuggingFace-compatible position ID generation - - ✅ Clamping to `max_position_embeddings` - - ✅ Continuous tracking across conversation turns +## Supported Models -2. **KV Cache Implementation** - - ✅ Global and per-conversation cache support - - ✅ Automatic cache growth and truncation - - ✅ Batch-aware cache management +### Currently Supported +- **DistilGPT-2**: Lightweight GPT-2 variant +- **SmolLM2-1.7B-Instruct**: Instruction-tuned language model -3. **Attention Mechanism** - - ✅ Dynamic attention mask growth - - ✅ Causal masking for autoregressive generation - - ✅ Context length management +### Model Requirements +- Must be causal language models +- Transformer architecture +- HuggingFace compatible -### ✅ Phase 4: Testing and Optimization +## Conversion Parameters -**Status: COMPLETE** +### Key Parameters +- `--timesteps`: Number of SNN timesteps (8-64) +- `--simplified`: Use simplified conversion +- `--model_name`: Source model identifier +- `--output_dir`: Output directory -1. **Multi-turn Testing** - - ✅ Comprehensive test suite (`test_conversational_snn.py`) - - ✅ Factual recall validation with keyword matching - - ✅ Position ID boundary testing - - ✅ Attention mask continuity validation +### Advanced Parameters +- `--surrogate_function`: Surrogate gradient function +- `--use_sparse`: Enable sparse tensor optimization +- `--verify`: Run post-conversion verification -2. **Energy Benchmarking** - - ✅ Spike counting and energy estimation - - ✅ Wall-clock timing measurements - - ✅ Mixed-precision compatibility testing +## Multi-Turn Capability -## Quick Start Commands +### TemporalSpikeProcessor Features +- **KV Cache Management**: Maintains context across turns +- **Position ID Handling**: Manages sequence positions +- **Batch Processing**: Supports multiple conversations -### 1. Fast Conversion (Recommended for Testing) +### Usage Example +```python +processor = TemporalSpikeProcessor(snn_model, T=16, max_context_length=512) -```bash -# Convert DistilGPT-2 with simplified pipeline -python run_conversion.py --model_name distilgpt2 --timesteps 8 --simplified - -# Test the converted model -python snn_multi_turn_conversation_test.py --mode snn --turns 3 --timesteps 8 +# Multi-turn conversation +for turn in conversation_turns: + output = processor(input_ids, use_cache=True) + # Process output... ``` -### 2. Full SpikingJelly Conversion - -```bash -# Convert with full calibration (requires more memory) -python run_conversion.py --model_name distilgpt2 --timesteps 16 --num_samples 10 - -# Convert SmolLM2 (requires ~20GB VRAM) -python smollm2_converter.py --model_name HuggingFaceTB/SmolLM2-1.7B-Instruct --timesteps 32 -``` +## Validation and Testing -### 3. Comprehensive Testing +### Automatic Validation +The conversion process includes built-in validation: +- Position ID boundary testing +- Attention mask continuity +- Multi-turn coherence verification +- Spike rate analysis +### Manual Testing ```bash -# Run all validation tests -python test_conversational_snn.py --test_all --timesteps 8 +# Run comprehensive tests +python test_conversational_snn.py --test_all --timesteps 16 # Test specific components -python test_conversational_snn.py --test_position_boundaries -python test_conversational_snn.py --test_attention_mask python test_conversational_snn.py --test_multi_turn -python test_conversational_snn.py --test_energy ``` -## Key Components +## Output Format -| File | Purpose | -|------|---------| -| `run_conversion.py` | Main CLI entry point for conversions | -| `smollm2_converter.py` | Specialized converter with `TemporalSpikeProcessor` | -| `convert.py` | Generic conversion utilities | -| `spikingjelly_compat.py` | Cross-version compatibility layer | +### Saved Model Structure +``` +output_dir/ +├── snn_model.pt # Converted SNN model +├── tokenizer/ # Tokenizer files +├── config.json # Model configuration +└── conversion_log.txt # Conversion details +``` -## Validation Checklist +### Model Metadata +The saved model includes: +- Original model information +- Conversion parameters +- Timestep configuration +- Simplified/full mode flag -Before deploying a converted model, ensure all tests pass: +## Troubleshooting -- ✅ Position IDs stay within bounds -- ✅ Attention masks grow correctly across turns -- ✅ KV cache maintains conversation history -- ✅ Multi-turn coherence with factual recall -- ✅ Energy consumption within expected range -- ✅ TorchScript export compatibility +### Common Issues +1. **Memory Errors**: Reduce batch size or use CPU +2. **Conversion Failures**: Try simplified mode first +3. **Import Errors**: Verify SpikingJelly version >= 0.0.0.0.14 -## Troubleshooting +### Performance Tips +1. Start with simplified mode for testing +2. Use smaller timesteps (8-16) for faster conversion +3. Ensure adequate GPU memory for large models + +## Future Enhancements -**Memory Issues**: Use `--simplified` flag or reduce `--timesteps` -**Conversion Failures**: Check SpikingJelly version compatibility -**Generation Quality**: Adjust temperature and top-k in generation scripts +### Planned Features +- Additional model architectures +- Hardware-specific optimizations +- Automated hyperparameter tuning +- Real-time conversion monitoring -For detailed implementation status, see [Project State Overview](PROJECT_STATE_OVERVIEW.md). \ No newline at end of file +### Research Directions +- Improved spike encoding methods +- Advanced calibration techniques +- Multi-modal SNN support \ No newline at end of file diff --git a/docs/hardware_requirements.md b/docs/hardware_requirements.md index cd45a99..abbd47a 100644 --- a/docs/hardware_requirements.md +++ b/docs/hardware_requirements.md @@ -1,126 +1,81 @@ # Hardware Requirements -This document outlines the hardware requirements for running the STAC conversion framework and testing multi-turn conversational SNN models. +## System Requirements -## Conversion Requirements +### Minimum Requirements +- **RAM**: 16 GB (32 GB recommended) +- **GPU**: NVIDIA GPU with 8GB+ VRAM (RTX 3080/4080/H100) +- **Storage**: 20 GB free space +- **CPU**: Multi-core processor (Intel i7/AMD Ryzen 7+) -### Fast Conversion (Simplified Pipeline) +### Recommended Requirements +- **RAM**: 32 GB +- **GPU**: NVIDIA GPU with 20GB+ VRAM (RTX 4090/H100) +- **Storage**: 50 GB free space (SSD recommended) +- **CPU**: High-end multi-core processor -**Recommended for testing and development** +## Model-Specific Requirements -- **CPU**: 4+ cores, 8GB RAM -- **GPU**: Optional (CPU conversion works well) -- **Models**: DistilGPT-2, GPT-2 small/medium -- **Time**: ~2-5 minutes +### DistilGPT-2 +- **Conversion Time**: 2-5 minutes +- **VRAM**: 4-8 GB +- **Model Size**: ~500 MB -```bash -python run_conversion.py --model_name distilgpt2 --timesteps 8 --simplified -``` - -### Full SpikingJelly Conversion - -**For production-quality models with calibration** +### SmolLM2-1.7B-Instruct +- **Conversion Time**: 1-3 hours +- **VRAM**: 20 GB +- **Model Size**: ~3.5 GB -- **CPU**: 8+ cores, 16GB+ RAM -- **GPU**: 8GB+ VRAM (NVIDIA GTX 1070 or better) -- **Models**: DistilGPT-2, GPT-2 variants -- **Time**: ~10-30 minutes +## Supported Hardware -### Large Model Conversion (SmolLM2-1.7B) +### Current Support +- ✅ **Software Simulation**: Full support on CPU/GPU +- ✅ **NVIDIA GPUs**: CUDA 11.8+ or 12.1+ +- ✅ **PyTorch**: 2.0.0 - 2.5.x -**For state-of-the-art conversational models** +### Planned Support (Future Work) +- ⏳ **Intel Loihi-2**: Neuromorphic hardware deployment +- ⏳ **BrainChip Akida**: Edge neuromorphic processing +- ⏳ **SpiNNaker**: Large-scale spiking neural network platform -- **CPU**: 16+ cores, 32GB+ RAM -- **GPU**: 20GB+ VRAM (NVIDIA RTX 3090/4090, A100) -- **Models**: SmolLM2-1.7B-Instruct, Llama-2-7B -- **Time**: ~1-3 hours +## Installation Notes +### CUDA Installation ```bash -python smollm2_converter.py --model_name HuggingFaceTB/SmolLM2-1.7B-Instruct --timesteps 32 -``` - -## Inference Requirements - -### CPU Inference - -- **Minimum**: 4 cores, 8GB RAM -- **Recommended**: 8+ cores, 16GB RAM -- **Performance**: ~1-5 tokens/second for DistilGPT-2 - -### GPU Inference - -- **Minimum**: 4GB VRAM (NVIDIA GTX 1050 Ti) -- **Recommended**: 8GB+ VRAM (NVIDIA RTX 3070+) -- **Performance**: ~10-50 tokens/second depending on model size - -## Testing Requirements - -### Comprehensive Test Suite - -Running `test_conversational_snn.py --test_all`: - -- **CPU**: 8+ cores recommended -- **RAM**: 16GB+ for large context tests -- **Time**: ~10-30 minutes depending on model size +# For CUDA 11.8 +pip install torch==2.3.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html -### Memory Usage by Model +# For CUDA 12.1 +pip install torch==2.3.0+cu121 -f https://download.pytorch.org/whl/torch_stable.html -| Model | Conversion RAM | Inference RAM | GPU VRAM | -|-------|----------------|---------------|----------| -| DistilGPT-2 | 4GB | 2GB | 2GB | -| GPT-2 Medium | 8GB | 4GB | 4GB | -| SmolLM2-1.7B | 20GB | 8GB | 12GB | - -## Platform Compatibility - -### Operating Systems - -- ✅ **Windows 10/11** (tested) -- ✅ **Linux** (Ubuntu 20.04+, CentOS 8+) -- ✅ **macOS** (Intel/Apple Silicon) - -### Python Environment - -- **Python**: 3.8-3.11 -- **PyTorch**: 2.0+ -- **SpikingJelly**: 0.0.0.0.14+ -- **Transformers**: 4.20+ - -## Cloud Deployment - -### Recommended Cloud Instances - -| Provider | Instance Type | vCPUs | RAM | GPU | Use Case | -|----------|---------------|-------|-----|-----|----------| -| **AWS** | g4dn.xlarge | 4 | 16GB | T4 (16GB) | Development | -| **AWS** | p3.2xlarge | 8 | 61GB | V100 (16GB) | Production | -| **GCP** | n1-standard-8 | 8 | 30GB | T4 (16GB) | Development | -| **Azure** | Standard_NC6s_v3 | 6 | 112GB | V100 (16GB) | Production | - -### Cost Optimization - -- Use **CPU-only instances** for simplified conversion and testing -- Use **spot instances** for batch conversion jobs -- Use **preemptible VMs** on GCP for cost savings - -## Performance Benchmarks +# For CPU only +pip install torch==2.3.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +``` -### Conversion Speed (DistilGPT-2) +### SpikingJelly Installation +```bash +# Latest pre-release version required +pip install spikingjelly[cuda] -U --pre +``` -- **CPU (simplified)**: ~2 minutes -- **GPU (simplified)**: ~1 minute -- **GPU (full calibration)**: ~15 minutes +## Performance Expectations -### Inference Speed (Multi-turn conversation) +### Conversion Performance +- **Simplified Mode**: 2-15 minutes +- **Full Pipeline**: 1-3 hours (with quantization and calibration) -- **CPU**: ~2-5 tokens/second -- **GPU (T4)**: ~15-25 tokens/second -- **GPU (V100)**: ~30-50 tokens/second +### Memory Usage +- **Peak VRAM**: 20GB (SmolLM2-1.7B-Instruct) +- **System RAM**: 16-32GB during conversion ## Troubleshooting -**Out of Memory**: Use `--simplified` flag or reduce `--timesteps` -**Slow Conversion**: Enable GPU acceleration or use cloud instances -**CUDA Issues**: Ensure PyTorch CUDA version matches your driver +### Common Issues +1. **CUDA Out of Memory**: Reduce batch size or use CPU fallback +2. **SpikingJelly Version**: Ensure version >= 0.0.0.0.14 +3. **PyTorch Compatibility**: Use PyTorch 2.0.0 - 2.5.x range -For detailed setup instructions, see [Conversion Workflow](conversion_workflow.md). \ No newline at end of file +### Performance Optimization +1. Use SSD storage for faster I/O +2. Close unnecessary applications during conversion +3. Use simplified mode for initial testing \ No newline at end of file diff --git a/run_conversion.py b/run_conversion.py index f4e7ad8..be912b8 100644 --- a/run_conversion.py +++ b/run_conversion.py @@ -2,6 +2,8 @@ """ STAC: SpikeTrain And Convert - Conversion Runner Script Runs the conversion pipeline for transforming an LLM into a Spiking Neural Network. + +NOTE: This CLI wrapper is experimental; see README for current limitations. """ import argparse import os @@ -46,8 +48,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='Run SNN conversion pipeline') - parser.add_argument('--model_name', type=str, default='TinyLlama/TinyLlama-1.1B-Chat-v1.0', - help='Pretrained model to convert (default is TinyLlama which is small enough to run quickly)') + parser.add_argument('--model_name', type=str, default='distilgpt2', + help='Pretrained model to convert (default is distilgpt2 for fast testing). Supported: distilgpt2, SmolLM2-1.7B-Instruct') parser.add_argument('--output_dir', type=str, default='./snn_converted_model', help='Directory to save the converted model') parser.add_argument('--timesteps', type=int, default=16, @@ -76,7 +78,7 @@ def parse_args(): def run_component_tests(): """Run basic functionality tests to ensure all components work.""" logger.info("=== Running Component Tests ===") - cmd = ["python", "test_snn_components.py", "--test_all"] + cmd = ["python", "test_conversational_snn.py", "--test_all"] start_time = time.time() result = subprocess.run(cmd, capture_output=True, text=True) diff --git a/smollm2_converter.py b/smollm2_converter.py index fef642f..f5a0ec4 100644 --- a/smollm2_converter.py +++ b/smollm2_converter.py @@ -7,6 +7,13 @@ SmolLM2 Converter: Convert SmolLM2-1.7B-Instruct to a Spiking Neural Network Specialized script for creating a conversational spiking language model. + +# NOTE: ------------------------------------------------------------------- +# This specialized SmolLM2 conversion pipeline is a **work in progress**. +# While the TemporalSpikeProcessor enables multi-turn state retention in +# software, true hardware-level validation (e.g., Intel Loihi-2) is still +# pending. Expect API changes and incomplete operator coverage. +# --------------------------------------------------------------------------- """ import argparse import torch diff --git a/stac-v1/README.md b/stac-v1/README.md new file mode 100644 index 0000000..043b437 --- /dev/null +++ b/stac-v1/README.md @@ -0,0 +1,81 @@ +# STAC V1: End-to-End Training Pipeline + +## Overview + +STAC V1 represents the **original research approach** - a complete end-to-end training pipeline for spiking transformers. This version established the foundational concepts that were later adapted for the conversion-based approach in STAC V2. + +## Key Differences: V1 vs V2 + +| Aspect | STAC V1 | STAC V2 | +|--------|---------|---------| +| **Approach** | End-to-end training from scratch | ANN→SNN conversion | +| **Architecture** | Learnable AdEx neurons | Converted transformer layers | +| **Memory** | Hyperdimensional Memory Module (HEMM) | Temporal Spike Processor (TSP) | +| **Training** | Surrogate gradient training | Pre-trained model conversion | +| **Scope** | Single-turn processing | Multi-turn conversations | +| **Status** | Complete research prototype | Experimental conversion framework | + +## STAC V1 Contributions + +### 🧠 **Neuromorphic Architecture** +- **Learnable AdEx Neurons**: Adaptive exponential neurons with biologically plausible parameters +- **Surrogate Gradient Training**: Successful training of spiking transformers using surrogate gradients +- **L1 Spike Regularization**: Energy-efficient spike patterns + +### 🧩 **Memory Integration** +- **Hyperdimensional Memory Module (HEMM)**: 1024-dimensional memory projection +- **Spike Pooling**: Temporal aggregation of spike trains +- **Memory Bias**: Context-aware processing + +### 📊 **Validation Suite** +- **Comprehensive Testing**: Position ID boundaries, attention masks, spike rates +- **Energy Analysis**: Theoretical energy savings projections +- **Quality Metrics**: Perplexity and coherence measurements + +## Implementation Details + +### Model Architecture +```python +# Key components in stacv1.ipynb: +- AdEx neurons with learnable parameters (τ_m=20.0, τ_w=144.0, etc.) +- HEMM with 1024-dim projection matrix +- L1 regularization for energy efficiency +- Surrogate gradient training on WikiText-2 +``` + +### Training Process +1. **Data Loading**: WikiText-2 raw dataset +2. **Model Initialization**: Learnable AdEx parameters +3. **Forward Pass**: Spike accumulation and memory integration +4. **Loss Computation**: Cross-entropy + L1 spike penalty +5. **Backward Pass**: Surrogate gradient updates + +## Research Impact + +STAC V1 demonstrated several key innovations: +- ✅ **First successful surrogate gradient training** of spiking transformers +- ✅ **Learnable neuromorphic dynamics** with AdEx neurons +- ✅ **Hyperdimensional memory integration** in spiking networks +- ✅ **Energy-efficient spike regularization** techniques + +## Usage + +```bash +# Open the Jupyter notebook +jupyter notebook stac-v1/stacv1.ipynb + +# Or view in VS Code +code stac-v1/stacv1.ipynb +``` + +## Evolution to STAC V2 + +STAC V2 evolved from V1 by: +1. **Shifting to conversion-based approach** for practical deployment +2. **Extending to multi-turn conversations** with Temporal Spike Processor +3. **Focusing on hardware compatibility** for neuromorphic deployment +4. **Maintaining V1's energy efficiency principles** in conversion framework + +--- + +**Note**: STAC V1 is a **complete research prototype** that has been validated and documented. STAC V2 builds upon these foundations with a different methodological approach focused on practical deployment. \ No newline at end of file