From 01b64e199d0c98b4284a05dee1975da5e4fe2378 Mon Sep 17 00:00:00 2001 From: Andras Polgar Date: Thu, 18 Dec 2025 18:12:36 +0100 Subject: [PATCH] Measure reaction delay --- msgs/src/msgs/__init__.py | 7 +++++++ nodes/policy_controller/policy_controller/main.py | 5 ++++- nodes/simulation/simulation/go2_robot.py | 2 ++ nodes/simulation/simulation/main.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/msgs/src/msgs/__init__.py b/msgs/src/msgs/__init__.py index 7164c68..00ded56 100644 --- a/msgs/src/msgs/__init__.py +++ b/msgs/src/msgs/__init__.py @@ -81,11 +81,18 @@ class Observations(ArrowMessage): joint_positions: np.ndarray joint_velocities: np.ndarray height_scan: np.ndarray + observation_id: int @dataclass class JointCommands(ArrowMessage): + """Joint commands generated from a specific observation. + + The `observation_id` matches the commands to the observations they were generated from. + """ + positions: np.ndarray + observation_id: int class WaypointStatus(ArrowMessage, Enum): diff --git a/nodes/policy_controller/policy_controller/main.py b/nodes/policy_controller/policy_controller/main.py index 846d568..cd7fd7d 100644 --- a/nodes/policy_controller/policy_controller/main.py +++ b/nodes/policy_controller/policy_controller/main.py @@ -55,7 +55,10 @@ def try_step(): ) node.send_output( "joint_commands", - msgs.JointCommands(positions=joint_targets).to_arrow(), + msgs.JointCommands( + positions=joint_targets, + observation_id=last_observations.observation_id, + ).to_arrow(), ) elif last_commands is not None: # Resend last commands diff --git a/nodes/simulation/simulation/go2_robot.py b/nodes/simulation/simulation/go2_robot.py index bcf0ab3..086a3ec 100644 --- a/nodes/simulation/simulation/go2_robot.py +++ b/nodes/simulation/simulation/go2_robot.py @@ -109,6 +109,8 @@ def compute_observations(self) -> msgs.Observations: joint_positions=self.robot.get_joint_positions(), joint_velocities=self.robot.get_joint_velocities(), height_scan=self.height_scan_grid.get_height_data(), + # Generate a random observation id + observation_id=np.random.randint(0, 2**16 - 1), ) def _report_all_hits(self, hit_info): diff --git a/nodes/simulation/simulation/main.py b/nodes/simulation/simulation/main.py index cb2c096..c4218b0 100644 --- a/nodes/simulation/simulation/main.py +++ b/nodes/simulation/simulation/main.py @@ -1,11 +1,12 @@ """TODO: Add docstring.""" +from collections import deque from enum import Enum -import msgs import pyarrow as pa from dora import Node +import msgs from simulation.check_nvidia_driver import check_nvidia_driver from simulation.scene_config import Scene from simulation.simulation_time_output import SimulationTimeOutput @@ -37,9 +38,13 @@ def simulation(): # Publish simulation time at each physics step _simulation_time_output = SimulationTimeOutput(node, runner.world) + # keep only the last 100 observation IDs + observation_id_sequence = deque(maxlen=100) + # Publish observations at each physics step def on_physics_step(dt: float): observations = runner.go2.compute_observations() + observation_id_sequence.append(observations.observation_id) node.send_output("observations", observations.to_arrow()) runner.world.add_physics_callback("observation_output", on_physics_step) @@ -64,6 +69,11 @@ def on_physics_step(dt: float): elif event["id"] == "joint_commands": joint_commands = msgs.JointCommands.from_arrow(event["value"]) + # Compute how many physics steps have passed since the observation was made + observation_delay = len( + observation_id_sequence + ) - observation_id_sequence.index(joint_commands.observation_id) + node.send_output("reaction_frame_delay", pa.array([observation_delay])) runner.go2.set_target_positions(joint_commands.positions) elif event["id"] == "pub_status_tick":