Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ Note: The ECE594 project is currently limited in scope to "Aim 1" above.

## Overview

This project investigates whether state-estimation objectives for articulated bodies induce structured neural representations analogous to grid codes in spatial navigation. We train recurrent networks on path integration for a robotic arm with configuration space Q = SO(3) × SO(3), analyze the learned representations, and evaluate their utility for downstream reinforcement learning.
This project investigates whether state-estimation objectives for articulated bodies induce structured neural representations analogous to grid codes in spatial navigation. We train recurrent networks on path integration for a robotic arm and analyze the learned representations.

**Configuration spaces:**
- **SO(3) × SO(3)** — full 3D rotational joints (6D velocities)
- **SO(2) × SO(2)** — planar arm on the torus (2D velocities). Neuron activations are directly plottable on (θ1, θ2) without dimensionality reduction

**Project Stages:**
1. **Body-state estimation** (Team Estimation): Train RNN to perform path integration on joint angular velocities
Expand Down Expand Up @@ -76,7 +80,7 @@ articulated/
│ │ └── train.py # Training script
│ │
│ ├── shared/ # Shared utilities
│ │ └── robot_arm.py # Kinematics on SO(3) × SO(3)
│ │ └── robot_arm.py # Kinematics on SO(3)×SO(3) and SO(2)×SO(2)
│ │
│ ├── configs/ # Configuration files
│ │ ├── estimation/ # Team Estimation configs
Expand All @@ -95,21 +99,34 @@ articulated/

### Team Estimation

**Goal:** Train RNN to perform path integration on SO(3) × SO(3).
**Goal:** Train RNN/LSTM/GRU to perform path integration. Supports SO(3)×SO(3) and SO(2)×SO(2).

**Key files:**
- `articulated/estimation/datamodule.py`: Implement trajectory generation (inputs,targets)
- `articulated/estimation/model.py`: Define RNN architecture
- `articulated/estimation/train.py`: Training script
- `articulated/estimation/datamodule.py`: Trajectory data generation (supports `manifold="so2"` / `"so3"`)
- `articulated/estimation/model.py`: RNN/LSTM/GRU architectures with configurable `init_pos_size`
- `scripts/generate_data.py`: Parallel data generation with multiprocessing
- `scripts/train.py`: Training entry point
- `scripts/train.sh`: Convenience shell script
- `scripts/analyze_representation.py`: PCA, t-SNE, and tuning curve / torus heatmap analysis

**Generate data and train:**
```bash
# SO(2) — planar arm on the torus
python scripts/generate_data.py --manifold so2 --n_train 100000 --n_val 5000 --workers 16
bash scripts/train.sh gru 0 so2

**Key TODOs:**
1. Implement proper SO(3) × SO(3) trajectory generation in `_generate_single_trajectory()`
2. Implement proper "place cell" targets on SO(3) × SO(3)
3. Experiment with RNN vs LSTM vs GRU architectures
# SO(3) — full 3D rotational joints
python scripts/generate_data.py --n_train 100000 --n_val 5000 --workers 16
bash scripts/train.sh gru 0
```

**Run training:**
**Analyze representations:**
```bash
python -m articulated.estimation.train --config articulated/configs/estimation/rnn.yaml
# SO(2): direct (θ1, θ2) heatmaps
PYTHONPATH=. python scripts/analyze_representation.py path/to/checkpoint.ckpt --manifold so2

# SO(3): PCA + t-SNE + tuning curves
PYTHONPATH=. python scripts/analyze_representation.py path/to/checkpoint.ckpt
```

**Interface with other teams:**
Expand Down
39 changes: 39 additions & 0 deletions articulated/configs/estimation/gru_so2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Configuration for GRU state estimation on SO(2) x SO(2) (2D torus)
# Usage: python scripts/train.py --manifold so2 --model_type gru

seed: 42

model:
input_size: 2 # 2 joints x 1 angular velocity component
hidden_size: 256
output_size: 64 # Number of place cells (32 per joint)
init_pos_size: 4 # (cos θ1, sin θ1, cos θ2, sin θ2)
model_type: "gru"
learning_rate: 1e-3
weight_decay: 1e-4
use_init_pos: true
dropout: 0.5

data:
batch_size: 64
seq_length: 100
n_trajectories_train: 1000
n_trajectories_val: 100
n_place_cells: 64
dt: 0.01
provide_init_pos: true
manifold: "so2"

trainer:
max_epochs: 100
accelerator: "auto"
devices: 1
log_every_n_steps: 10
gradient_clip_val: 1.0

logging:
wandb: false
project: "articulated-estimation"
save_dir: "logs"

checkpoint_dir: "checkpoints/estimation"
Loading