Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions msgs/src/msgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ class Observations(ArrowMessage):
joint_positions: np.ndarray
joint_velocities: np.ndarray
height_scan: np.ndarray
observation_id: int


@dataclass
class JointCommands(ArrowMessage):
"""Joint commands generated from a specific observation.

The `observation_id` matches the commands to the observations they were generated from.
"""

positions: np.ndarray
observation_id: int


class WaypointStatus(ArrowMessage, Enum):
Expand Down
5 changes: 4 additions & 1 deletion nodes/policy_controller/policy_controller/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def try_step():
)
node.send_output(
"joint_commands",
msgs.JointCommands(positions=joint_targets).to_arrow(),
msgs.JointCommands(
positions=joint_targets,
observation_id=last_observations.observation_id,
).to_arrow(),
)
elif last_commands is not None:
# Resend last commands
Expand Down
2 changes: 2 additions & 0 deletions nodes/simulation/simulation/go2_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def compute_observations(self) -> msgs.Observations:
joint_positions=self.robot.get_joint_positions(),
joint_velocities=self.robot.get_joint_velocities(),
height_scan=self.height_scan_grid.get_height_data(),
# Generate a random observation id
observation_id=np.random.randint(0, 2**16 - 1),
)

def _report_all_hits(self, hit_info):
Expand Down
12 changes: 11 additions & 1 deletion nodes/simulation/simulation/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""TODO: Add docstring."""

from collections import deque
from enum import Enum

import msgs
import pyarrow as pa
from dora import Node

import msgs
from simulation.check_nvidia_driver import check_nvidia_driver
from simulation.scene_config import Scene
from simulation.simulation_time_output import SimulationTimeOutput
Expand Down Expand Up @@ -37,9 +38,13 @@ def simulation():
# Publish simulation time at each physics step
_simulation_time_output = SimulationTimeOutput(node, runner.world)

# keep only the last 100 observation IDs
observation_id_sequence = deque(maxlen=100)

# Publish observations at each physics step
def on_physics_step(dt: float):
observations = runner.go2.compute_observations()
observation_id_sequence.append(observations.observation_id)
node.send_output("observations", observations.to_arrow())

runner.world.add_physics_callback("observation_output", on_physics_step)
Expand All @@ -64,6 +69,11 @@ def on_physics_step(dt: float):

elif event["id"] == "joint_commands":
joint_commands = msgs.JointCommands.from_arrow(event["value"])
# Compute how many physics steps have passed since the observation was made
observation_delay = len(
observation_id_sequence
) - observation_id_sequence.index(joint_commands.observation_id)
node.send_output("reaction_frame_delay", pa.array([observation_delay]))
runner.go2.set_target_positions(joint_commands.positions)

elif event["id"] == "pub_status_tick":
Expand Down