diff --git a/README.md b/README.md index 4c303af..69d089b 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,11 @@ Note: The ECE594 project is currently limited in scope to "Aim 1" above. ## Overview -This project investigates whether state-estimation objectives for articulated bodies induce structured neural representations analogous to grid codes in spatial navigation. We train recurrent networks on path integration for a robotic arm with configuration space Q = SO(3) × SO(3), analyze the learned representations, and evaluate their utility for downstream reinforcement learning. +This project investigates whether state-estimation objectives for articulated bodies induce structured neural representations analogous to grid codes in spatial navigation. We train recurrent networks on path integration for a robotic arm and analyze the learned representations. + +**Configuration spaces:** +- **SO(3) × SO(3)** — full 3D rotational joints (6D velocities) +- **SO(2) × SO(2)** — planar arm on the torus (2D velocities). Neuron activations are directly plottable on (θ1, θ2) without dimensionality reduction **Project Stages:** 1. **Body-state estimation** (Team Estimation): Train RNN to perform path integration on joint angular velocities @@ -76,7 +80,7 @@ articulated/ │ │ └── train.py # Training script │ │ │ ├── shared/ # Shared utilities -│ │ └── robot_arm.py # Kinematics on SO(3) × SO(3) +│ │ └── robot_arm.py # Kinematics on SO(3)×SO(3) and SO(2)×SO(2) │ │ │ ├── configs/ # Configuration files │ │ ├── estimation/ # Team Estimation configs @@ -95,21 +99,34 @@ articulated/ ### Team Estimation -**Goal:** Train RNN to perform path integration on SO(3) × SO(3). +**Goal:** Train RNN/LSTM/GRU to perform path integration. Supports SO(3)×SO(3) and SO(2)×SO(2). **Key files:** -- `articulated/estimation/datamodule.py`: Implement trajectory generation (inputs,targets) -- `articulated/estimation/model.py`: Define RNN architecture -- `articulated/estimation/train.py`: Training script +- `articulated/estimation/datamodule.py`: Trajectory data generation (supports `manifold="so2"` / `"so3"`) +- `articulated/estimation/model.py`: RNN/LSTM/GRU architectures with configurable `init_pos_size` +- `scripts/generate_data.py`: Parallel data generation with multiprocessing +- `scripts/train.py`: Training entry point +- `scripts/train.sh`: Convenience shell script +- `scripts/analyze_representation.py`: PCA, t-SNE, and tuning curve / torus heatmap analysis + +**Generate data and train:** +```bash +# SO(2) — planar arm on the torus +python scripts/generate_data.py --manifold so2 --n_train 100000 --n_val 5000 --workers 16 +bash scripts/train.sh gru 0 so2 -**Key TODOs:** -1. Implement proper SO(3) × SO(3) trajectory generation in `_generate_single_trajectory()` -2. Implement proper "place cell" targets on SO(3) × SO(3) -3. Experiment with RNN vs LSTM vs GRU architectures +# SO(3) — full 3D rotational joints +python scripts/generate_data.py --n_train 100000 --n_val 5000 --workers 16 +bash scripts/train.sh gru 0 +``` -**Run training:** +**Analyze representations:** ```bash -python -m articulated.estimation.train --config articulated/configs/estimation/rnn.yaml +# SO(2): direct (θ1, θ2) heatmaps +PYTHONPATH=. python scripts/analyze_representation.py path/to/checkpoint.ckpt --manifold so2 + +# SO(3): PCA + t-SNE + tuning curves +PYTHONPATH=. python scripts/analyze_representation.py path/to/checkpoint.ckpt ``` **Interface with other teams:** diff --git a/articulated/configs/estimation/gru_so2.yaml b/articulated/configs/estimation/gru_so2.yaml new file mode 100644 index 0000000..4a54be8 --- /dev/null +++ b/articulated/configs/estimation/gru_so2.yaml @@ -0,0 +1,39 @@ +# Configuration for GRU state estimation on SO(2) x SO(2) (2D torus) +# Usage: python scripts/train.py --manifold so2 --model_type gru + +seed: 42 + +model: + input_size: 2 # 2 joints x 1 angular velocity component + hidden_size: 256 + output_size: 64 # Number of place cells (32 per joint) + init_pos_size: 4 # (cos θ1, sin θ1, cos θ2, sin θ2) + model_type: "gru" + learning_rate: 1e-3 + weight_decay: 1e-4 + use_init_pos: true + dropout: 0.5 + +data: + batch_size: 64 + seq_length: 100 + n_trajectories_train: 1000 + n_trajectories_val: 100 + n_place_cells: 64 + dt: 0.01 + provide_init_pos: true + manifold: "so2" + +trainer: + max_epochs: 100 + accelerator: "auto" + devices: 1 + log_every_n_steps: 10 + gradient_clip_val: 1.0 + +logging: + wandb: false + project: "articulated-estimation" + save_dir: "logs" + +checkpoint_dir: "checkpoints/estimation" diff --git a/articulated/estimation/datamodule.py b/articulated/estimation/datamodule.py index 1be600e..52d0f4b 100644 --- a/articulated/estimation/datamodule.py +++ b/articulated/estimation/datamodule.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm -from articulated.shared.robot_arm import RobotArmKinematics +from articulated.shared.robot_arm import RobotArm2DKinematics, RobotArmKinematics class EstimationDataModule(L.LightningDataModule): @@ -41,6 +41,7 @@ def __init__( place_cell_kappa: float = 5.0, data_path: Optional[str] = None, num_workers: int = 4, + manifold: str = "so3", ): """Initialize the data module. @@ -58,18 +59,19 @@ def __init__( velocity_theta: Mean-reversion rate of OU angular velocity process. provide_init_pos: Whether to include initial position place cell activations in the dataset (needed for path integration). - place_cell_kappa: Concentration parameter for von Mises-Fisher kernel - on SO(3). Higher = more peaked. kappa=5 gives ~4-5 effective - cells per joint out of 32. Gaussian RBF fails on SO(3) due to - the bounded geometry; vMF is the natural analogue. + place_cell_kappa: Concentration parameter for von Mises-Fisher kernel. + Higher = more peaked. kappa=5 gives ~4-5 effective cells per joint. data_path: Path to pre-generated .pt data file. If provided, loads data from disk instead of generating. Use scripts/generate_data.py to create these files. num_workers: Number of DataLoader workers. + manifold: Configuration space manifold. "so3" for SO(3)xSO(3) (6D), + "so2" for SO(2)xSO(2) (2D torus). """ super().__init__() self.save_hyperparameters() + assert manifold in ("so3", "so2"), f"Unknown manifold: {manifold}" assert ( n_place_cells % 2 == 0 ), "n_place_cells must be even (split between 2 joints)" @@ -87,8 +89,18 @@ def __init__( self.place_cell_kappa = place_cell_kappa self.data_path = data_path self.num_workers = num_workers - - self.kinematics = RobotArmKinematics() + self.manifold = manifold + + # Velocity dimension and init_pos dimension depend on manifold + self.vel_dim = 2 if manifold == "so2" else 6 + # SO(2): raw angles encoded as (cos θ1, sin θ1, cos θ2, sin θ2) → 4D + # SO(3): place cell activations → output_size + self.init_pos_dim = 4 if manifold == "so2" else n_place_cells + + if manifold == "so2": + self.kinematics = RobotArm2DKinematics() + else: + self.kinematics = RobotArmKinematics() self.train_dataset: Optional[TensorDataset] = None self.val_dataset: Optional[TensorDataset] = None @@ -109,24 +121,35 @@ def _load_from_disk(self) -> None: print(f"Loading data from {self.data_path}...") data = torch.load(self.data_path, weights_only=True) + # Init position key depends on manifold + init_key_train = ( + "train_init_angles" if self.manifold == "so2" else "train_init_pcs" + ) + init_key_val = "val_init_angles" if self.manifold == "so2" else "val_init_pcs" + train_tensors = [data["train_velocities"], data["train_targets"]] - if self.provide_init_pos and "train_init_pcs" in data: - train_tensors.append(data["train_init_pcs"]) + if self.provide_init_pos and init_key_train in data: + train_tensors.append(data[init_key_train]) self.train_dataset = TensorDataset(*train_tensors) val_tensors = [data["val_velocities"], data["val_targets"]] - if self.provide_init_pos and "val_init_pcs" in data: - val_tensors.append(data["val_init_pcs"]) + if self.provide_init_pos and init_key_val in data: + val_tensors.append(data[init_key_val]) self.val_dataset = TensorDataset(*val_tensors) # Restore place cell centers for analysis - if "place_cell_quats1" in data and "place_cell_quats2" in data: - self._centers_R1 = Rotation.from_quat(data["place_cell_quats1"].numpy()) - self._centers_R2 = Rotation.from_quat(data["place_cell_quats2"].numpy()) + if self.manifold == "so2": + if "place_cell_angles1" in data and "place_cell_angles2" in data: + self._centers_angles1 = data["place_cell_angles1"].numpy() + self._centers_angles2 = data["place_cell_angles2"].numpy() + else: + if "place_cell_quats1" in data and "place_cell_quats2" in data: + self._centers_R1 = Rotation.from_quat(data["place_cell_quats1"].numpy()) + self._centers_R2 = Rotation.from_quat(data["place_cell_quats2"].numpy()) n_train = data["train_velocities"].shape[0] n_val = data["val_velocities"].shape[0] - print(f"Loaded {n_train} train + {n_val} val trajectories") + print(f"Loaded {n_train} train + {n_val} val trajectories ({self.manifold})") def _generate_all_data(self) -> None: """Generate all train/val data from scratch.""" @@ -136,7 +159,7 @@ def _generate_all_data(self) -> None: self._initialize_place_cells(rng) # Generate training data - train_velocities, train_targets, train_init_pcs = self._generate_trajectories( + train_velocities, train_targets, train_init = self._generate_trajectories( self.n_trajectories_train, rng ) train_tensors = [ @@ -144,11 +167,11 @@ def _generate_all_data(self) -> None: torch.from_numpy(train_targets).float(), ] if self.provide_init_pos: - train_tensors.append(torch.from_numpy(train_init_pcs).float()) + train_tensors.append(torch.from_numpy(train_init).float()) self.train_dataset = TensorDataset(*train_tensors) # Generate validation data - val_velocities, val_targets, val_init_pcs = self._generate_trajectories( + val_velocities, val_targets, val_init = self._generate_trajectories( self.n_trajectories_val, rng ) val_tensors = [ @@ -156,23 +179,32 @@ def _generate_all_data(self) -> None: torch.from_numpy(val_targets).float(), ] if self.provide_init_pos: - val_tensors.append(torch.from_numpy(val_init_pcs).float()) + val_tensors.append(torch.from_numpy(val_init).float()) self.val_dataset = TensorDataset(*val_tensors) def _initialize_place_cells(self, rng: np.random.Generator) -> None: - """Initialize place cell centers separately for each joint on SO(3). + """Initialize place cell centers separately for each joint. - Samples n_cells_per_joint random rotations independently for each joint. - This avoids the curse of dimensionality: 32 cells on 3D SO(3) is much - more tractable than 64 cells on 6D SO(3) x SO(3). + For SO(3): random rotations on SO(3). + For SO(2): random angles on [0, 2*pi). """ n = self.n_cells_per_joint - self._centers_R1 = Rotation.random(n, random_state=int(rng.integers(0, 2**31))) - self._centers_R2 = Rotation.random(n, random_state=int(rng.integers(0, 2**31))) - # Store as list of pairs for compatibility - self.place_cell_centers = [ - (self._centers_R1[i], self._centers_R2[i]) for i in range(n) - ] + if self.manifold == "so2": + self._centers_angles1 = rng.uniform(0, 2 * np.pi, size=n) + self._centers_angles2 = rng.uniform(0, 2 * np.pi, size=n) + self.place_cell_centers = [ + (self._centers_angles1[i], self._centers_angles2[i]) for i in range(n) + ] + else: + self._centers_R1 = Rotation.random( + n, random_state=int(rng.integers(0, 2**31)) + ) + self._centers_R2 = Rotation.random( + n, random_state=int(rng.integers(0, 2**31)) + ) + self.place_cell_centers = [ + (self._centers_R1[i], self._centers_R2[i]) for i in range(n) + ] def _generate_trajectories( self, n_trajectories: int, rng: np.random.Generator @@ -184,22 +216,22 @@ def _generate_trajectories( rng: Random number generator. Returns: - Tuple of (velocities, place_cell_targets, init_place_cells). - velocities: Shape (n_trajectories, seq_length, 6). + Tuple of (velocities, place_cell_targets, init_pos). + velocities: Shape (n_trajectories, seq_length, vel_dim). targets: Shape (n_trajectories, seq_length, n_place_cells). - init_place_cells: Shape (n_trajectories, n_place_cells). + init_pos: Shape (n_trajectories, init_pos_dim). """ - velocities = np.zeros((n_trajectories, self.seq_length, 6)) + velocities = np.zeros((n_trajectories, self.seq_length, self.vel_dim)) targets = np.zeros((n_trajectories, self.seq_length, self.n_place_cells)) - init_pcs = np.zeros((n_trajectories, self.n_place_cells)) + init_pos = np.zeros((n_trajectories, self.init_pos_dim)) for i in tqdm(range(n_trajectories), desc="Generating trajectories"): - vel, target, init_pc = self._generate_single_trajectory(rng) + vel, target, ip = self._generate_single_trajectory(rng) velocities[i] = vel targets[i] = target - init_pcs[i] = init_pc + init_pos[i] = ip - return velocities, targets, init_pcs + return velocities, targets, init_pos def _generate_single_trajectory( self, rng: np.random.Generator @@ -207,34 +239,47 @@ def _generate_single_trajectory( """Generate a single trajectory. Uses an Ornstein-Uhlenbeck process to generate smooth angular velocities, - integrates on SO(3) x SO(3), and computes place cell activations. + integrates on the configuration manifold, and computes place cell activations. Args: rng: Random number generator. Returns: - Tuple of (velocities, place_cell_targets, init_place_cells). - velocities: Shape (seq_length, 6). + Tuple of (velocities, place_cell_targets, init_pos). + velocities: Shape (seq_length, vel_dim). targets: Shape (seq_length, n_place_cells). - init_place_cells: Shape (n_place_cells,). + init_pos: Shape (init_pos_dim,). """ config = self.kinematics.sample_random_configuration(rng) - # Place cell activation at the starting position - init_pc = self._compute_place_cell_activations(config) + # Initial position encoding + if self.manifold == "so2": + # Raw angles → (cos θ1, sin θ1, cos θ2, sin θ2) + theta1, theta2 = config + init_pos = np.array( + [ + np.cos(theta1), + np.sin(theta1), + np.cos(theta2), + np.sin(theta2), + ] + ) + else: + # Place cell activation at the starting position + init_pos = self._compute_place_cell_activations(config) decay = np.exp(-self.velocity_theta * self.dt) noise_scale = self.velocity_sigma * np.sqrt(1.0 - decay**2) - velocities = np.zeros((self.seq_length, 6)) + velocities = np.zeros((self.seq_length, self.vel_dim)) targets = np.zeros((self.seq_length, self.n_place_cells)) # Initialize angular velocity at zero - omega = np.zeros(6) + omega = np.zeros(self.vel_dim) for t in range(self.seq_length): # OU update for angular velocity - omega = decay * omega + noise_scale * rng.standard_normal(6) + omega = decay * omega + noise_scale * rng.standard_normal(self.vel_dim) velocities[t] = omega # Integrate to get new configuration @@ -243,31 +288,39 @@ def _generate_single_trajectory( # Compute place cell activations targets[t] = self._compute_place_cell_activations(config) - return velocities, targets, init_pc + return velocities, targets, init_pos def _compute_place_cell_activations(self, configuration: tuple) -> np.ndarray: """Compute place cell activations for a single configuration. - Uses von Mises-Fisher kernel with separate softmax per joint. - The vMF kernel exp(kappa * cos(theta)) is the natural analogue of - Gaussian RBF on the rotation group SO(3), and produces properly - peaked distributions unlike Gaussian RBF which fails due to the - bounded geometry of SO(3). + Uses vMF-style kernel with separate softmax per joint. + + For SO(3): geodesic distance on rotation group. + For SO(2): circular distance on the circle. Args: - configuration: Current configuration as (Rotation, Rotation). + configuration: Current configuration. (Rotation, Rotation) for SO(3), + or (float, float) for SO(2). Returns: Place cell activations of shape (n_place_cells,), which is the concatenation of two independent distributions of size n_cells_per_joint. """ - R1_cur, R2_cur = configuration kappa = self.place_cell_kappa - # Geodesic distances to all centers for each joint - d1 = (R1_cur.inv() * self._centers_R1).magnitude() - d2 = (R2_cur.inv() * self._centers_R2).magnitude() + if self.manifold == "so2": + theta1, theta2 = configuration + # Circular distance: min(|delta|, 2pi - |delta|) + TWO_PI = 2.0 * np.pi + delta1 = np.abs(theta1 - self._centers_angles1) + d1 = np.minimum(delta1, TWO_PI - delta1) + delta2 = np.abs(theta2 - self._centers_angles2) + d2 = np.minimum(delta2, TWO_PI - delta2) + else: + R1_cur, R2_cur = configuration + d1 = (R1_cur.inv() * self._centers_R1).magnitude() + d2 = (R2_cur.inv() * self._centers_R2).magnitude() # vMF kernel + softmax per joint (numerically stable) logits1 = kappa * np.cos(d1) diff --git a/articulated/estimation/model.py b/articulated/estimation/model.py index f604441..30faf57 100644 --- a/articulated/estimation/model.py +++ b/articulated/estimation/model.py @@ -180,6 +180,7 @@ def __init__( weight_decay: float = 0.0, use_init_pos: bool = True, dropout: float = 0.0, + init_pos_size: Optional[int] = None, ): """Initialize the Lightning module. @@ -192,6 +193,9 @@ def __init__( weight_decay: Weight decay for optimizer. use_init_pos: Whether to encode initial position into h0. dropout: Dropout rate on RNN output (0.5 recommended per Banino 2018). + init_pos_size: Dimension of initial position input. If None, defaults + to output_size (SO(3) place cell activations). Set to 4 for SO(2) + where init is (cos θ1, sin θ1, cos θ2, sin θ2). """ super().__init__() self.save_hyperparameters() @@ -202,6 +206,7 @@ def __init__( self.weight_decay = weight_decay self.use_init_pos = use_init_pos self.model_type = model_type + self.init_pos_size = init_pos_size if init_pos_size is not None else output_size # Instantiate the appropriate model self.model: Union[RNN, LSTM, GRU] @@ -220,7 +225,7 @@ def __init__( # Encoder for initial position → initial hidden state if use_init_pos: - self.init_encoder = nn.Linear(output_size, hidden_size) + self.init_encoder = nn.Linear(self.init_pos_size, hidden_size) self.loss_fn = nn.KLDivLoss(reduction="batchmean") @@ -234,18 +239,20 @@ def forward( """Forward pass through the model.""" return self.model(x, hidden) - def _encode_init_pos(self, init_pc: torch.Tensor) -> Any: - """Encode initial place cell activations into initial hidden state. + def _encode_init_pos(self, init_pos: torch.Tensor) -> Any: + """Encode initial position into initial hidden state. Args: - init_pc: Initial place cell activations of shape (batch, output_size). + init_pos: Initial position of shape (batch, init_pos_size). + For SO(3): place cell activations (batch, output_size). + For SO(2): (cos θ1, sin θ1, cos θ2, sin θ2) → (batch, 4). Returns: Initial hidden state of shape (1, batch, hidden_size) for RNN/GRU, or tuple of (h0, c0) each (1, batch, hidden_size) for LSTM. """ - # (batch, output_size) → (batch, hidden_size) - h0 = self.init_encoder(init_pc) + # (batch, init_pos_size) → (batch, hidden_size) + h0 = self.init_encoder(init_pos) # RNN expects (num_layers, batch, hidden_size) h0 = h0.unsqueeze(0) diff --git a/articulated/shared/robot_arm.py b/articulated/shared/robot_arm.py index bbe6721..d2470c4 100644 --- a/articulated/shared/robot_arm.py +++ b/articulated/shared/robot_arm.py @@ -8,6 +8,8 @@ Both Team Estimation and Team RL will use these utilities. """ +from typing import Union + import numpy as np from scipy.spatial.transform import Rotation @@ -121,3 +123,82 @@ def sample_random_configuration( R2 = Rotation.random(random_state=rng.integers(0, 2**31)) return (R1, R2) + + +class RobotArm2DKinematics: + """Kinematics utilities for a 2-joint arm on SO(2) x SO(2). + + Each joint is parameterized by a single angle theta in [0, 2*pi). + The configuration is a tuple of two scalar angles (theta1, theta2). + This is the 2D (planar) restriction of the full SO(3) arm. + """ + + def __init__(self, link_lengths: tuple[float, float] = (1.0, 1.0)): + self.link_lengths = link_lengths + + def integrate_velocity( + self, + current_config: tuple[float, float], + angular_velocity: Union[np.ndarray, list], + dt: float, + ) -> tuple[float, float]: + """Integrate angular velocities to update joint angles. + + Args: + current_config: Current angles (theta1, theta2) in [0, 2*pi). + angular_velocity: Angular velocities [omega1, omega2] (2D). + dt: Time step. + + Returns: + Updated angles (theta1, theta2) wrapped to [0, 2*pi). + """ + TWO_PI = 2.0 * np.pi + theta1 = (current_config[0] + angular_velocity[0] * dt) % TWO_PI + theta2 = (current_config[1] + angular_velocity[1] * dt) % TWO_PI + return (float(theta1), float(theta2)) + + def forward_kinematics(self, config: tuple[float, float]) -> np.ndarray: + """Compute end-effector position from joint angles. + + Args: + config: Joint angles (theta1, theta2). + + Returns: + End-effector position in 2D space, shape (2,). + """ + theta1, theta2 = config + L1, L2 = self.link_lengths + x = L1 * np.cos(theta1) + L2 * np.cos(theta1 + theta2) + y = L1 * np.sin(theta1) + L2 * np.sin(theta1 + theta2) + return np.array([x, y]) + + def geodesic_distance( + self, + config1: tuple[float, float], + config2: tuple[float, float], + ) -> float: + """Compute geodesic distance between two configurations on SO(2) x SO(2). + + Uses circular distance per joint: min(|delta|, 2*pi - |delta|). + Combined as sqrt(d1^2 + d2^2). + """ + TWO_PI = 2.0 * np.pi + d1 = abs(config1[0] - config2[0]) + d1 = min(d1, TWO_PI - d1) + d2 = abs(config1[1] - config2[1]) + d2 = min(d2, TWO_PI - d2) + return float(np.sqrt(d1**2 + d2**2)) + + def sample_random_configuration( + self, rng: np.random.Generator | None = None + ) -> tuple[float, float]: + """Sample a random configuration uniformly on SO(2) x SO(2). + + Returns: + Random angles (theta1, theta2) in [0, 2*pi). + """ + if rng is None: + rng = np.random.default_rng() + theta1 = float(rng.uniform(0, 2 * np.pi)) + theta2 = float(rng.uniform(0, 2 * np.pi)) + return (theta1, theta2) diff --git a/assets/results/gru_2d_heatmaps.png b/assets/results/gru_2d_heatmaps.png new file mode 100644 index 0000000..9b308c5 Binary files /dev/null and b/assets/results/gru_2d_heatmaps.png differ diff --git a/assets/results/gru_neuron_pca_heatmap.png b/assets/results/gru_neuron_pca_heatmap.png new file mode 100644 index 0000000..6091ade Binary files /dev/null and b/assets/results/gru_neuron_pca_heatmap.png differ diff --git a/assets/results/gru_neuron_pca_tuning.png b/assets/results/gru_neuron_pca_tuning.png new file mode 100644 index 0000000..6acc1dc Binary files /dev/null and b/assets/results/gru_neuron_pca_tuning.png differ diff --git a/assets/results/gru_place_cell_pca.png b/assets/results/gru_place_cell_pca.png new file mode 100644 index 0000000..f01b4a8 Binary files /dev/null and b/assets/results/gru_place_cell_pca.png differ diff --git a/assets/results/rnn_2d_heatmaps.png b/assets/results/rnn_2d_heatmaps.png new file mode 100644 index 0000000..816d8ab Binary files /dev/null and b/assets/results/rnn_2d_heatmaps.png differ diff --git a/assets/results/rnn_neuron_pca_heatmap.png b/assets/results/rnn_neuron_pca_heatmap.png new file mode 100644 index 0000000..b3d63d9 Binary files /dev/null and b/assets/results/rnn_neuron_pca_heatmap.png differ diff --git a/assets/results/rnn_neuron_pca_tuning.png b/assets/results/rnn_neuron_pca_tuning.png new file mode 100644 index 0000000..986c97f Binary files /dev/null and b/assets/results/rnn_neuron_pca_tuning.png differ diff --git a/assets/results/rnn_place_cell_pca.png b/assets/results/rnn_place_cell_pca.png new file mode 100644 index 0000000..42a457b Binary files /dev/null and b/assets/results/rnn_place_cell_pca.png differ diff --git a/assets/results/tuning_curves.png b/assets/results/tuning_curves.png new file mode 100644 index 0000000..4f325ed Binary files /dev/null and b/assets/results/tuning_curves.png differ diff --git a/discussion.md b/discussion.md index 51a5dfd..4aed385 100644 --- a/discussion.md +++ b/discussion.md @@ -169,9 +169,67 @@ embedding = model.get_embedding(velocities) # (batch, 256) hidden = model.get_hidden_states(velocities) # (batch, seq_len, 256) ``` +## SO(2) × SO(2): Restricting to the Torus + +### Motivation + +The SO(3) analysis above uses PCA and t-SNE to visualize 256-dim hidden states in 2D. This is lossy — a neuron could have clean grid-like structure in the full space that looks like noise after projection. We can't plot neuron activation vs. true configuration because the config space is 6D. + +**Solution:** Restrict the arm to a plane. Each joint rotates in SO(2) (a circle), so the configuration space is SO(2) × SO(2) = a torus, parameterized by two angles (θ1, θ2). This is 2D — we can plot neuron activation directly on the (θ1, θ2) grid with zero information loss. + +### What Changed + +| | SO(3) × SO(3) | SO(2) × SO(2) | +|---|---|---| +| Config space | 6D (two 3D rotations) | 2D (two angles on [0, 2π)) | +| Velocity input | 6D per timestep | 2D per timestep | +| Integration | Exponential map on SO(3) | θ += ω·dt mod 2π | +| Place cell distance | Geodesic on rotation group | Circular: min(\|δθ\|, 2π - \|δθ\|) | +| Init position encoding | Place cell activation (64D) | Raw angles as (cos θ1, sin θ1, cos θ2, sin θ2) → 4D | +| Visualization | PCA/t-SNE (lossy) | Direct (θ1, θ2) heatmaps (lossless) | + +The place cell targets still use the vMF kernel (`κ·cos(d)` + softmax), 32 cells per joint, same as SO(3). The key difference for init position: instead of encoding the starting pose as a 64-dim place cell distribution, we encode it as a 4-dim vector `(cos θ1, sin θ1, cos θ2, sin θ2)` fed through `Linear(4, 256)` → h0. + +### Training Results + +GRU on SO(2)×SO(2): hidden=256, dropout=0.5, 100k training trajectories, cosine annealing LR. Training stopped at epoch 111/200. + +| Epoch | Val Accuracy | Val Loss | +|-------|-------------|----------| +| 0 | 75.7% | 0.0123 | +| 10 | 87.5% | 0.0027 | +| 30 | 94.0% | 0.0006 | +| 60 | 95.5% | 0.0004 | +| 110 | 96.9% | 0.0003 | + +**Best: 97.3% accuracy, loss 0.000311** (epoch 108). + +Compared to SO(3) GRU (97.8% after 200 epochs), SO(2) converges faster to comparable accuracy. This makes sense — the problem is simpler (2D vs 6D manifold, 2D vs 6D velocity input). + +### Representation Analysis + +![GRU SO(2) Representation Analysis](assets/results/gru_so2_representation_analysis.png) + +**PCA:** The hidden representation is low-dimensional — a small number of PCs capture most of the variance. The PC1 vs PC2 projections colored by θ1 and θ2 show smooth gradients, confirming the network encodes joint angles in its top principal components. + +**Torus heatmaps:** The key result. Each heatmap shows a single neuron's mean activation across the (θ1, θ2) configuration space. θ1-selective neurons show vertical stripes (responding to θ1 regardless of θ2), and θ2-selective neurons show horizontal stripes. This is direct evidence that individual neurons develop **angle-selective tuning** — they've learned to represent specific joint angles. + +### Commands + +```bash +# Generate SO(2) data +python scripts/generate_data.py --manifold so2 --n_train 100000 --n_val 5000 --workers 16 + +# Train +bash scripts/train.sh gru 0 so2 + +# Analyze +PYTHONPATH=. python scripts/analyze_representation.py logs/estimation/gru_so2/checkpoints/last.ckpt --manifold so2 +``` + ## Next Steps -- **Try GRU/LSTM** — May produce sharper representations than vanilla RNN -- **Longer training** — Banino et al. trained significantly longer; 200 epochs may not be enough for global structure -- **Regularization sweep** — Vary dropout, weight decay -- **Look for grid-like patterns** — Fourier analysis of hidden states, spatial autocorrelation on SO(3) +- **Longer SO(2) training** — We stopped at 111 epochs; running to 200 may improve further +- **Look for grid-like patterns** — Fourier analysis of hidden states, spatial autocorrelation on the torus +- **Regularization sweep** — Vary dropout, weight decay to see effect on representation structure +- **Compare SO(2) vs SO(3) representations** — Are SO(3) neurons also angle-selective, just harder to visualize? diff --git a/scripts/analyze_representation.py b/scripts/analyze_representation.py index f510665..0fa35ad 100644 --- a/scripts/analyze_representation.py +++ b/scripts/analyze_representation.py @@ -1,9 +1,11 @@ """Analyze learned representations from trained state estimation models. Generates PCA, t-SNE, and tuning curve plots on validation data. +For SO(2) models, generates direct (θ1, θ2) heatmaps — no PCA needed. Usage: python scripts/analyze_representation.py logs/estimation/rnn/checkpoints/last-v3.ckpt + python scripts/analyze_representation.py path/to/checkpoint.ckpt --manifold so2 python scripts/analyze_representation.py path/to/checkpoint.ckpt --n_traj 500 python scripts/analyze_representation.py path/to/checkpoint.ckpt --help """ @@ -31,8 +33,15 @@ def parse_args(): parser.add_argument( "--data_path", type=str, - default="data/estimation/trajectories.pt", - help="Path to trajectories.pt", + default=None, + help="Path to trajectories.pt (auto from manifold if not set)", + ) + parser.add_argument( + "--manifold", + type=str, + default="so3", + choices=["so3", "so2"], + help="Configuration manifold", ) parser.add_argument( "--n_traj", type=int, default=200, help="Number of val trajectories to use" @@ -46,27 +55,20 @@ def parse_args(): return parser.parse_args() -def get_joint_angles(data, n_traj, seq_len): - """Reconstruct true continuous joint rotations by replaying velocity integration. - - This gives the actual rotation at each timestep — not the quantized - nearest-cell-center approximation. - """ +def get_joint_angles_so3(data, n_traj, seq_len): + """Reconstruct true continuous joint rotations by replaying velocity integration.""" from articulated.shared.robot_arm import RobotArmKinematics kin = RobotArmKinematics() - vel = data["val_velocities"][:n_traj].numpy() # (n_traj, seq_len, 6) + vel = data["val_velocities"][:n_traj].numpy() dt = 0.01 - # Place cell centers for reconstructing initial config centers_R1 = Rotation.from_quat(data["place_cell_quats1"].numpy()) centers_R2 = Rotation.from_quat(data["place_cell_quats2"].numpy()) n_per_joint = len(centers_R1) - # Get initial place cell activations to find initial rotation - init_pcs = data["val_init_pcs"][:n_traj].numpy() # (n_traj, n_place_cells) + init_pcs = data["val_init_pcs"][:n_traj].numpy() - # Cell IDs for coloring (from targets) targets = data["val_targets"][:n_traj].numpy() j1_tgt = targets[:, :, :n_per_joint].reshape(-1, n_per_joint) j2_tgt = targets[:, :, n_per_joint:].reshape(-1, n_per_joint) @@ -77,7 +79,6 @@ def get_joint_angles(data, n_traj, seq_len): j2_rotvec = np.zeros((n_traj, seq_len, 3)) for i in range(n_traj): - # Reconstruct initial config from init_pc by picking nearest cell center init_j1 = init_pcs[i, :n_per_joint] init_j2 = init_pcs[i, n_per_joint:] R1_init = centers_R1[int(init_j1.argmax())] @@ -96,6 +97,44 @@ def get_joint_angles(data, n_traj, seq_len): return j1_rotvec, j2_rotvec, j1_cell, j2_cell +def get_joint_angles_so2(data, n_traj, seq_len): + """Reconstruct true joint angles by replaying velocity integration on SO(2).""" + from articulated.shared.robot_arm import RobotArm2DKinematics + + kin = RobotArm2DKinematics() + vel = data["val_velocities"][:n_traj].numpy() + dt = 0.01 + + # Recover initial angles from (cos θ, sin θ) encoding + init_encoded = data["val_init_angles"][:n_traj].numpy() # (n_traj, 4) + + n_per_joint = len(data["place_cell_angles1"]) + targets = data["val_targets"][:n_traj].numpy() + j1_tgt = targets[:, :, :n_per_joint].reshape(-1, n_per_joint) + j2_tgt = targets[:, :, n_per_joint:].reshape(-1, n_per_joint) + j1_cell = j1_tgt.argmax(axis=1) + j2_cell = j2_tgt.argmax(axis=1) + + j1_angles = np.zeros((n_traj, seq_len)) + j2_angles = np.zeros((n_traj, seq_len)) + + for i in range(n_traj): + theta1 = np.arctan2(init_encoded[i, 1], init_encoded[i, 0]) % (2 * np.pi) + theta2 = np.arctan2(init_encoded[i, 3], init_encoded[i, 2]) % (2 * np.pi) + config = (float(theta1), float(theta2)) + + for t in range(seq_len): + omega = vel[i, t] + config = kin.integrate_velocity(config, omega, dt) + j1_angles[i, t] = config[0] + j2_angles[i, t] = config[1] + + j1_angles = j1_angles.reshape(-1) + j2_angles = j2_angles.reshape(-1) + + return j1_angles, j2_angles, j1_cell, j2_cell + + def find_selective_neurons(hidden_flat, angles, top_k=3): """Find neurons most correlated with a joint angle variable.""" correlations = np.array( @@ -104,31 +143,248 @@ def find_selective_neurons(hidden_flat, angles, top_k=3): for i in range(hidden_flat.shape[1]) ] ) - # Replace NaN with 0 (constant neurons) correlations = np.nan_to_num(correlations) top_idx = np.argsort(correlations)[::-1][:top_k] return top_idx, correlations[top_idx] -def main(): - args = parse_args() +def _plot_tuning_curve(ax, activations, angles, title, xlabel, color): + """Plot binned tuning curve with scatter and confidence band.""" + n_bins = 50 + bins = np.linspace(angles.min(), angles.max(), n_bins + 1) + bin_centers = (bins[:-1] + bins[1:]) / 2 + bin_means = np.zeros(n_bins) + bin_stds = np.zeros(n_bins) - print(f"Loading checkpoint: {args.checkpoint}") - model = StateEstimationModel.load_from_checkpoint( - args.checkpoint, map_location="cpu" + for b in range(n_bins): + mask = (angles >= bins[b]) & (angles < bins[b + 1]) + if mask.sum() > 0: + bin_means[b] = activations[mask].mean() + bin_stds[b] = activations[mask].std() + + ax.fill_between( + bin_centers, bin_means - bin_stds, bin_means + bin_stds, alpha=0.2, color=color ) - model.eval() + ax.plot(bin_centers, bin_means, "-", linewidth=2, color=color) - print(f"Loading data: {args.data_path}") - data = torch.load(args.data_path, weights_only=True) + n_show = min(2000, len(angles)) + idx = np.random.default_rng(0).choice(len(angles), n_show, replace=False) + ax.scatter(angles[idx], activations[idx], s=1, alpha=0.08, color="gray") - n_traj = min(args.n_traj, data["val_velocities"].shape[0]) + ax.set_xlabel(xlabel) + ax.set_ylabel("Activation") + ax.set_title(title) + + +def analyze_so2(args, model, data, n_traj, seq_len, hidden_size): + """SO(2)-specific analysis with direct (θ1, θ2) heatmaps.""" + val_vel = data["val_velocities"][:n_traj] + val_init = data["val_init_angles"][:n_traj] + + print(f"Running inference on {n_traj} trajectories...") + with torch.no_grad(): + h0 = model._encode_init_pos(val_init) + _, hidden_states = model(val_vel, hidden=h0) + + hidden_flat = hidden_states.numpy().reshape(-1, hidden_size) + n_points = hidden_flat.shape[0] + print(f"Hidden states: {n_points} points x {hidden_size} dims") + + j1_angles, j2_angles, j1_cell, j2_cell = get_joint_angles_so2(data, n_traj, seq_len) + + # PCA + print("Running PCA...") + pca = PCA(n_components=min(50, hidden_size)) + hidden_pca = pca.fit_transform(hidden_flat) + cumvar = np.cumsum(pca.explained_variance_ratio_) + n95 = int(np.searchsorted(cumvar, 0.95) + 1) + + # Find most selective neurons for θ1 and θ2 + top_j1, corr_j1 = find_selective_neurons(hidden_flat, j1_angles, top_k=3) + top_j2, corr_j2 = find_selective_neurons(hidden_flat, j2_angles, top_k=3) + print(f"Most θ1-selective neurons: {top_j1} (r={corr_j1})") + print(f"Most θ2-selective neurons: {top_j2} (r={corr_j2})") + + # ===== PLOT ===== + model_type = model.hparams.get("model_type", "unknown").upper() + fig = plt.figure(figsize=(20, 24)) + gs = GridSpec(4, 3, figure=fig, hspace=0.45, wspace=0.35) + + fig.suptitle( + f"{model_type} Representation Analysis — SO(2)×SO(2) " + f"(vMF, kappa=5, hidden={hidden_size})", + fontsize=16, + fontweight="bold", + y=0.98, + ) + + # ── Row 1: PCA spectrum + summary ──────────────────────────────────────── + ax = fig.add_subplot(gs[0, 0]) + ax.bar( + range(min(20, len(pca.explained_variance_ratio_))), + pca.explained_variance_ratio_[:20], + color="steelblue", + ) + ax.set_xlabel("PC Index") + ax.set_ylabel("Explained Variance Ratio") + ax.set_title("PCA Spectrum") + + ax = fig.add_subplot(gs[0, 1]) + ax.plot(range(1, min(21, len(cumvar) + 1)), cumvar[:20], "o-", color="steelblue") + ax.axhline(y=0.95, color="red", linestyle="--", alpha=0.7, label="95%") + ax.axvline(x=n95, color="red", linestyle=":", alpha=0.7) + ax.set_xlabel("Number of PCs") + ax.set_ylabel("Cumulative Variance") + ax.set_title(f"Cumulative Variance\n{n95} PCs for 95%") + ax.legend() + + ax = fig.add_subplot(gs[0, 2]) + ax.text( + 0.5, + 0.5, + f"Model: {model_type} (hidden={hidden_size})\n" + f"Manifold: SO(2)×SO(2)\n" + f"Output: {model.output_size} ({model.cells_per_joint}/joint)\n" + f"Dropout: {model.hparams.get('dropout', '?')}\n\n" + f"PCs for 95%: {n95}\n" + f"PC1: {pca.explained_variance_ratio_[0]*100:.1f}%\n" + f"Top-3: {cumvar[2]*100:.1f}%\n\n" + f"θ1-selective: {top_j1}\n" + f" r = [{', '.join(f'{c:.3f}' for c in corr_j1)}]\n" + f"θ2-selective: {top_j2}\n" + f" r = [{', '.join(f'{c:.3f}' for c in corr_j2)}]", + transform=ax.transAxes, + fontsize=10, + verticalalignment="center", + horizontalalignment="center", + fontfamily="monospace", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + ax.set_title("Summary") + ax.axis("off") + + # ── Row 2: PCA colored by θ1, θ2, and cell ID ─────────────────────────── + n_plot = min(3000, n_points) + plot_idx = np.random.default_rng(0).choice(n_points, n_plot, replace=False) + + ax = fig.add_subplot(gs[1, 0]) + sc = ax.scatter( + hidden_pca[plot_idx, 0], + hidden_pca[plot_idx, 1], + c=j1_angles[plot_idx], + cmap="hsv", + s=2, + alpha=0.5, + ) + plt.colorbar(sc, ax=ax, label="θ1") + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.set_title("PC1 vs PC2\nColored by θ1") + + ax = fig.add_subplot(gs[1, 1]) + sc = ax.scatter( + hidden_pca[plot_idx, 0], + hidden_pca[plot_idx, 1], + c=j2_angles[plot_idx], + cmap="hsv", + s=2, + alpha=0.5, + ) + plt.colorbar(sc, ax=ax, label="θ2") + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.set_title("PC1 vs PC2\nColored by θ2") + + ax = fig.add_subplot(gs[1, 2]) + sc = ax.scatter( + hidden_pca[plot_idx, 0], + hidden_pca[plot_idx, 1], + c=j1_cell[plot_idx], + cmap="tab20", + s=2, + alpha=0.5, + ) + plt.colorbar(sc, ax=ax, label="Dominant cell ID") + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.set_title("PC1 vs PC2\nColored by cell ID") + + # ── Row 3: Direct (θ1, θ2) heatmaps — θ1-selective neurons ───────────── + n_bins = 40 + for i in range(3): + neuron_idx = top_j1[i] + ax = fig.add_subplot(gs[2, i]) + _plot_torus_heatmap( + ax, + j1_angles, + j2_angles, + hidden_flat[:, neuron_idx], + n_bins, + f"Neuron {neuron_idx} — θ1-selective (r={corr_j1[i]:.3f})", + ) + + # ── Row 4: Direct (θ1, θ2) heatmaps — θ2-selective neurons ────────── + for i in range(3): + neuron_idx = top_j2[i] + ax = fig.add_subplot(gs[3, i]) + _plot_torus_heatmap( + ax, + j1_angles, + j2_angles, + hidden_flat[:, neuron_idx], + n_bins, + f"Neuron {neuron_idx} — θ2-selective (r={corr_j2[i]:.3f})", + ) + + # Save + if args.output: + output_path = args.output + else: + output_path = str( + Path(args.checkpoint).parent.parent / "representation_analysis_so2.png" + ) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"\nSaved to {output_path}") + + +def _plot_torus_heatmap(ax, theta1, theta2, activations, n_bins, title): + """Plot neuron activation as a heatmap on (θ1, θ2) torus.""" + bins = np.linspace(0, 2 * np.pi, n_bins + 1) + heatmap = np.full((n_bins, n_bins), np.nan) + + for i in range(n_bins): + for j in range(n_bins): + mask = ( + (theta1 >= bins[i]) + & (theta1 < bins[i + 1]) + & (theta2 >= bins[j]) + & (theta2 < bins[j + 1]) + ) + if mask.sum() > 0: + heatmap[j, i] = activations[mask].mean() + + im = ax.imshow( + heatmap, + extent=[0, 2 * np.pi, 0, 2 * np.pi], + origin="lower", + aspect="equal", + cmap="viridis", + ) + plt.colorbar(im, ax=ax, label="Mean activation") + ax.set_xlabel("θ1") + ax.set_ylabel("θ2") + ax.set_title(title) + ax.set_xticks([0, np.pi, 2 * np.pi]) + ax.set_xticklabels(["0", "π", "2π"]) + ax.set_yticks([0, np.pi, 2 * np.pi]) + ax.set_yticklabels(["0", "π", "2π"]) + + +def analyze_so3(args, model, data, n_traj, seq_len, hidden_size): + """SO(3)-specific analysis (original code path).""" val_vel = data["val_velocities"][:n_traj] val_init = data["val_init_pcs"][:n_traj] - seq_len = val_vel.shape[1] - hidden_size = model.hidden_size - # Forward pass print(f"Running inference on {n_traj} trajectories...") with torch.no_grad(): h0 = model._encode_init_pos(val_init) @@ -138,8 +394,7 @@ def main(): n_points = hidden_flat.shape[0] print(f"Hidden states: {n_points} points x {hidden_size} dims") - # Joint rotation vectors (3D per joint) - j1_rotvec, j2_rotvec, j1_cell, j2_cell = get_joint_angles(data, n_traj, seq_len) + j1_rotvec, j2_rotvec, j1_cell, j2_cell = get_joint_angles_so3(data, n_traj, seq_len) axis_labels = ["x", "y", "z"] # PCA @@ -149,14 +404,14 @@ def main(): cumvar = np.cumsum(pca.explained_variance_ratio_) n95 = int(np.searchsorted(cumvar, 0.95) + 1) - # t-SNE (subsample for speed) + # t-SNE n_tsne = min(5000, n_points) tsne_idx = np.random.default_rng(42).choice(n_points, n_tsne, replace=False) print(f"Running t-SNE on {n_tsne} points (perplexity={args.tsne_perplexity})...") tsne = TSNE(n_components=2, perplexity=args.tsne_perplexity, random_state=42) hidden_tsne = tsne.fit_transform(hidden_flat[tsne_idx]) - # Find most selective neurons (check all 3 components per joint, pick best) + # Find most selective neurons best_j1_neurons, best_j1_corr, best_j1_comp = [], [], [] for comp in range(3): idx, corr = find_selective_neurons(hidden_flat, j1_rotvec[:, comp], top_k=1) @@ -173,7 +428,7 @@ def main(): print(f"Most J2-selective neurons: {best_j2_neurons} (r={best_j2_corr})") # ===== PLOT ===== - n_plot = min(3000, n_points) # subsample for scatter readability + n_plot = min(3000, n_points) plot_idx = np.random.default_rng(0).choice(n_points, n_plot, replace=False) fig = plt.figure(figsize=(20, 28)) @@ -188,7 +443,7 @@ def main(): y=0.98, ) - # ── Row 1: PCA spectrum ────────────────────────────────────────────────── + # Row 1: PCA spectrum ax = fig.add_subplot(gs[0, 0]) ax.bar( range(min(20, len(pca.explained_variance_ratio_))), @@ -197,7 +452,7 @@ def main(): ) ax.set_xlabel("PC Index") ax.set_ylabel("Explained Variance Ratio") - ax.set_title("(A) PCA Spectrum\nVariance per principal component") + ax.set_title("PCA Spectrum") ax = fig.add_subplot(gs[0, 1]) ax.plot(range(1, min(21, len(cumvar) + 1)), cumvar[:20], "o-", color="steelblue") @@ -205,7 +460,7 @@ def main(): ax.axvline(x=n95, color="red", linestyle=":", alpha=0.7) ax.set_xlabel("Number of PCs") ax.set_ylabel("Cumulative Variance") - ax.set_title(f"(B) Cumulative Variance\n{n95} PCs for 95%") + ax.set_title(f"Cumulative Variance\n{n95} PCs for 95%") ax.legend() ax = fig.add_subplot(gs[0, 2]) @@ -234,7 +489,7 @@ def main(): ax.set_title("Summary") ax.axis("off") - # ── Row 2: PCA colored by J1 rotvec x, y, z ───────────────────────────── + # Row 2: PCA colored by J1 rotvec x, y, z for col, comp in enumerate(range(3)): ax = fig.add_subplot(gs[1, col]) sc = ax.scatter( @@ -248,11 +503,9 @@ def main(): plt.colorbar(sc, ax=ax, label=f"J1 rotvec {axis_labels[comp]}") ax.set_xlabel("PC1") ax.set_ylabel("PC2") - ax.set_title( - f"(C{comp+1}) PC1 vs PC2\nJ1 rotation {axis_labels[comp]}-component" - ) + ax.set_title(f"PC1 vs PC2\nJ1 rotation {axis_labels[comp]}-component") - # ── Row 3: t-SNE colored by J1 x, J2 x, cell ID ──────────────────────── + # Row 3: t-SNE ax = fig.add_subplot(gs[2, 0]) sc = ax.scatter( hidden_tsne[:, 0], @@ -265,7 +518,7 @@ def main(): plt.colorbar(sc, ax=ax, label="J1 rotvec x") ax.set_xlabel("t-SNE 1") ax.set_ylabel("t-SNE 2") - ax.set_title("(D1) t-SNE\nJ1 rotation x-component") + ax.set_title("t-SNE\nJ1 rotation x-component") ax = fig.add_subplot(gs[2, 1]) sc = ax.scatter( @@ -279,7 +532,7 @@ def main(): plt.colorbar(sc, ax=ax, label="J2 rotvec x") ax.set_xlabel("t-SNE 1") ax.set_ylabel("t-SNE 2") - ax.set_title("(D2) t-SNE\nJ2 rotation x-component") + ax.set_title("t-SNE\nJ2 rotation x-component") ax = fig.add_subplot(gs[2, 2]) sc = ax.scatter( @@ -293,9 +546,9 @@ def main(): plt.colorbar(sc, ax=ax, label="Dominant cell ID") ax.set_xlabel("t-SNE 1") ax.set_ylabel("t-SNE 2") - ax.set_title("(D3) t-SNE\nColored by dominant cell ID") + ax.set_title("t-SNE\nColored by dominant cell ID") - # ── Row 4: Tuning — most J1-selective neurons (one per x,y,z) ──────────── + # Row 4: Tuning — most J1-selective neurons for i in range(3): neuron_idx = best_j1_neurons[i] corr = best_j1_corr[i] @@ -310,7 +563,7 @@ def main(): "blue", ) - # ── Row 5: Tuning — most J2-selective neurons (one per x,y,z) ──────────── + # Row 5: Tuning — most J2-selective neurons for i in range(3): neuron_idx = best_j2_neurons[i] corr = best_j2_corr[i] @@ -336,33 +589,35 @@ def main(): print(f"\nSaved to {output_path}") -def _plot_tuning_curve(ax, activations, angles, title, xlabel, color): - """Plot binned tuning curve with scatter and confidence band.""" - n_bins = 50 - bins = np.linspace(angles.min(), angles.max(), n_bins + 1) - bin_centers = (bins[:-1] + bins[1:]) / 2 - bin_means = np.zeros(n_bins) - bin_stds = np.zeros(n_bins) +def main(): + args = parse_args() - for b in range(n_bins): - mask = (angles >= bins[b]) & (angles < bins[b + 1]) - if mask.sum() > 0: - bin_means[b] = activations[mask].mean() - bin_stds[b] = activations[mask].std() + if args.data_path is None: + if args.manifold == "so2": + args.data_path = "data/estimation/trajectories_so2.pt" + else: + args.data_path = "data/estimation/trajectories.pt" - ax.fill_between( - bin_centers, bin_means - bin_stds, bin_means + bin_stds, alpha=0.2, color=color + print(f"Loading checkpoint: {args.checkpoint}") + model = StateEstimationModel.load_from_checkpoint( + args.checkpoint, map_location="cpu" ) - ax.plot(bin_centers, bin_means, "-", linewidth=2, color=color) + model.eval() - # Subsample scatter - n_show = min(2000, len(angles)) - idx = np.random.default_rng(0).choice(len(angles), n_show, replace=False) - ax.scatter(angles[idx], activations[idx], s=1, alpha=0.08, color="gray") + print(f"Loading data: {args.data_path}") + data = torch.load(args.data_path, weights_only=True) - ax.set_xlabel(xlabel) - ax.set_ylabel("Activation") - ax.set_title(title) + n_traj = min(args.n_traj, data["val_velocities"].shape[0]) + seq_len = ( + data["val_velocities"].shape[2] if data["val_velocities"].ndim > 2 else 100 + ) + seq_len = data["val_velocities"][:n_traj].shape[1] + hidden_size = model.hidden_size + + if args.manifold == "so2": + analyze_so2(args, model, data, n_traj, seq_len, hidden_size) + else: + analyze_so3(args, model, data, n_traj, seq_len, hidden_size) if __name__ == "__main__": diff --git a/scripts/generate_data.py b/scripts/generate_data.py index 2470cf7..8bba59b 100644 --- a/scripts/generate_data.py +++ b/scripts/generate_data.py @@ -1,12 +1,13 @@ """Pre-generate trajectory data for state estimation training. -Generates training and validation trajectories on SO(3) x SO(3) and saves -them to a .pt file for fast, reproducible loading during training. +Generates training and validation trajectories on SO(3)xSO(3) or SO(2)xSO(2) +and saves them to a .pt file for fast, reproducible loading during training. Uses multiprocessing for ~10-20x speedup on multi-core machines. Usage: python scripts/generate_data.py + python scripts/generate_data.py --manifold so2 python scripts/generate_data.py --n_train 100000 --n_val 5000 python scripts/generate_data.py --workers 32 python scripts/generate_data.py --help @@ -22,13 +23,15 @@ from scipy.spatial.transform import Rotation from tqdm import tqdm -from articulated.shared.robot_arm import RobotArmKinematics +from articulated.shared.robot_arm import RobotArm2DKinematics, RobotArmKinematics +# ── SO(3) worker functions ─────────────────────────────────────────────────── -def _worker_init( + +def _worker_init_so3( pc_quats1, pc_quats2, place_cell_kappa, seq_length, dt, vel_sigma, vel_theta ): - """Initialize worker process with shared place cell centers.""" + """Initialize worker process with shared place cell centers (SO(3)).""" global _w_centers_R1, _w_centers_R2, _w_kappa, _w_seq_length, _w_dt global _w_vel_sigma, _w_vel_theta, _w_kinematics _w_centers_R1 = Rotation.from_quat(pc_quats1) @@ -41,13 +44,13 @@ def _worker_init( _w_kinematics = RobotArmKinematics() -def _generate_one(seed): - """Generate a single trajectory in a worker process.""" +def _generate_one_so3(seed): + """Generate a single SO(3) trajectory in a worker process.""" rng = np.random.default_rng(seed) config = _w_kinematics.sample_random_configuration(rng) # Initial place cell activation - init_pc = _compute_pc(config) + init_pc = _compute_pc_so3(config) n_total = len(_w_centers_R1) + len(_w_centers_R2) decay = np.exp(-_w_vel_theta * _w_dt) @@ -61,13 +64,13 @@ def _generate_one(seed): omega = decay * omega + noise_scale * rng.standard_normal(6) velocities[t] = omega config = _w_kinematics.integrate_velocity(config, omega, _w_dt) - targets[t] = _compute_pc(config) + targets[t] = _compute_pc_so3(config) return velocities, targets, init_pc -def _compute_pc(config): - """Compute place cell activations using vMF kernel, separate per joint.""" +def _compute_pc_so3(config): + """Compute place cell activations using vMF kernel, separate per joint (SO(3)).""" R1_cur, R2_cur = config kappa = _w_kappa @@ -88,11 +91,93 @@ def _compute_pc(config): return np.concatenate([pc1, pc2]) +# ── SO(2) worker functions ─────────────────────────────────────────────────── + + +def _worker_init_so2( + pc_angles1, pc_angles2, place_cell_kappa, seq_length, dt, vel_sigma, vel_theta +): + """Initialize worker process with shared place cell centers (SO(2)).""" + global _w_centers_a1, _w_centers_a2, _w_kappa, _w_seq_length, _w_dt + global _w_vel_sigma, _w_vel_theta, _w_kinematics + _w_centers_a1 = pc_angles1 + _w_centers_a2 = pc_angles2 + _w_kappa = place_cell_kappa + _w_seq_length = seq_length + _w_dt = dt + _w_vel_sigma = vel_sigma + _w_vel_theta = vel_theta + _w_kinematics = RobotArm2DKinematics() + + +def _generate_one_so2(seed): + """Generate a single SO(2) trajectory in a worker process.""" + rng = np.random.default_rng(seed) + config = _w_kinematics.sample_random_configuration(rng) + + # Initial position as (cos θ1, sin θ1, cos θ2, sin θ2) + theta1, theta2 = config + init_angles = np.array( + [ + np.cos(theta1), + np.sin(theta1), + np.cos(theta2), + np.sin(theta2), + ] + ) + + n_total = len(_w_centers_a1) + len(_w_centers_a2) + decay = np.exp(-_w_vel_theta * _w_dt) + noise_scale = _w_vel_sigma * np.sqrt(1.0 - decay**2) + + velocities = np.zeros((_w_seq_length, 2)) + targets = np.zeros((_w_seq_length, n_total)) + omega = np.zeros(2) + + for t in range(_w_seq_length): + omega = decay * omega + noise_scale * rng.standard_normal(2) + velocities[t] = omega + config = _w_kinematics.integrate_velocity(config, omega, _w_dt) + targets[t] = _compute_pc_so2(config) + + return velocities, targets, init_angles + + +def _compute_pc_so2(config): + """Compute place cell activations using vMF kernel, separate per joint (SO(2)).""" + theta1, theta2 = config + kappa = _w_kappa + TWO_PI = 2.0 * np.pi + + # Circular distance per joint + delta1 = np.abs(theta1 - _w_centers_a1) + d1 = np.minimum(delta1, TWO_PI - delta1) + delta2 = np.abs(theta2 - _w_centers_a2) + d2 = np.minimum(delta2, TWO_PI - delta2) + + # vMF kernel + softmax per joint + logits1 = kappa * np.cos(d1) + logits1 -= logits1.max() + exp1 = np.exp(logits1) + pc1 = exp1 / exp1.sum() + + logits2 = kappa * np.cos(d2) + logits2 -= logits2.max() + exp2 = np.exp(logits2) + pc2 = exp2 / exp2.sum() + + return np.concatenate([pc1, pc2]) + + +# ── Common parallel generation ─────────────────────────────────────────────── + + def generate_parallel( n_trajectories, base_seed, - pc_quats1, - pc_quats2, + manifold, + pc_data1, + pc_data2, place_cell_kappa, seq_length, n_place_cells, @@ -102,17 +187,26 @@ def generate_parallel( n_workers, ): """Generate trajectories using multiprocessing.""" - # Each trajectory gets a unique seed derived from base_seed rng = np.random.default_rng(base_seed) seeds = rng.integers(0, 2**31, size=n_trajectories) - velocities = np.zeros((n_trajectories, seq_length, 6)) + vel_dim = 2 if manifold == "so2" else 6 + init_dim = 4 if manifold == "so2" else n_place_cells + + velocities = np.zeros((n_trajectories, seq_length, vel_dim)) targets = np.zeros((n_trajectories, seq_length, n_place_cells)) - init_pcs = np.zeros((n_trajectories, n_place_cells)) + init_data = np.zeros((n_trajectories, init_dim)) + + if manifold == "so2": + initializer = _worker_init_so2 + worker_fn = _generate_one_so2 + else: + initializer = _worker_init_so3 + worker_fn = _generate_one_so3 initargs = ( - pc_quats1, - pc_quats2, + pc_data1, + pc_data2, place_cell_kappa, seq_length, dt, @@ -120,21 +214,28 @@ def generate_parallel( vel_theta, ) - with Pool(n_workers, initializer=_worker_init, initargs=initargs) as pool: - results = pool.imap(_generate_one, seeds, chunksize=64) - for i, (vel, tgt, ipc) in enumerate( + with Pool(n_workers, initializer=initializer, initargs=initargs) as pool: + results = pool.imap(worker_fn, seeds, chunksize=64) + for i, (vel, tgt, init) in enumerate( tqdm(results, total=n_trajectories, desc="Generating trajectories") ): velocities[i] = vel targets[i] = tgt - init_pcs[i] = ipc + init_data[i] = init - return velocities, targets, init_pcs + return velocities, targets, init_data def parse_args(): parser = argparse.ArgumentParser(description="Pre-generate trajectory data") + parser.add_argument( + "--manifold", + type=str, + default="so3", + choices=["so3", "so2"], + help="Configuration manifold: so3 (6D) or so2 (2D torus)", + ) parser.add_argument( "--n_train", type=int, default=100000, help="Training trajectories" ) @@ -162,8 +263,8 @@ def parse_args(): parser.add_argument( "--output", type=str, - default="data/estimation/trajectories.pt", - help="Output .pt file path", + default=None, + help="Output .pt file path (default: auto based on manifold)", ) return parser.parse_args() @@ -172,27 +273,46 @@ def parse_args(): def main(): args = parse_args() + if args.output is None: + if args.manifold == "so2": + args.output = "data/estimation/trajectories_so2.pt" + else: + args.output = "data/estimation/trajectories.pt" + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) n_per_joint = args.n_place_cells // 2 - print(f"Generating {args.n_train} train + {args.n_val} val trajectories") + print( + f"Generating {args.n_train} train + {args.n_val} val trajectories ({args.manifold})" + ) print( f" seq_length={args.seq_length}, n_place_cells={args.n_place_cells} ({n_per_joint} per joint)" ) print(f" kappa={args.place_cell_kappa}, seed={args.seed}, workers={args.workers}") print() - # Initialize separate place cells for each joint rng = np.random.default_rng(args.seed) - centers_R1 = Rotation.random(n_per_joint, random_state=int(rng.integers(0, 2**31))) - centers_R2 = Rotation.random(n_per_joint, random_state=int(rng.integers(0, 2**31))) - pc_quats1 = centers_R1.as_quat() - pc_quats2 = centers_R2.as_quat() + + if args.manifold == "so2": + # Random angles on [0, 2*pi) for each joint + pc_data1 = rng.uniform(0, 2 * np.pi, size=n_per_joint) + pc_data2 = rng.uniform(0, 2 * np.pi, size=n_per_joint) + else: + # Random rotations on SO(3) for each joint + centers_R1 = Rotation.random( + n_per_joint, random_state=int(rng.integers(0, 2**31)) + ) + centers_R2 = Rotation.random( + n_per_joint, random_state=int(rng.integers(0, 2**31)) + ) + pc_data1 = centers_R1.as_quat() + pc_data2 = centers_R2.as_quat() common = dict( - pc_quats1=pc_quats1, - pc_quats2=pc_quats2, + manifold=args.manifold, + pc_data1=pc_data1, + pc_data2=pc_data2, place_cell_kappa=args.place_cell_kappa, seq_length=args.seq_length, n_place_cells=args.n_place_cells, @@ -227,13 +347,10 @@ def main(): data = { "train_velocities": torch.from_numpy(train_vel).float(), "train_targets": torch.from_numpy(train_tgt).float(), - "train_init_pcs": torch.from_numpy(train_init).float(), "val_velocities": torch.from_numpy(val_vel).float(), "val_targets": torch.from_numpy(val_tgt).float(), - "val_init_pcs": torch.from_numpy(val_init).float(), - "place_cell_quats1": torch.from_numpy(pc_quats1).float(), - "place_cell_quats2": torch.from_numpy(pc_quats2).float(), "metadata": { + "manifold": args.manifold, "n_train": args.n_train, "n_val": args.n_val, "seq_length": args.seq_length, @@ -245,6 +362,17 @@ def main(): }, } + if args.manifold == "so2": + data["train_init_angles"] = torch.from_numpy(train_init).float() + data["val_init_angles"] = torch.from_numpy(val_init).float() + data["place_cell_angles1"] = torch.from_numpy(pc_data1).float() + data["place_cell_angles2"] = torch.from_numpy(pc_data2).float() + else: + data["train_init_pcs"] = torch.from_numpy(train_init).float() + data["val_init_pcs"] = torch.from_numpy(val_init).float() + data["place_cell_quats1"] = torch.from_numpy(pc_data1).float() + data["place_cell_quats2"] = torch.from_numpy(pc_data2).float() + print(f"\nSaving to {output_path}...") torch.save(data, output_path) diff --git a/scripts/train.py b/scripts/train.py index 008a0c6..0d5b24c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -64,6 +64,22 @@ def parse_args(): parser.add_argument( "--weight_decay", type=float, default=1e-4, help="L2 regularization" ) + parser.add_argument( + "--input_size", type=int, default=None, help="Input size (auto from manifold)" + ) + parser.add_argument( + "--init_pos_size", + type=int, + default=None, + help="Init position size (auto from manifold)", + ) + parser.add_argument( + "--manifold", + type=str, + default="so3", + choices=["so3", "so2"], + help="Configuration manifold", + ) # Training parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") @@ -108,9 +124,23 @@ def main(): except ValueError: devices = args.devices + # Derive input_size and init_pos_size from manifold if not specified + if args.input_size is None: + input_size = 2 if args.manifold == "so2" else 6 + else: + input_size = args.input_size + + if args.init_pos_size is None: + init_pos_size = 4 if args.manifold == "so2" else args.n_place_cells + else: + init_pos_size = args.init_pos_size + print( f"Model: {args.model_type.upper()}, hidden={args.hidden_size}, dropout={args.dropout}" ) + print( + f"Manifold: {args.manifold} (input_size={input_size}, init_pos_size={init_pos_size})" + ) print( f"Training: {args.epochs} epochs, lr={args.lr}, wd={args.weight_decay}, grad_clip={args.grad_clip}" ) @@ -123,7 +153,7 @@ def main(): print( f"Data: generating {args.n_train} train, {args.n_val} val, seq_len={args.seq_length}" ) - print(f"Logging: {'WandB' if args.wandb else 'TensorBoard + CSV'}") + print(f"Logging: {'WandB' if args.wandb else 'CSV'}") print(f"Output: {output_dir}") print() @@ -139,11 +169,12 @@ def main(): provide_init_pos=True, data_path=args.data_path, num_workers=args.num_workers, + manifold=args.manifold, ) # Model model = StateEstimationModel( - input_size=6, + input_size=input_size, hidden_size=args.hidden_size, output_size=args.n_place_cells, model_type=args.model_type, @@ -151,6 +182,7 @@ def main(): weight_decay=args.weight_decay, use_init_pos=True, dropout=args.dropout, + init_pos_size=init_pos_size, ) # Loggers diff --git a/scripts/train.sh b/scripts/train.sh index b8f9237..3ab6316 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,13 +2,14 @@ # Train state estimation model # # Usage: -# bash scripts/train.sh gru 0 # train GRU on GPU 0 +# bash scripts/train.sh gru 0 # train GRU on GPU 0 (SO(3)) +# bash scripts/train.sh gru 0 so2 # train GRU on GPU 0 (SO(2)) # bash scripts/train.sh lstm 0,1 # train LSTM on GPU 0 and 1 # bash scripts/train.sh rnn 2,3,4 # train RNN on GPU 2, 3, 4 # bash scripts/train.sh gru all # train GRU on all GPUs # # Defaults: -# - Data: data/estimation/trajectories.pt +# - Data: data/estimation/trajectories.pt (or trajectories_so2.pt) # - WandB logging enabled # - 200 epochs, batch_size=128, lr=1e-3, dropout=0.5 # - Gradient clipping=1.0, CosineAnnealing LR schedule @@ -18,9 +19,9 @@ set -euo pipefail # ── Args ────────────────────────────────────────────────────────────────────── MODEL_TYPE="${1:-gru}" GPUS="${2:-0}" +MANIFOLD="${3:-so3}" # ── Config (edit these as needed) ───────────────────────────────────────────── -DATA_PATH="data/estimation/trajectories.pt" EPOCHS=200 BATCH_SIZE=128 LR=1e-3 @@ -31,7 +32,19 @@ WEIGHT_DECAY=1e-4 GRAD_CLIP=1.0 SEED=42 WANDB_PROJECT="articulated-estimation" -OUTPUT_DIR="logs/estimation/${MODEL_TYPE}" + +# Manifold-dependent config +if [ "$MANIFOLD" = "so2" ]; then + DATA_PATH="data/estimation/trajectories_so2.pt" + INPUT_SIZE=2 + INIT_POS_SIZE=4 + OUTPUT_DIR="logs/estimation/${MODEL_TYPE}_so2" +else + DATA_PATH="data/estimation/trajectories.pt" + INPUT_SIZE=6 + INIT_POS_SIZE=$N_PLACE_CELLS + OUTPUT_DIR="logs/estimation/${MODEL_TYPE}" +fi # ── Setup ───────────────────────────────────────────────────────────────────── SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -64,6 +77,7 @@ echo "============================================" echo " State Estimation Training" echo "============================================" echo "Model: ${MODEL_TYPE^^} (hidden=${HIDDEN_SIZE}, dropout=${DROPOUT})" +echo "Manifold: ${MANIFOLD} (input=${INPUT_SIZE}, init_pos=${INIT_POS_SIZE})" echo "Data: ${DATA_PATH}" echo "GPUs: ${GPUS} (devices=${NUM_DEVICES}, strategy=${STRATEGY})" echo "Training: ${EPOCHS} epochs, lr=${LR}, batch=${BATCH_SIZE}" @@ -89,5 +103,8 @@ exec "$PYTHON" "$SCRIPT_DIR/train.py" \ --strategy "$STRATEGY" \ --seed "$SEED" \ --output_dir "$OUTPUT_DIR" \ + --manifold "$MANIFOLD" \ + --input_size "$INPUT_SIZE" \ + --init_pos_size "$INIT_POS_SIZE" \ --wandb \ --project "$WANDB_PROJECT" diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index e2846fc..29d3808 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -211,3 +211,138 @@ def test_place_cells_initialized(self): assert dm.place_cell_centers is not None assert len(dm.place_cell_centers) == 8 # n_cells_per_joint + + +class TestEstimationDataModuleSO2: + """Tests for EstimationDataModule with SO(2) manifold.""" + + def test_initialization(self): + """Test that SO(2) data module initializes correctly.""" + dm = EstimationDataModule( + batch_size=16, + seq_length=50, + n_trajectories_train=100, + n_trajectories_val=20, + n_place_cells=64, + seed=42, + manifold="so2", + ) + assert dm.manifold == "so2" + assert dm.vel_dim == 2 + assert dm.init_pos_dim == 4 + + def test_dataset_shapes(self): + """Velocity shape should be (*, 2), init_pos shape should be (*, 4).""" + seq_len = 10 + n_train = 5 + n_val = 2 + n_pc = 8 + + dm = EstimationDataModule( + batch_size=4, + seq_length=seq_len, + n_trajectories_train=n_train, + n_trajectories_val=n_val, + n_place_cells=n_pc, + seed=42, + provide_init_pos=True, + manifold="so2", + ) + dm.setup(stage="fit") + + assert dm.train_dataset is not None + train_vel, train_tgt, train_init = dm.train_dataset.tensors + assert train_vel.shape == (n_train, seq_len, 2) + assert train_tgt.shape == (n_train, seq_len, n_pc) + assert train_init.shape == (n_train, 4) + + assert dm.val_dataset is not None + val_vel, val_tgt, val_init = dm.val_dataset.tensors + assert val_vel.shape == (n_val, seq_len, 2) + assert val_tgt.shape == (n_val, seq_len, n_pc) + assert val_init.shape == (n_val, 4) + + def test_init_pos_is_unit_vectors(self): + """Init pos (cos, sin) pairs should have magnitude ~1.""" + dm = EstimationDataModule( + batch_size=4, + seq_length=10, + n_trajectories_train=5, + n_trajectories_val=2, + n_place_cells=8, + seed=42, + provide_init_pos=True, + manifold="so2", + ) + dm.setup(stage="fit") + + assert dm.train_dataset is not None + init_pos = dm.train_dataset.tensors[2].numpy() + # (cos θ1, sin θ1) should have magnitude 1 + mag1 = np.sqrt(init_pos[:, 0] ** 2 + init_pos[:, 1] ** 2) + mag2 = np.sqrt(init_pos[:, 2] ** 2 + init_pos[:, 3] ** 2) + np.testing.assert_allclose(mag1, 1.0, atol=1e-5) + np.testing.assert_allclose(mag2, 1.0, atol=1e-5) + + def test_per_joint_targets_are_distributions(self): + """Each joint's place cell targets should sum to ~1.""" + n_pc = 16 + n_per_joint = n_pc // 2 + dm = EstimationDataModule( + batch_size=4, + seq_length=10, + n_trajectories_train=3, + n_trajectories_val=1, + n_place_cells=n_pc, + seed=42, + manifold="so2", + ) + dm.setup(stage="fit") + + assert dm.train_dataset is not None + targets = dm.train_dataset.tensors[1] + j1_sums = targets[..., :n_per_joint].sum(dim=-1).numpy() + j2_sums = targets[..., n_per_joint:].sum(dim=-1).numpy() + np.testing.assert_allclose(j1_sums, 1.0, atol=1e-5) + np.testing.assert_allclose(j2_sums, 1.0, atol=1e-5) + + def test_targets_not_uniform(self): + """vMF targets should NOT be uniform.""" + n_pc = 16 + n_per_joint = n_pc // 2 + dm = EstimationDataModule( + batch_size=4, + seq_length=10, + n_trajectories_train=3, + n_trajectories_val=1, + n_place_cells=n_pc, + place_cell_kappa=5.0, + seed=42, + manifold="so2", + ) + dm.setup(stage="fit") + + assert dm.train_dataset is not None + targets = dm.train_dataset.tensors[1] + j1 = targets[0, 0, :n_per_joint].numpy() + assert j1.max() > 2.0 / n_per_joint + + def test_place_cells_initialized(self): + """Place cell centers should be angles after setup.""" + dm = EstimationDataModule( + batch_size=4, + seq_length=10, + n_trajectories_train=3, + n_trajectories_val=1, + n_place_cells=16, + seed=42, + manifold="so2", + ) + dm.setup(stage="fit") + + assert dm.place_cell_centers is not None + assert len(dm.place_cell_centers) == 8 + # Each center should be a pair of floats (angles) + theta1, theta2 = dm.place_cell_centers[0] + assert isinstance(theta1, (float, np.floating)) + assert isinstance(theta2, (float, np.floating)) diff --git a/tests/test_models.py b/tests/test_models.py index fba97f7..c7cf301 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -228,6 +228,53 @@ def test_get_hidden_states(self): assert hidden_states.shape == (4, 20, 128) + def test_forward_so2_input(self): + """Test forward pass with SO(2) input_size=2.""" + model = StateEstimationModel( + input_size=2, + hidden_size=128, + output_size=64, + model_type="gru", + use_init_pos=False, + ) + + x = torch.randn(4, 20, 2) + output, hidden_states = model(x) + + assert output.shape == (4, 20, 64) + assert hidden_states.shape == (4, 20, 128) + + def test_forward_so2_with_init_pos(self): + """Test SO(2) model with init_pos_size=4.""" + model = StateEstimationModel( + input_size=2, + hidden_size=128, + output_size=64, + model_type="gru", + use_init_pos=True, + init_pos_size=4, + ) + + x = torch.randn(4, 20, 2) + # SO(2) init pos: (cos θ1, sin θ1, cos θ2, sin θ2) → 4D + init_pos = torch.randn(4, 4) + h0 = model._encode_init_pos(init_pos) + output, hidden_states = model(x, hidden=h0) + + assert output.shape == (4, 20, 64) + assert hidden_states.shape == (4, 20, 128) + + def test_init_pos_size_defaults_to_output_size(self): + """When init_pos_size is None, it should default to output_size.""" + model = StateEstimationModel( + input_size=6, + hidden_size=128, + output_size=64, + model_type="rnn", + use_init_pos=True, + ) + assert model.init_pos_size == 64 + @pytest.mark.usefixtures("patch_reacher_env") class TestRLAgent: diff --git a/tests/test_robot_arm.py b/tests/test_robot_arm.py index 4091a42..1969fa8 100644 --- a/tests/test_robot_arm.py +++ b/tests/test_robot_arm.py @@ -3,7 +3,7 @@ import numpy as np from scipy.spatial.transform import Rotation -from articulated.shared.robot_arm import RobotArmKinematics +from articulated.shared.robot_arm import RobotArm2DKinematics, RobotArmKinematics class TestIntegrateVelocity: @@ -138,3 +138,144 @@ def test_end_effector_bounded(self): pos = kin.forward_kinematics(config) dist = np.linalg.norm(pos) assert dist <= 2.0 + 1e-10 + + +class TestRobotArm2DIntegrateVelocity: + """Tests for RobotArm2DKinematics.integrate_velocity.""" + + def test_zero_velocity_preserves_angles(self): + """Zero angular velocity should not change angles.""" + kin = RobotArm2DKinematics() + config = (1.0, 2.0) + omega = np.zeros(2) + + new_config = kin.integrate_velocity(config, omega, dt=0.01) + + assert abs(new_config[0] - config[0]) < 1e-10 + assert abs(new_config[1] - config[1]) < 1e-10 + + def test_wrapping(self): + """Angles should wrap around [0, 2*pi).""" + kin = RobotArm2DKinematics() + config = (6.0, 0.1) # close to 2*pi + omega = np.array([10.0, -20.0]) # large velocity + + new_config = kin.integrate_velocity(config, omega, dt=0.1) + + assert 0 <= new_config[0] < 2 * np.pi + assert 0 <= new_config[1] < 2 * np.pi + + def test_nonzero_velocity_changes_angles(self): + """Nonzero velocity should change the angles.""" + kin = RobotArm2DKinematics() + config = (0.0, 0.0) + omega = np.array([1.0, 2.0]) + + new_config = kin.integrate_velocity(config, omega, dt=0.1) + + assert abs(new_config[0] - 0.1) < 1e-10 + assert abs(new_config[1] - 0.2) < 1e-10 + + +class TestRobotArm2DGeodesicDistance: + """Tests for RobotArm2DKinematics.geodesic_distance.""" + + def test_distance_to_self_is_zero(self): + kin = RobotArm2DKinematics() + config = (1.5, 3.0) + assert abs(kin.geodesic_distance(config, config)) < 1e-10 + + def test_symmetry(self): + kin = RobotArm2DKinematics() + c1 = (0.5, 1.0) + c2 = (2.0, 4.0) + d12 = kin.geodesic_distance(c1, c2) + d21 = kin.geodesic_distance(c2, c1) + assert abs(d12 - d21) < 1e-10 + + def test_circular_distance(self): + """Distance should use circular metric (shorter arc).""" + kin = RobotArm2DKinematics() + # theta1: 0.1 vs 6.1 → distance is ~0.1+0.083 = ~0.183 (wraps around) + c1 = (0.1, 0.0) + c2 = (2 * np.pi - 0.1, 0.0) + d = kin.geodesic_distance(c1, c2) + # Circular distance should be 0.2, not ~6.08 + assert abs(d - 0.2) < 1e-10 + + def test_nonnegative(self): + kin = RobotArm2DKinematics() + rng = np.random.default_rng(42) + c1 = kin.sample_random_configuration(rng) + c2 = kin.sample_random_configuration(rng) + assert kin.geodesic_distance(c1, c2) >= 0.0 + + def test_triangle_inequality(self): + kin = RobotArm2DKinematics() + rng = np.random.default_rng(42) + c1 = kin.sample_random_configuration(rng) + c2 = kin.sample_random_configuration(rng) + c3 = kin.sample_random_configuration(rng) + d12 = kin.geodesic_distance(c1, c2) + d23 = kin.geodesic_distance(c2, c3) + d13 = kin.geodesic_distance(c1, c3) + assert d13 <= d12 + d23 + 1e-10 + + +class TestRobotArm2DForwardKinematics: + """Tests for RobotArm2DKinematics.forward_kinematics.""" + + def test_zero_angles(self): + """Both joints at zero: end-effector at (L1+L2, 0).""" + kin = RobotArm2DKinematics(link_lengths=(1.0, 1.0)) + pos = kin.forward_kinematics((0.0, 0.0)) + np.testing.assert_allclose(pos, [2.0, 0.0], atol=1e-10) + + def test_output_shape(self): + """Forward kinematics should return a 2D position.""" + kin = RobotArm2DKinematics() + rng = np.random.default_rng(42) + config = kin.sample_random_configuration(rng) + pos = kin.forward_kinematics(config) + assert pos.shape == (2,) + + def test_end_effector_bounded(self): + """End-effector should be within L1 + L2 of origin.""" + kin = RobotArm2DKinematics(link_lengths=(1.0, 1.0)) + rng = np.random.default_rng(42) + for _ in range(10): + config = kin.sample_random_configuration(rng) + pos = kin.forward_kinematics(config) + dist = np.linalg.norm(pos) + assert dist <= 2.0 + 1e-10 + + def test_custom_link_lengths(self): + """Different link lengths with zero angles.""" + kin = RobotArm2DKinematics(link_lengths=(2.0, 3.0)) + pos = kin.forward_kinematics((0.0, 0.0)) + np.testing.assert_allclose(pos, [5.0, 0.0], atol=1e-10) + + def test_right_angle(self): + """First joint at pi/2, second at 0: end-effector at (0, L1+L2).""" + kin = RobotArm2DKinematics(link_lengths=(1.0, 1.0)) + pos = kin.forward_kinematics((np.pi / 2, 0.0)) + np.testing.assert_allclose(pos, [0.0, 2.0], atol=1e-10) + + +class TestRobotArm2DSampleConfig: + """Tests for RobotArm2DKinematics.sample_random_configuration.""" + + def test_range(self): + """Random config should be in [0, 2*pi).""" + kin = RobotArm2DKinematics() + rng = np.random.default_rng(42) + for _ in range(20): + config = kin.sample_random_configuration(rng) + assert 0 <= config[0] < 2 * np.pi + assert 0 <= config[1] < 2 * np.pi + + def test_returns_tuple_of_floats(self): + kin = RobotArm2DKinematics() + config = kin.sample_random_configuration(np.random.default_rng(42)) + assert isinstance(config[0], float) + assert isinstance(config[1], float)