From 545211121d98824ceb4ae3a80ed0c559b60c7566 Mon Sep 17 00:00:00 2001 From: StevenJacobs61 <106992451+StevenJacobs61@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:29:56 +0000 Subject: [PATCH] fix: ensured channel states --- neuracore-dictionary.txt | 1 + neuracore/core/robot.py | 48 +++- neuracore/core/streaming/data_stream.py | 5 + neuracore/data_daemon/bootstrap.py | 70 +---- .../communications_management/data_bridge.py | 270 +++++++++++++++--- .../communications_management/producer.py | 53 +++- neuracore/data_daemon/const.py | 1 + neuracore/data_daemon/helpers.py | 6 + .../data_daemon/lifecycle/daemon_lifecycle.py | 2 +- neuracore/data_daemon/models.py | 7 + .../encoding/video_trace.py | 26 +- neuracore/data_daemon/runner_entry.py | 2 +- .../state_management/state_store.py | 2 +- .../state_management/state_store_sqlite.py | 15 +- .../data_daemon/state_management/tables.py | 1 + .../resumable_file_uploader.py | 6 +- .../test_zmq_sockets.py | 24 +- .../lifecycle/test_daemon_lifecycle.py | 2 +- .../test_management_channel_startup.py | 2 +- .../test_state_store_sqlite.py | 54 ++++ 20 files changed, 459 insertions(+), 138 deletions(-) diff --git a/neuracore-dictionary.txt b/neuracore-dictionary.txt index 293291be..d98f2685 100644 --- a/neuracore-dictionary.txt +++ b/neuracore-dictionary.txt @@ -271,3 +271,4 @@ getpid WRONLY rels huggingface +itemsize \ No newline at end of file diff --git a/neuracore/core/robot.py b/neuracore/core/robot.py index bab77cb5..54f7ec54 100644 --- a/neuracore/core/robot.py +++ b/neuracore/core/robot.py @@ -23,6 +23,7 @@ from neuracore.core.streaming.data_stream import DataStream from neuracore.core.streaming.recording_state_manager import get_recording_state_manager from neuracore.core.utils.robot_mapping import RobotMapping +from neuracore.data_daemon.communications_management.producer import Producer from neuracore.data_daemon.communications_management.producer import ( RecordingContext as DaemonRecordingContext, ) @@ -107,6 +108,7 @@ def __init__( self._auth: Auth = get_auth() self._temp_dir = None self._data_streams: dict[str, DataStream] = dict() + self._daemon_recording_context: DaemonRecordingContext | None = None self.org_id = org_id or get_current_org() @@ -319,8 +321,11 @@ def stop_recording(self, recording_id: str) -> None: if not self.id: raise RobotError("Robot not initialized. Call init() first.") - DaemonRecordingContext(recording_id=recording_id).stop_recording() - self._stop_all_streams() + producer_stop_sequence_numbers = self._stop_all_streams() + self._get_daemon_recording_context().stop_recording( + recording_id=recording_id, + producer_stop_sequence_numbers=producer_stop_sequence_numbers, + ) try: response = requests.post( @@ -347,13 +352,20 @@ def stop_recording(self, recording_id: str) -> None: except requests.exceptions.RequestException as e: raise RobotError(f"Failed to stop recording: {str(e)}") - def _stop_all_streams(self) -> None: + def _stop_all_streams(self) -> dict[str, int]: """Stop recording on all data streams for this robot instance.""" + producer_stop_sequence_numbers: dict[str, int] = {} for stream_id, stream in self._data_streams.items(): try: stream.stop_recording() + producer = getattr(stream, "_producer", None) + if isinstance(producer, Producer): + producer_stop_sequence_numbers[producer.producer_id] = ( + producer.get_last_sent_sequence_number() + ) except Exception: logger.exception("Failed to stop data stream %s", stream_id) + return producer_stop_sequence_numbers def is_recording(self) -> bool: """Check if the robot is currently recording data. @@ -640,7 +652,7 @@ def cancel_recording(self, recording_id: str) -> None: if not self.id: raise RobotError("Robot not initialized. Call init() first.") - DaemonRecordingContext(recording_id=recording_id).stop_recording() + self._get_daemon_recording_context().stop_recording(recording_id=recording_id) try: response = requests.post( @@ -663,6 +675,34 @@ def cancel_recording(self, recording_id: str) -> None: except requests.exceptions.RequestException as e: raise RobotError(f"Failed to cancel recording: {str(e)}") + def _get_daemon_recording_context(self) -> DaemonRecordingContext: + """Return a reusable daemon recording context, creating it lazily.""" + if self._daemon_recording_context is None: + self._daemon_recording_context = DaemonRecordingContext() + return self._daemon_recording_context + + def _cleanup_daemon_recording_context(self) -> None: + """Release daemon recording context resources.""" + if self._daemon_recording_context is None: + return + try: + self._daemon_recording_context.close() + except Exception: + logger.exception("Failed to cleanup daemon recording context") + finally: + self._daemon_recording_context = None + + def close(self) -> None: + """Release local resources owned by this Robot instance.""" + self._cleanup_daemon_recording_context() + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = None + + def __del__(self) -> None: + """Best-effort cleanup for daemon recording resources.""" + self.close() + # Global robot registry _robots: dict[RobotInstanceIdentifier, Robot] = {} diff --git a/neuracore/core/streaming/data_stream.py b/neuracore/core/streaming/data_stream.py index 3b041c0e..e4ec2c02 100644 --- a/neuracore/core/streaming/data_stream.py +++ b/neuracore/core/streaming/data_stream.py @@ -111,6 +111,10 @@ def _handle_ensure_producer(self, context: DataRecordingContext) -> None: ): self._producer.set_recording_id(context.recording_id) + # Reopen producer channel state for each new recording in case + # the daemon expired the channel while this producer object was idle. + self._producer.start_producer() + self._producer.open_ring_buffer() self._producer.start_new_trace() def stop_recording(self) -> list[threading.Thread]: @@ -232,6 +236,7 @@ def log(self, metadata: CameraData, frame: np.ndarray) -> None: metadata_dict = metadata.model_dump(mode="json", exclude={"frame"}) metadata_dict["width"] = self.width metadata_dict["height"] = self.height + metadata_dict["frame_nbytes"] = int(frame.size * frame.itemsize) metadata_json = json.dumps(metadata_dict).encode("utf-8") # Pack: [metadata_len (4 bytes)] [metadata_json] [frame_bytes] diff --git a/neuracore/data_daemon/bootstrap.py b/neuracore/data_daemon/bootstrap.py index f5e0ea7b..f7996a42 100644 --- a/neuracore/data_daemon/bootstrap.py +++ b/neuracore/data_daemon/bootstrap.py @@ -1,72 +1,4 @@ -"""Daemon bootstrap and lifecycle management. - -This module provides a clean, modular initialization sequence for the -data daemon. It coordinates the startup of all subsystems in the correct -order across the three execution contexts. - -INITIALIZATION SEQUENCE -======================= - - DaemonBootstrap.start() - │ - ├─[1] Configuration - │ └── ProfileManager → ConfigManager → DaemonConfig - │ - ├─[2] Authentication - │ └── Auth.login(api_key) - Initialize Auth singleton - │ - ├─[3] Event Loops (EventLoopManager) - │ ├── General Loop Thread started - │ ├── Encoder Loop Thread started - │ └── init_emitter(loop=general_loop) - │ - ├─[4] Async Services (on General Loop) - │ ├── aiohttp.ClientSession - │ ├── SqliteStateStore + init_async_store() - │ ├── StateManager (registers event listeners) - │ ├── UploadManager (listens for READY_FOR_UPLOAD) - │ ├── ConnectionManager + start() (monitors API) - │ └── ProgressReporter (listens for PROGRESS_REPORT) - │ - ├─[5] Recording & Encoding (RecordingDiskManager) - │ ├── _TraceFilesystem (path management) - │ ├── _TraceController (trace lifecycle) - │ ├── _EncoderManager (encoder factory) - │ ├── StorageBudget (disk space tracking) - │ ├── _RawBatchWriter → schedule_on_general_loop() - │ └── _BatchEncoderWorker → schedule_on_encoder_loop() - │ - ├─[6] ZMQ Communications - │ └── CommunicationsManager - │ - └─[7] Return DaemonContext - └── Daemon created with context, calls run() - - -MODULE REGISTRY -=============== - -Main Thread: - EventLoopManager - DaemonBootstrap - Manages async loops - CommunicationsManager- DaemonBootstrap - ZMQ sockets - Daemon - runner_entry - Message loop - -General Loop: - Emitter - EventLoopManager - Event coordination - AuthManager - bootstrap_async - API auth (singleton) - SqliteStateStore - bootstrap_async - Trace state persistence - StateManager - bootstrap_async - State coordination - UploadManager - bootstrap_async - Cloud uploads - ConnectionManager - bootstrap_async - API monitoring - ProgressReporter - bootstrap_async - Progress reporting - _RawBatchWriter - RecordingDiskManager - Raw file I/O - -Encoder Loop: - _BatchEncoderWorker - RecordingDiskManager - Video/JSON encoding - VideoTrace - _EncoderManager - H.264 encoding - JsonTrace - _EncoderManager - JSON encoding - -""" +"""Daemon bootstrap and lifecycle management.""" from __future__ import annotations diff --git a/neuracore/data_daemon/communications_management/data_bridge.py b/neuracore/data_daemon/communications_management/data_bridge.py index 46d4714a..ec29ec62 100644 --- a/neuracore/data_daemon/communications_management/data_bridge.py +++ b/neuracore/data_daemon/communications_management/data_bridge.py @@ -22,9 +22,11 @@ DATA_TYPE_FIELD_SIZE, DEFAULT_RING_BUFFER_SIZE, HEARTBEAT_TIMEOUT_SECS, + NEVER_OPENED_TIMEOUT_SECS, TRACE_ID_FIELD_SIZE, ) from neuracore.data_daemon.event_emitter import Emitter, get_emitter +from neuracore.data_daemon.helpers import utc_now from neuracore.data_daemon.models import ( CommandType, CompleteMessage, @@ -58,6 +60,13 @@ class ChannelState: last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) reader: ChannelMessageReader | None = None trace_id: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_sequence_number: int = 0 + opened_at: datetime | None = None + + def is_opened(self) -> bool: + """Check if the channel has been opened with a ring buffer.""" + return self.ring_buffer is not None def touch(self) -> None: """Update the last heartbeat time for the channel. @@ -66,10 +75,75 @@ def touch(self) -> None: """ self.last_heartbeat = datetime.now(timezone.utc) + def set_ring_buffer(self, ring_buffer: RingBuffer) -> None: + """Set the ring buffer for the channel. + + This method is called when a new channel is opened from a producer. + It takes a RingBuffer instance as an argument and sets the channel's ring buffer + to it. If the argument is not an instance of RingBuffer, it raises a TypeError. + + The method also creates a ChannelMessageReader + instance to read from the ring buffer + and sets the opened_at timestamp to the current time in UTC. + """ + if not isinstance(ring_buffer, RingBuffer): + raise TypeError("Invalid ring buffer instance provided for new channel.") + self.ring_buffer = ring_buffer + self.reader = ChannelMessageReader(ring_buffer) + self.opened_at = datetime.now(timezone.utc) + + def is_open(self) -> bool: + """Check if the channel is open (i.e. has an initialized ring buffer).""" + return self.ring_buffer is not None + + def has_missed_heartbeat( + self, + now: datetime, + heartbeat_timeout: timedelta | None = None, + ) -> bool: + """Return True when no heartbeat has been seen within timeout.""" + if heartbeat_timeout is None: + heartbeat_timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS) + return now - self.last_heartbeat > heartbeat_timeout + + def is_stale_unopened( + self, + now: datetime, + never_opened_timeout: timedelta | None = None, + ) -> bool: + """Return True when a channel never opened within timeout.""" + if never_opened_timeout is None: + never_opened_timeout = timedelta(seconds=NEVER_OPENED_TIMEOUT_SECS) + return (not self.is_open()) and (now - self.created_at > never_opened_timeout) + + def should_expire( + self, + ) -> bool: + """Return True if channel should be removed from daemon state.""" + now = utc_now() + return self.has_missed_heartbeat(now) or (self.is_stale_unopened(now)) + + def set_trace_id(self, trace_id: str) -> None: + """Set the trace ID for the current channel. + + Args: + trace_id: The trace ID to set for the current channel. + """ + if trace_id != self.trace_id: + self.trace_id = trace_id + CommandHandler = Callable[[ChannelState, MessageEnvelope], None] +@dataclass +class RecordingClosingState: + """Recording-level stop/drain state.""" + + producer_stop_sequence_numbers: dict[str, int] + stop_requested_at: datetime + + class Daemon: """Main neuracore data daemon. @@ -94,34 +168,37 @@ def __init__( self.comm = comm_manager or CommunicationsManager() self.recording_disk_manager = recording_disk_manager self.channels: dict[str, ChannelState] = {} + self._closed_producers: set[str] = set() self._recording_traces: dict[str, set[str]] = {} self._trace_recordings: dict[str, str] = {} self._trace_metadata: dict[str, dict[str, str | int | None]] = {} self._closed_recordings: set[str] = set() - self._pending_close_recordings: set[str] = set() + self._pending_close_recordings: dict[str, dict[str, int]] = {} + self._closing_recordings: dict[str, RecordingClosingState] = {} self._command_handlers: dict[CommandType, CommandHandler] = { CommandType.OPEN_RING_BUFFER: self._handle_open_ring_buffer, CommandType.DATA_CHUNK: self._handle_write_data_chunk, CommandType.HEARTBEAT: self._handle_heartbeat, CommandType.TRACE_END: self._handle_end_trace, - CommandType.RECORDING_STOPPED: self._handle_recording_stopped, } self._emitter = get_emitter() - self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_stopped_channels) self._running = False + self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_channel_on_trace_written) def run(self) -> None: - """Run the daemon main loop. + """Starts the daemon and begins accepting messages from producers. - This starts the consumer socket, and then enters an infinite loop where it: - - Receives ManagementMessages from producers over ZMQ - - Handles messages from producers using the `handle_message` function - - Cleans up expired channels using the `_cleanup_expired_channels` function - - Drains channel messages using the `_drain_channel_messages` function + This function blocks until the daemon is shutdown via Ctrl-C. - The loop will exit on a KeyboardInterrupt (e.g. Ctrl+C), and will then call - `cleanup_daemon` on the communications manager to clean up resources. + It is responsible for: + + - Starting the ZMQ consumer and publisher sockets. + - Receiving and processing management messages from producers. + - Periodically cleaning up expired channels. + - Draining full messages from the ring buffer. + + :return: None """ if self._running: raise RuntimeError("Daemon is already running") @@ -185,23 +262,53 @@ def handle_message(self, message: MessageEnvelope) -> None: cmd = message.command if producer_id is None: + # Stop recording commands are sent without a producer_id / channel if cmd != CommandType.RECORDING_STOPPED: logger.warning("Missing producer_id for command %s", cmd) return - channel = ChannelState(producer_id="recording-context") - else: - existing = self.channels.get(producer_id) - if existing is None: - existing = ChannelState(producer_id=producer_id) - self.channels[producer_id] = existing - logger.info("Created new channel for producer_id=%s", producer_id) - channel = existing + self._handle_recording_stopped(message) + return + + if ( + producer_id in self._closed_producers + and cmd != CommandType.OPEN_RING_BUFFER + ): + logger.warning( + "Ignoring command %s from closed producer_id=%s", + cmd, + producer_id, + ) + return + + if ( + cmd == CommandType.OPEN_RING_BUFFER + and producer_id in self._closed_producers + ): + self._closed_producers.discard(producer_id) + + existing = self.channels.get(producer_id) + if existing is None: + existing = ChannelState(producer_id=producer_id) + self.channels[producer_id] = existing + logger.info("Created new channel for producer_id=%s", producer_id) + channel = existing channel.touch() handler = self._command_handlers.get(cmd) if handler is None: logger.warning("Unknown command %s from producer_id=%s", cmd, producer_id) return + + if message.sequence_number is not None: + if message.sequence_number > channel.last_sequence_number: + channel.last_sequence_number = message.sequence_number + else: + logger.debug( + "Non-monotonic sequence_number=%s for producer_id=%s (last=%s)", + message.sequence_number, + producer_id, + channel.last_sequence_number, + ) try: handler(channel, message) except Exception: @@ -238,8 +345,7 @@ def _handle_open_ring_buffer( payload = message.payload.get(message.command.value, {}) size = payload.get("size", DEFAULT_RING_BUFFER_SIZE) - channel.ring_buffer = RingBuffer(size=size) - channel.reader = ChannelMessageReader(channel.ring_buffer) + channel.set_ring_buffer(RingBuffer(size)) logger.info( "Opened ring buffer (size=%d) for producer_id=%s", size, @@ -250,6 +356,7 @@ def _handle_open_ring_buffer( def _drain_channel_messages(self) -> None: """Poll all channels for completed messages and handle them.""" for channel in self.channels.values(): + # guard against uninitialised channels if channel.reader is None or channel.ring_buffer is None: continue # Loop to receive full message @@ -445,7 +552,14 @@ def _handle_write_data_chunk( ) return - channel.trace_id = data_chunk.trace_id + trace_id = data_chunk.trace_id + if channel.trace_id != trace_id and channel.trace_id is not None: + logger.warning( + "DATA_CHUNK trace_id=%s does not match channel trace_id=%s", + data_chunk.trace_id, + channel.trace_id, + ) + channel.set_trace_id(trace_id) if recording_id in self._closed_recordings: logger.warning( @@ -455,7 +569,40 @@ def _handle_write_data_chunk( ) return - trace_id = data_chunk.trace_id + closing_state = self._closing_recordings.get(recording_id) + if closing_state is not None and closing_state.producer_stop_sequence_numbers: + cutoff_sequence_number = closing_state.producer_stop_sequence_numbers.get( + channel.producer_id + ) + if cutoff_sequence_number is None: + logger.warning( + "Dropping data from producer_id=%s while " + "recording_id=%s is closing " + "(missing stop sequence number)", + channel.producer_id, + recording_id, + ) + return + if message.sequence_number is None: + logger.warning( + "Dropping data for producer_id=%s recording_id=%s " + "without sequence_number " + "while recording is closing", + channel.producer_id, + recording_id, + ) + return + if message.sequence_number > cutoff_sequence_number: + logger.warning( + "Dropping post-stop data for producer_id=%s recording_id=%s " + "(sequence_number=%s, cutoff_sequence_number=%s)", + channel.producer_id, + recording_id, + message.sequence_number, + cutoff_sequence_number, + ) + return + if recording_id: self._register_trace(recording_id, trace_id) self._register_trace_metadata( @@ -566,9 +713,7 @@ def _handle_end_trace( self._remove_trace(str(recording_id), str(trace_id)) - def _handle_recording_stopped( - self, _: ChannelState, message: MessageEnvelope - ) -> None: + def _handle_recording_stopped(self, message: MessageEnvelope) -> None: """Handle a RECORDING_STOPPED message from a producer. This function is called when a producer sends a RECORDING_STOPPED message to @@ -587,21 +732,73 @@ def _handle_recording_stopped( message.producer_id, ) return - self._pending_close_recordings.add(str(recording_id)) + producer_stop_sequence_numbers_raw = payload.get( + "producer_stop_sequence_numbers", {} + ) + producer_stop_sequence_numbers: dict[str, int] = {} + if isinstance(producer_stop_sequence_numbers_raw, dict): + for ( + producer_id, + sequence_number, + ) in producer_stop_sequence_numbers_raw.items(): + try: + producer_stop_sequence_numbers[str(producer_id)] = int( + sequence_number + ) + except (TypeError, ValueError): + logger.warning( + "Ignoring invalid stop sequence number for producer_id=%s: %r", + producer_id, + sequence_number, + ) + else: + logger.warning( + "recording_stopped.producer_stop_sequence_numbers must be a dict" + ) + + self._pending_close_recordings[str(recording_id)] = ( + producer_stop_sequence_numbers + ) self._emitter.emit(Emitter.STOP_RECORDING, recording_id) def _finalize_pending_closes(self) -> None: - """Move pending close recordings to closed set. + """Transition pending close recordings into closing state. - Called at the start of each main loop iteration to ensure any - DATA_CHUNK messages that arrived before RECORDING_STOPPED (but - were interleaved in ZMQ) get processed first. + Also finalize recordings that are done draining. """ + now = utc_now() if self._pending_close_recordings: - self._closed_recordings.update(self._pending_close_recordings) + for ( + recording_id, + producer_stop_sequence_numbers, + ) in self._pending_close_recordings.items(): + self._closing_recordings[recording_id] = RecordingClosingState( + producer_stop_sequence_numbers=producer_stop_sequence_numbers, + stop_requested_at=now, + ) self._pending_close_recordings.clear() - def cleanup_stopped_channels( + close_timeout = timedelta(seconds=10) + to_close: list[str] = [] + for recording_id, closing_state in self._closing_recordings.items(): + traces = self._recording_traces.get(recording_id, set()) + if not traces: + to_close.append(recording_id) + continue + if now - closing_state.stop_requested_at >= close_timeout: + logger.warning( + "Force-closing recording_id=%s after timeout " + "with %d trace(s) still open", + recording_id, + len(traces), + ) + to_close.append(recording_id) + + for recording_id in to_close: + self._closing_recordings.pop(recording_id, None) + self._closed_recordings.add(recording_id) + + def cleanup_channel_on_trace_written( self, trace_id: str, _: str | None = None, @@ -632,13 +829,10 @@ def cleanup_stopped_channels( def _cleanup_expired_channels(self) -> None: """Remove channels whose heartbeat has not been seen within the timeout.""" - now = datetime.now(timezone.utc) - timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS) - to_remove = [ producer_id for producer_id, state in self.channels.items() - if now - state.last_heartbeat > timeout + if state.should_expire() ] for producer_id in to_remove: @@ -663,5 +857,5 @@ def _cleanup_expired_channels(self) -> None: }, ), ) - # Here is where you would also clean up any shared memory segments. del self.channels[producer_id] + self._closed_producers.add(producer_id) diff --git a/neuracore/data_daemon/communications_management/producer.py b/neuracore/data_daemon/communications_management/producer.py index 4f910722..2af46001 100644 --- a/neuracore/data_daemon/communications_management/producer.py +++ b/neuracore/data_daemon/communications_management/producer.py @@ -23,7 +23,7 @@ class RecordingContext: def __init__( self, - recording_id: str, + recording_id: str | None = None, comm_manager: CommunicationsManager | None = None, ) -> None: """Initialize the recording context.""" @@ -38,14 +38,51 @@ def __init__( "Data cannot be captured without a running daemon." ) - def stop_recording(self) -> None: + def set_recording_id(self, recording_id: str | None) -> None: + """Set or clear the recording identifier for this context.""" + self.recording_id = recording_id + + def stop_recording( + self, + recording_id: str | None = None, + producer_stop_sequence_numbers: dict[str, int] | None = None, + ) -> None: """Send a recording-stopped control message.""" + effective_recording_id = recording_id or self.recording_id + if not effective_recording_id: + raise ValueError("recording_id is required to stop a recording.") + + recording_stopped_payload: dict[str, object] = { + "recording_id": effective_recording_id + } + if producer_stop_sequence_numbers: + recording_stopped_payload["producer_stop_sequence_numbers"] = ( + producer_stop_sequence_numbers + ) self._send( CommandType.RECORDING_STOPPED, - {"recording_stopped": {"recording_id": self.recording_id}}, + {"recording_stopped": recording_stopped_payload}, ) + self.recording_id = effective_recording_id + + def close(self) -> None: + """Close sockets and cleanup context resources owned by this instance.""" + if self.socket is not None: + self.socket.close(0) + self.socket = None + self._comm.cleanup_producer() def _send(self, command: CommandType, payload: dict | None = None) -> None: + """Send a management message to the daemon. + + Args: + command: The CommandType to send to the daemon. + payload: A dictionary containing any additional data required by the daemon + to process the message. + + Returns: + None + """ envelope = MessageEnvelope( producer_id=None, command=command, @@ -77,6 +114,8 @@ def __init__( self._heartbeat_thread: threading.Thread | None = None self._send_queue: queue.Queue[MessageEnvelope | None] = queue.Queue() self._sender_thread: threading.Thread | None = None + self._next_sequence_number = 1 + self._last_sent_sequence_number = 0 if self.socket is None: raise RuntimeError( @@ -222,13 +261,21 @@ def _send(self, command: CommandType, payload: dict | None = None) -> None: Returns: None """ + sequence_number = self._next_sequence_number + self._next_sequence_number += 1 + self._last_sent_sequence_number = sequence_number envelope = MessageEnvelope( producer_id=self.producer_id, command=command, payload=payload or {}, + sequence_number=sequence_number, ) self._send_queue.put(envelope) + def get_last_sent_sequence_number(self) -> int: + """Return the most recent message sequence number sent by this producer.""" + return self._last_sent_sequence_number + def has_consumer(self) -> bool: """Check if the producer has a consumer. diff --git a/neuracore/data_daemon/const.py b/neuracore/data_daemon/const.py index 2875aa85..5adb85bb 100644 --- a/neuracore/data_daemon/const.py +++ b/neuracore/data_daemon/const.py @@ -6,6 +6,7 @@ from pathlib import Path HEARTBEAT_TIMEOUT_SECS = 10 +NEVER_OPENED_TIMEOUT_SECS = 20 API_URL = os.getenv("NEURACORE_API_URL", "https://api.neuracore.app/api") TRACE_ID_FIELD_SIZE = 36 # bytes allocated for the trace_id string in chunk headers diff --git a/neuracore/data_daemon/helpers.py b/neuracore/data_daemon/helpers.py index 738bf501..956e2230 100644 --- a/neuracore/data_daemon/helpers.py +++ b/neuracore/data_daemon/helpers.py @@ -1,6 +1,7 @@ """Helper functions for the data daemon.""" import os +from datetime import datetime, timezone from pathlib import Path @@ -50,3 +51,8 @@ def get_daemon_recordings_root_path() -> Path: str(default_root), ) ) + + +def utc_now() -> datetime: + """Return the current time as a Unix timestamp or datetime object.""" + return datetime.now(timezone.utc) diff --git a/neuracore/data_daemon/lifecycle/daemon_lifecycle.py b/neuracore/data_daemon/lifecycle/daemon_lifecycle.py index 21a7a9c6..624a3bbd 100644 --- a/neuracore/data_daemon/lifecycle/daemon_lifecycle.py +++ b/neuracore/data_daemon/lifecycle/daemon_lifecycle.py @@ -310,7 +310,7 @@ async def reconcile_state_with_filesystem( store: StateStore, recordings_root: Path ) -> None: """Sync stored traces with disk contents, cleaning orphans and flagging gaps.""" - traces = store.list_traces() + traces = await store.list_traces() trace_paths = {Path(str(trace.path)) for trace in traces} for trace in traces: diff --git a/neuracore/data_daemon/models.py b/neuracore/data_daemon/models.py index 42574323..25d71c0f 100644 --- a/neuracore/data_daemon/models.py +++ b/neuracore/data_daemon/models.py @@ -285,6 +285,7 @@ class MessageEnvelope: producer_id: str | None command: CommandType payload: dict = field(default_factory=dict) + sequence_number: int | None = None @classmethod def from_dict(cls, data: dict) -> "MessageEnvelope": @@ -303,6 +304,11 @@ def from_dict(cls, data: dict) -> "MessageEnvelope": producer_id=str(producer_id) if producer_id is not None else None, command=CommandType(data["command"]), payload=dict(data.get("payload") or {}), + sequence_number=( + int(data["sequence_number"]) + if data.get("sequence_number") is not None + else None + ), ) @classmethod @@ -330,6 +336,7 @@ def to_bytes(self) -> bytes: "producer_id": self.producer_id, "command": self.command.value, "payload": self.payload, + "sequence_number": self.sequence_number, }).encode("utf-8") diff --git a/neuracore/data_daemon/recording_encoding_disk_manager/encoding/video_trace.py b/neuracore/data_daemon/recording_encoding_disk_manager/encoding/video_trace.py index 8ad18f08..103b88a1 100644 --- a/neuracore/data_daemon/recording_encoding_disk_manager/encoding/video_trace.py +++ b/neuracore/data_daemon/recording_encoding_disk_manager/encoding/video_trace.py @@ -112,7 +112,31 @@ def _try_handle_combined_packet(self, payload: bytes) -> bool: self._handle_metadata(parsed) - frame_bytes = payload[4 + metadata_len :] + frame_start = 4 + metadata_len + frame_nbytes: int | None = None + if isinstance(parsed, dict): + frame_nbytes_raw = parsed.get("frame_nbytes") + if isinstance(frame_nbytes_raw, int) and frame_nbytes_raw >= 0: + frame_nbytes = frame_nbytes_raw + + if frame_nbytes is None: + frame_bytes = payload[frame_start:] + else: + frame_end = frame_start + frame_nbytes + if frame_end > len(payload): + raise ValueError( + "Combined packet shorter than declared frame_nbytes: " + f"frame_start={frame_start} frame_nbytes={frame_nbytes} " + f"payload_len={len(payload)}" + ) + if frame_end != len(payload): + raise ValueError( + "Combined packet has trailing bytes after frame payload: " + f"declared_frame_nbytes={frame_nbytes} " + f"trailing_bytes={len(payload) - frame_end}" + ) + frame_bytes = payload[frame_start:frame_end] + if len(frame_bytes) > 0: self._handle_frame_bytes(frame_bytes) diff --git a/neuracore/data_daemon/runner_entry.py b/neuracore/data_daemon/runner_entry.py index 8ccb4475..389323ad 100644 --- a/neuracore/data_daemon/runner_entry.py +++ b/neuracore/data_daemon/runner_entry.py @@ -40,7 +40,7 @@ def main() -> None: db_path = get_daemon_db_path() try: - bootstrap = DaemonBootstrap() + bootstrap = DaemonBootstrap(db_path=db_path) context = bootstrap.start() if context is None: diff --git a/neuracore/data_daemon/state_management/state_store.py b/neuracore/data_daemon/state_management/state_store.py index 22566d82..a60ba949 100644 --- a/neuracore/data_daemon/state_management/state_store.py +++ b/neuracore/data_daemon/state_management/state_store.py @@ -28,7 +28,7 @@ async def find_traces_by_recording_id(self, recording_id: str) -> list[TraceReco """Return all traces for a given recording ID.""" ... - def list_traces(self) -> list[TraceRecord]: + async def list_traces(self) -> list[TraceRecord]: """Return all trace records.""" ... diff --git a/neuracore/data_daemon/state_management/state_store_sqlite.py b/neuracore/data_daemon/state_management/state_store_sqlite.py index d153195d..0e90cb81 100644 --- a/neuracore/data_daemon/state_management/state_store_sqlite.py +++ b/neuracore/data_daemon/state_management/state_store_sqlite.py @@ -153,16 +153,17 @@ async def find_traces_by_recording_id(self, recording_id: str) -> list[TraceReco ) return [TraceRecord.from_row(dict(row)) for row in rows] - def list_traces(self) -> list[TraceRecord]: + async def list_traces(self) -> list[TraceRecord]: """Return all trace records.""" - with self._engine.begin() as conn: - rows = conn.execute(select(traces)).mappings().all() + async with self._engine.begin() as conn: + rows = (await conn.execute(select(traces))).mappings().all() return [TraceRecord.from_row(dict(row)) for row in rows] async def update_status( self, trace_id: str, status: TraceStatus, + *, error_message: str | None = None, ) -> bool: """Update the status and optional error message for a trace. @@ -431,7 +432,13 @@ async def upsert_trace_bytes( index_elements=["trace_id"], set_={ "bytes_written": bytes_written, - "total_bytes": bytes_written, + "total_bytes": case( + ( + traces.c.total_bytes.is_(None), + bytes_written, + ), + else_=traces.c.total_bytes, + ), "last_updated": now, "status": case( ( diff --git a/neuracore/data_daemon/state_management/tables.py b/neuracore/data_daemon/state_management/tables.py index 9c0e7ed0..9f350931 100644 --- a/neuracore/data_daemon/state_management/tables.py +++ b/neuracore/data_daemon/state_management/tables.py @@ -80,3 +80,4 @@ Index("idx_traces_trace_id", traces.c.trace_id) Index("idx_traces_status", traces.c.status) +Index("idx_traces_next_retry_at", traces.c.next_retry_at) diff --git a/neuracore/data_daemon/upload_management/resumable_file_uploader.py b/neuracore/data_daemon/upload_management/resumable_file_uploader.py index e2abf2dd..c543f8f4 100644 --- a/neuracore/data_daemon/upload_management/resumable_file_uploader.py +++ b/neuracore/data_daemon/upload_management/resumable_file_uploader.py @@ -148,7 +148,7 @@ async def upload(self) -> tuple[bool, int, str | None]: FileNotFoundError: If the local file does not exist. """ logger.info( - f"Starting upload for {self._recording_id}/{self._filepath}: " + f"Starting upload for {self._recording_id} at{self._filepath}: " f"{self._bytes_uploaded} bytes already uploaded" ) @@ -181,13 +181,13 @@ async def upload(self) -> tuple[bool, int, str | None]: ) return (False, self._bytes_uploaded, checksum_error) logger.info( - f"Upload complete for {self._recording_id}/{self._filepath}: " + f"Upload complete for {self._recording_id} at {self._filepath}: " f"{self._total_bytes} bytes" ) return (True, self._bytes_uploaded, None) else: logger.warning( - f"Upload failed for {self._recording_id}/{self._filepath} " + f"Upload failed for {self._recording_id} at {self._filepath} " f"at offset {self._bytes_uploaded}/{self._total_bytes}: {error_message}" ) return (False, self._bytes_uploaded, error_message) diff --git a/tests/unit/data_daemon/communications_management/test_zmq_sockets.py b/tests/unit/data_daemon/communications_management/test_zmq_sockets.py index a7cac7c3..c5db1114 100644 --- a/tests/unit/data_daemon/communications_management/test_zmq_sockets.py +++ b/tests/unit/data_daemon/communications_management/test_zmq_sockets.py @@ -218,16 +218,6 @@ def test_zmq_commands_and_message_flow(daemon_runtime) -> None: payload = json.dumps({"seq": 1}).encode("utf-8") active_trace_id = producer.trace_id - producer.send_data( - payload, - data_type=DataType.CUSTOM_1D, - data_type_name="custom", - robot_instance=1, - robot_id="robot-1", - dataset_id="dataset-1", - ) - producer.end_trace() - trace_written: list[int] = [] def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None: @@ -236,6 +226,18 @@ def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None: get_emitter().on(Emitter.TRACE_WRITTEN, on_trace_written) try: + producer.send_data( + payload, + data_type=DataType.CUSTOM_1D, + data_type_name="custom", + robot_instance=1, + robot_id="robot-1", + dataset_id="dataset-1", + ) + assert _wait_for( + lambda: active_trace_id in daemon._trace_recordings, timeout=0.5 + ) + producer.end_trace() assert _wait_for(lambda: trace_written, timeout=1) finally: get_emitter().remove_listener(Emitter.TRACE_WRITTEN, on_trace_written) @@ -430,7 +432,7 @@ def on_stop_recording(rec_id: str) -> None: assert recording_id not in daemon._pending_close_recordings assert recording_id not in daemon._closed_recordings - daemon._handle_recording_stopped(None, msg) + daemon._handle_recording_stopped(msg) assert recording_id in daemon._pending_close_recordings assert recording_id not in daemon._closed_recordings diff --git a/tests/unit/data_daemon/lifecycle/test_daemon_lifecycle.py b/tests/unit/data_daemon/lifecycle/test_daemon_lifecycle.py index de905808..7ac14c0d 100644 --- a/tests/unit/data_daemon/lifecycle/test_daemon_lifecycle.py +++ b/tests/unit/data_daemon/lifecycle/test_daemon_lifecycle.py @@ -33,7 +33,7 @@ class _InMemoryStore: def __init__(self, traces: list[TraceRecord]) -> None: self._traces = {trace.trace_id: trace for trace in traces} - def list_traces(self) -> list[TraceRecord]: + async def list_traces(self) -> list[TraceRecord]: return list(self._traces.values()) async def record_error( diff --git a/tests/unit/data_daemon/lifecycle/test_management_channel_startup.py b/tests/unit/data_daemon/lifecycle/test_management_channel_startup.py index d2c919a5..a79b495d 100644 --- a/tests/unit/data_daemon/lifecycle/test_management_channel_startup.py +++ b/tests/unit/data_daemon/lifecycle/test_management_channel_startup.py @@ -25,7 +25,7 @@ class _Context: class _DummyStore: - def list_traces(self): + async def list_traces(self): return [] diff --git a/tests/unit/data_daemon/state_management/test_state_store_sqlite.py b/tests/unit/data_daemon/state_management/test_state_store_sqlite.py index 55052963..21027d0d 100644 --- a/tests/unit/data_daemon/state_management/test_state_store_sqlite.py +++ b/tests/unit/data_daemon/state_management/test_state_store_sqlite.py @@ -116,6 +116,60 @@ async def test_upsert_trace_bytes_inserts_row(store: SqliteStateStore) -> None: assert row["bytes_uploaded"] == 0 +@pytest.mark.asyncio +async def test_upsert_trace_bytes_preserves_existing_total_bytes( + store: SqliteStateStore, +) -> None: + await store.upsert_trace_metadata( + trace_id="trace-bytes-preserve-total", + recording_id="rec-bytes-preserve-total", + data_type=PRIMARY_DATA_TYPE, + data_type_name="primary", + path="/tmp/trace-bytes-preserve-total.bin", + total_bytes=256, + robot_instance=ROBOT_INSTANCE, + ) + + trace = await store.upsert_trace_bytes( + trace_id="trace-bytes-preserve-total", + recording_id="rec-bytes-preserve-total", + bytes_written=64, + ) + + assert trace.status == TraceStatus.WRITTEN + row = await _get_trace_row(store, "trace-bytes-preserve-total") + assert row is not None + assert row["bytes_written"] == 64 + assert row["total_bytes"] == 256 + + +@pytest.mark.asyncio +async def test_upsert_trace_bytes_backfills_missing_total_bytes( + store: SqliteStateStore, +) -> None: + await store.upsert_trace_metadata( + trace_id="trace-bytes-backfill-total", + recording_id="rec-bytes-backfill-total", + data_type=PRIMARY_DATA_TYPE, + data_type_name="primary", + path="/tmp/trace-bytes-backfill-total.bin", + total_bytes=None, + robot_instance=ROBOT_INSTANCE, + ) + + trace = await store.upsert_trace_bytes( + trace_id="trace-bytes-backfill-total", + recording_id="rec-bytes-backfill-total", + bytes_written=96, + ) + + assert trace.status == TraceStatus.WRITTEN + row = await _get_trace_row(store, "trace-bytes-backfill-total") + assert row is not None + assert row["bytes_written"] == 96 + assert row["total_bytes"] == 96 + + @pytest.mark.asyncio async def test_update_bytes_uploaded_sets_value(store: SqliteStateStore) -> None: await store.upsert_trace_metadata(