diff --git a/robocodec-cli/src/cli/mod.rs b/robocodec-cli/src/cli/mod.rs index 143b447..c3a3120 100644 --- a/robocodec-cli/src/cli/mod.rs +++ b/robocodec-cli/src/cli/mod.rs @@ -4,6 +4,8 @@ //! CLI utilities for the robocodec command-line interface. +use robocodec::RoboReader; + pub mod output; pub mod progress; pub mod time; @@ -18,9 +20,9 @@ pub use time::{format_duration, format_timestamp, parse_time_range}; /// /// Convenience wrapper around `RoboReader::open` that provides better /// error messages for invalid paths. -pub fn open_reader(path: &std::path::Path) -> Result { +pub fn open_reader(path: &std::path::Path) -> Result { let path_str = path .to_str() .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 path: {:?}", path))?; - Ok(robocodec::RoboReader::open(path_str)?) + Ok(RoboReader::open(path_str)?) } diff --git a/robocodec-cli/src/cmds/extract.rs b/robocodec-cli/src/cmds/extract.rs index c6e9fa5..1b60a1c 100644 --- a/robocodec-cli/src/cmds/extract.rs +++ b/robocodec-cli/src/cmds/extract.rs @@ -9,7 +9,8 @@ use std::path::PathBuf; use clap::Subcommand; use crate::cli::{Progress, Result, open_reader, parse_time_range}; -use robocodec::{FormatReader, RoboRewriter}; +use robocodec::io::RawMessage; +use robocodec::{FormatReader, FormatWriter, RoboReader, RoboWriter}; /// Extract subsets of data from files. #[derive(Subcommand, Clone, Debug)] @@ -155,42 +156,59 @@ fn cmd_extract_messages( let reader = open_reader(&input)?; let total = reader.message_count(); - let channel_count = reader.channels().len() as u64; let limit = count.unwrap_or(total as usize); println!(" Limit: {} messages", limit); - // Use rewriter for full file copy with limit support - // For partial extraction, we need format-specific iteration which is not yet exposed - if limit < total as usize { - return Err(anyhow::anyhow!( - "Partial message extraction (count < total) requires format-specific iteration. \ - Use the convert command for full file copying." - )); - } + let output_str = output + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in output path"))?; + let mut writer = RoboWriter::create(output_str)?; + + // Add all channels to writer + let channel_map = add_channels_to_writer(&reader, &mut writer)?; - // Full file copy using rewriter let mut progress = if show_progress { - Some(Progress::new(channel_count, "Copying channels")) + Some(Progress::new(limit as u64, "Extracting messages")) } else { None }; - let mut rewriter = RoboRewriter::open(&input)?; + // Iterate raw messages and write up to limit + let raw_iter = reader.iter_raw()?; + let mut written = 0u64; + + for result in raw_iter { + if written >= limit as u64 { + break; + } + + let (raw_msg, _channel_info) = result?; + + // Remap channel_id to writer's channel_id + if let Some(&new_ch_id) = channel_map.get(&raw_msg.channel_id) { + let write_msg = RawMessage { + channel_id: new_ch_id, + log_time: raw_msg.log_time, + publish_time: raw_msg.publish_time, + data: raw_msg.data, + sequence: raw_msg.sequence, + }; + writer.write(&write_msg)?; + written += 1; + } - // Simulate channel progress during rewrite - if let Some(ref mut pb) = progress { - for i in 0..channel_count { - pb.set(i + 1); + if let Some(ref mut pb) = progress { + pb.set(written); } } - let stats = rewriter.rewrite(&output)?; + writer.finish()?; if let Some(pb) = progress { - pb.finish(format!("{} messages", stats.message_count)); + pb.finish(format!("{written} messages")); } else { - println!(" Written: {} messages", stats.message_count); + println!(" Written: {written} messages"); } Ok(()) @@ -212,13 +230,13 @@ fn cmd_extract_topics( let reader = open_reader(&input)?; - // Find matching channels and count messages - let mut matching_channels: Vec = Vec::new(); + // Find matching channels + let mut matching_channels = std::collections::HashSet::new(); for (ch_id, channel) in reader.channels() { for topic in &topics_list { if channel.topic == *topic || channel.topic.contains(topic) { - matching_channels.push(*ch_id); + matching_channels.insert(*ch_id); break; } } @@ -231,33 +249,61 @@ fn cmd_extract_topics( )); } + println!(" Matched {} channels", matching_channels.len()); + + let output_str = output + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in output path"))?; + let mut writer = RoboWriter::create(output_str)?; + + // Only add matching channels to writer + let channel_map = add_matching_channels_to_writer(&reader, &mut writer, &matching_channels)?; + let mut progress = if show_progress { - Some(Progress::new( - matching_channels.len() as u64, - "Processing channels", - )) + Some(Progress::new(0, "Extracting topics")) } else { None }; - // Simulate processing each channel - for (i, &ch_id) in matching_channels.iter().enumerate() { + let raw_iter = reader.iter_raw()?; + let mut written = 0u64; + + for result in raw_iter { + let (raw_msg, _channel_info) = result?; + + // Only write messages from matching channels + if let Some(&new_ch_id) = channel_map.get(&raw_msg.channel_id) { + let write_msg = RawMessage { + channel_id: new_ch_id, + log_time: raw_msg.log_time, + publish_time: raw_msg.publish_time, + data: raw_msg.data, + sequence: raw_msg.sequence, + }; + writer.write(&write_msg)?; + written += 1; + } + if let Some(ref mut pb) = progress { - pb.set((i + 1) as u64); + pb.set(written); } - // In a full implementation, this would iterate through messages - let _ = ch_id; // Channel would be processed here } + writer.finish()?; + if let Some(pb) = progress { - pb.finish(format!("{} channels", matching_channels.len())); + pb.finish(format!( + "{written} messages from {} topics", + matching_channels.len() + )); + } else { + println!( + " Written: {written} messages from {} topics", + matching_channels.len() + ); } - // Topic extraction requires format-specific iteration which is not yet exposed - Err(anyhow::anyhow!( - "Topic-specific extraction requires format-specific message iteration. \ - This feature is not yet implemented. Use the convert command for full file copying." - )) + Ok(()) } /// Extract N messages per topic. @@ -272,39 +318,79 @@ fn cmd_extract_per_topic( println!(" Output: {}", output.display()); println!(" Messages per topic: {}", count); - if count != 1 { - return Err(anyhow::anyhow!( - "Per-topic extraction with count > 1 requires format-specific iteration. \ - This feature is not yet implemented." - )); - } - let reader = open_reader(&input)?; - let channel_count = reader.channels().len() as u64; + let channel_count = reader.channels().len(); + + let output_str = output + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in output path"))?; + let mut writer = RoboWriter::create(output_str)?; + + // Add all channels to writer + let channel_map = add_channels_to_writer(&reader, &mut writer)?; let mut progress = if show_progress { - Some(Progress::new(channel_count, "Scanning channels")) + Some(Progress::new( + (channel_count * count) as u64, + "Extracting per topic", + )) } else { None }; - // Simulate scanning each channel - for (i, channel) in reader.channels().values().enumerate() { + // Track how many messages we've written per channel + let mut per_channel_count: std::collections::HashMap = + std::collections::HashMap::new(); + + let raw_iter = reader.iter_raw()?; + let mut written = 0u64; + let mut all_done = false; + + for result in raw_iter { + if all_done { + break; + } + + let (raw_msg, _channel_info) = result?; + + let ch_count = per_channel_count.entry(raw_msg.channel_id).or_insert(0); + + if *ch_count < count + && let Some(&new_ch_id) = channel_map.get(&raw_msg.channel_id) + { + let write_msg = RawMessage { + channel_id: new_ch_id, + log_time: raw_msg.log_time, + publish_time: raw_msg.publish_time, + data: raw_msg.data, + sequence: raw_msg.sequence, + }; + writer.write(&write_msg)?; + written += 1; + *ch_count += 1; + } + if let Some(ref mut pb) = progress { - pb.set((i + 1) as u64); + pb.set(written); + } + + // Check if all channels have enough messages + if per_channel_count.len() == channel_count + && per_channel_count.values().all(|&c| c >= count) + { + all_done = true; } - let _ = channel.topic; // Topic would be processed here } + writer.finish()?; + if let Some(pb) = progress { - pb.finish(format!("{} channels scanned", channel_count)); + pb.finish(format!("{written} messages from {channel_count} topics")); + } else { + println!(" Written: {written} messages from {channel_count} topics"); } - // Per-topic extraction requires format-specific iteration - Err(anyhow::anyhow!( - "Per-topic extraction requires format-specific message iteration. \ - This feature is not yet implemented. Use the convert command for full file copying." - )) + Ok(()) } /// Extract messages within time range. @@ -322,42 +408,58 @@ fn cmd_extract_time_range( println!(" Start: {}", start_ns); println!(" End: {}", end_ns); - // Check if the full file is within range (full file copy) - if start_ns == 0 && end_ns == u64::MAX { - let reader = open_reader(&input)?; - let channel_count = reader.channels().len() as u64; + let reader = open_reader(&input)?; - let mut progress = if show_progress { - Some(Progress::new(channel_count, "Copying channels")) - } else { - None - }; + let output_str = output + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in output path"))?; + let mut writer = RoboWriter::create(output_str)?; - let mut rewriter = RoboRewriter::open(&input)?; + // Add all channels to writer + let channel_map = add_channels_to_writer(&reader, &mut writer)?; - // Simulate channel progress during rewrite - if let Some(ref mut pb) = progress { - for i in 0..channel_count { - pb.set(i + 1); - } - } + let mut progress = if show_progress { + Some(Progress::new(0, "Extracting by time range")) + } else { + None + }; - let stats = rewriter.rewrite(&output)?; + let raw_iter = reader.iter_raw()?; + let mut written = 0u64; + + for result in raw_iter { + let (raw_msg, _channel_info) = result?; + + // Filter by time range (use log_time) + if raw_msg.log_time >= start_ns + && raw_msg.log_time <= end_ns + && let Some(&new_ch_id) = channel_map.get(&raw_msg.channel_id) + { + let write_msg = RawMessage { + channel_id: new_ch_id, + log_time: raw_msg.log_time, + publish_time: raw_msg.publish_time, + data: raw_msg.data, + sequence: raw_msg.sequence, + }; + writer.write(&write_msg)?; + written += 1; + } - if let Some(pb) = progress { - pb.finish(format!("{} messages", stats.message_count)); - } else { - println!(" Written: {} messages", stats.message_count); + if let Some(ref mut pb) = progress { + pb.set(written); } + } + + writer.finish()?; - return Ok(()); + if let Some(pb) = progress { + pb.finish(format!("{written} messages")); + } else { + println!(" Written: {written} messages"); } - // Time range filtering requires format-specific iteration - Err(anyhow::anyhow!( - "Time range filtering requires format-specific message iteration. \ - This feature is not yet implemented. Use the convert command for full file copying." - )) + Ok(()) } /// Create minimal fixture files. @@ -372,21 +474,113 @@ fn cmd_create_fixture( let reader = open_reader(&input)?; let fixture_dir = output_dir.unwrap_or_else(|| PathBuf::from("tests/fixtures")); - std::fs::create_dir_all(&fixture_dir)?; - let _fixture_name = name.unwrap_or_else(|| "fixture".to_string()); + let fixture_name = name.unwrap_or_else(|| "fixture".to_string()); + + // Determine output extension from input + let ext = input.extension().and_then(|e| e.to_str()).unwrap_or("bag"); + let output_path = fixture_dir.join(format!("{fixture_name}.{ext}")); + println!(" Output: {}", output_path.display()); println!(" Available topics:"); for channel in reader.channels().values() { println!(" - {} ({})", channel.topic, channel.message_type); } - // Fixture creation requires format-specific iteration to extract one message per topic - Err(anyhow::anyhow!( - "Fixture creation requires format-specific message iteration. \ - This feature is not yet implemented. Use the convert command for full file copying." - )) + let output_str = output_path + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in output path"))?; + let mut writer = RoboWriter::create(output_str)?; + + // Add all channels to writer + let channel_count = reader.channels().len(); + let channel_map = add_channels_to_writer(&reader, &mut writer)?; + + // Extract one message per topic (same as per-topic with count=1) + let mut per_channel_count: std::collections::HashMap = + std::collections::HashMap::new(); + + let raw_iter = reader.iter_raw()?; + let mut written = 0u64; + + for result in raw_iter { + let (raw_msg, _channel_info) = result?; + + let ch_count = per_channel_count.entry(raw_msg.channel_id).or_insert(0); + + if *ch_count < 1 + && let Some(&new_ch_id) = channel_map.get(&raw_msg.channel_id) + { + let write_msg = RawMessage { + channel_id: new_ch_id, + log_time: raw_msg.log_time, + publish_time: raw_msg.publish_time, + data: raw_msg.data, + sequence: raw_msg.sequence, + }; + writer.write(&write_msg)?; + written += 1; + *ch_count += 1; + } + + // Check if all channels have a message + if per_channel_count.len() == channel_count && per_channel_count.values().all(|&c| c >= 1) { + break; + } + } + + writer.finish()?; + + println!( + " Created fixture: {} ({written} messages from {channel_count} topics)", + output_path.display() + ); + + Ok(()) +} + +/// Add all channels from reader to writer, returning a map from old channel_id to new channel_id. +fn add_channels_to_writer( + reader: &RoboReader, + writer: &mut RoboWriter, +) -> Result> { + let mut channel_map = std::collections::HashMap::new(); + + for (&old_id, channel) in reader.channels() { + let new_id = writer.add_channel( + &channel.topic, + &channel.message_type, + &channel.encoding, + channel.schema.as_deref(), + )?; + channel_map.insert(old_id, new_id); + } + + Ok(channel_map) +} + +/// Add only matching channels from reader to writer. +fn add_matching_channels_to_writer( + reader: &RoboReader, + writer: &mut RoboWriter, + matching_channels: &std::collections::HashSet, +) -> Result> { + let mut channel_map = std::collections::HashMap::new(); + + for (&old_id, channel) in reader.channels() { + if matching_channels.contains(&old_id) { + let new_id = writer.add_channel( + &channel.topic, + &channel.message_type, + &channel.encoding, + channel.schema.as_deref(), + )?; + channel_map.insert(old_id, new_id); + } + } + + Ok(channel_map) } #[cfg(test)] @@ -477,32 +671,47 @@ mod tests { // ======================================================================== #[test] - fn test_cmd_extract_messages_partial_extraction_error() { + fn test_cmd_extract_messages_partial() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; // Skip if fixture not available } - // Partial extraction (count < total) should error - let result = cmd_extract_messages(path.clone(), temp_output(), Some(1), false); - assert!(result.is_err(), "partial extraction should fail"); + let output = temp_output(); + let result = cmd_extract_messages(path, output.clone(), Some(1), false); + // Should succeed - partial extraction now works assert!( - result - .unwrap_err() - .to_string() - .contains("Partial message extraction") + result.is_ok(), + "partial extraction should succeed: {:?}", + result.err() + ); + let _ = std::fs::remove_file(&output); + } + + #[test] + fn test_cmd_extract_messages_all() { + let path = fixture_path("robocodec_test_0.mcap"); + if !path.exists() { + return; + } + + let output = temp_output(); + let result = cmd_extract_messages(path, output.clone(), None, false); + assert!( + result.is_ok(), + "full extraction should succeed: {:?}", + result.err() ); + let _ = std::fs::remove_file(&output); } #[test] - fn test_cmd_extract_messages_invalid_range() { + fn test_cmd_extract_messages_invalid_output() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - // Can't test full extraction without a valid output - // but we can verify the function attempts to open the file let result = cmd_extract_messages( path, PathBuf::from("/nonexistent/output/dir/file.mcap"), @@ -540,14 +749,13 @@ mod tests { } #[test] - fn test_cmd_extract_topics_not_implemented() { + fn test_cmd_extract_topics_matching() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - // Even with matching topics, should return not implemented error - // First we need to find a real topic name + // Find a real topic name let Ok(reader) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| open_reader(&path))) else { @@ -560,14 +768,14 @@ mod tests { return; }; - let result = cmd_extract_topics(path, temp_output(), topic, false); - assert!(result.is_err(), "topic extraction not yet implemented"); + let output = temp_output(); + let result = cmd_extract_topics(path, output.clone(), topic, false); assert!( - result - .unwrap_err() - .to_string() - .contains("not yet implemented") + result.is_ok(), + "topic extraction should succeed: {:?}", + result.err() ); + let _ = std::fs::remove_file(&output); } #[test] @@ -604,34 +812,37 @@ mod tests { // ======================================================================== #[test] - fn test_cmd_extract_per_topic_count_not_one() { + fn test_cmd_extract_per_topic_count_one() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - // count != 1 should fail - let result = cmd_extract_per_topic(path, temp_output(), 2, false); - assert!(result.is_err(), "count > 1 should fail"); - assert!(result.unwrap_err().to_string().contains("count > 1")); + let output = temp_output(); + let result = cmd_extract_per_topic(path, output.clone(), 1, false); + assert!( + result.is_ok(), + "per-topic extraction with count=1 should succeed: {:?}", + result.err() + ); + let _ = std::fs::remove_file(&output); } #[test] - fn test_cmd_extract_per_topic_not_implemented() { + fn test_cmd_extract_per_topic_count_multiple() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - // Even with count=1, should return not implemented - let result = cmd_extract_per_topic(path, temp_output(), 1, false); - assert!(result.is_err(), "per-topic extraction not yet implemented"); + let output = temp_output(); + let result = cmd_extract_per_topic(path, output.clone(), 3, false); assert!( - result - .unwrap_err() - .to_string() - .contains("not yet implemented") + result.is_ok(), + "per-topic extraction with count>1 should succeed: {:?}", + result.err() ); + let _ = std::fs::remove_file(&output); } // ======================================================================== @@ -656,21 +867,20 @@ mod tests { } #[test] - fn test_cmd_extract_time_range_not_implemented() { + fn test_cmd_extract_time_range_specific_range() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - // Valid range that's not "0,MAX" should fail with not implemented - let result = cmd_extract_time_range(path, temp_output(), "1000,2000".to_string(), false); - assert!(result.is_err(), "time range filtering not yet implemented"); + let output = temp_output(); + let result = cmd_extract_time_range(path, output.clone(), "0,MAX".to_string(), false); assert!( - result - .unwrap_err() - .to_string() - .contains("not yet implemented") + result.is_ok(), + "time range extraction should succeed: {:?}", + result.err() ); + let _ = std::fs::remove_file(&output); } #[test] @@ -680,7 +890,6 @@ mod tests { return; } - // Even with 0,MAX range, invalid output should fail let result = cmd_extract_time_range( path, PathBuf::from("/nonexistent/output/dir/file.mcap"), @@ -695,20 +904,23 @@ mod tests { // ======================================================================== #[test] - fn test_cmd_create_fixture_not_implemented() { + fn test_cmd_create_fixture() { let path = fixture_path("robocodec_test_0.mcap"); if !path.exists() { return; } - let result = cmd_create_fixture(path, None, None); - assert!(result.is_err(), "fixture creation not yet implemented"); + let temp_dir = + std::env::temp_dir().join(format!("robocodec_fixture_{}", std::process::id())); + let result = cmd_create_fixture(path, Some(temp_dir.clone()), Some("test".to_string())); assert!( - result - .unwrap_err() - .to_string() - .contains("not yet implemented") + result.is_ok(), + "fixture creation should succeed: {:?}", + result.err() ); + + // Clean up + let _ = std::fs::remove_dir_all(temp_dir); } #[test] @@ -718,12 +930,25 @@ mod tests { return; } - let temp_dir = std::env::temp_dir().join("robocodec_fixture_test"); - let result = cmd_create_fixture(path, Some(temp_dir.clone()), Some("test".to_string())); + let temp_dir = + std::env::temp_dir().join(format!("robocodec_fixture_custom_{}", std::process::id())); + let result = + cmd_create_fixture(path, Some(temp_dir.clone()), Some("my_fixture".to_string())); + assert!( + result.is_ok(), + "fixture creation should succeed: {:?}", + result.err() + ); - assert!(result.is_err(), "fixture creation not yet implemented"); + // Verify output file exists + let output_file = temp_dir.join("my_fixture.mcap"); + assert!( + output_file.exists(), + "fixture file should exist at {:?}", + output_file + ); - // Clean up temp dir + // Clean up let _ = std::fs::remove_dir_all(temp_dir); } diff --git a/src/io/formats/bag/parallel.rs b/src/io/formats/bag/parallel.rs index 711d50b..69ed010 100644 --- a/src/io/formats/bag/parallel.rs +++ b/src/io/formats/bag/parallel.rs @@ -341,6 +341,10 @@ impl FormatReader for ParallelBagReader { Ok(Box::new(stream)) } + fn iter_raw_boxed(&self) -> Result> { + Ok(Box::new(self.iter_raw()?)) + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/src/io/formats/bag/parser.rs b/src/io/formats/bag/parser.rs index 1e4aeed..a1efa6c 100644 --- a/src/io/formats/bag/parser.rs +++ b/src/io/formats/bag/parser.rs @@ -333,6 +333,15 @@ impl BagParser { let value = &field_bytes[eq_pos + 1..]; Self::parse_field(&mut fields, name, value); + } else { + return Err(CodecError::parse( + "BagParser", + format!( + "Header field is missing equals sign (field_len={}, bytes={:?})", + field_len, + String::from_utf8_lossy(&field_bytes[..field_len.min(64)]), + ), + )); } } @@ -439,10 +448,9 @@ impl BagParser { match header_fields.op { Some(OP_CONNECTION) => { // Connection data section also contains field=value pairs - let data_fields = Self::parse_record_header(&data).unwrap_or_default(); - if let Some(conn) = Self::connection_from_fields(&header_fields, &data_fields) { - connections.insert(conn.conn_id, conn); - } + let data_fields = Self::parse_record_header(&data)?; + let conn = Self::connection_from_fields(&header_fields, &data_fields)?; + connections.insert(conn.conn_id, conn); } Some(OP_CHUNK_INFO) => { if let Some(chunk_info) = @@ -462,16 +470,60 @@ impl BagParser { } /// Create a `BagConnection` from parsed header and data fields. + /// + /// Returns an error if required fields are missing. Per the ROS bag spec, + /// connection records must have `conn` and `topic` in the header, and + /// `topic`, `type`, and `md5sum` in the data section. fn connection_from_fields( header_fields: &RecordHeader, data_fields: &RecordHeader, - ) -> Option { - Some(BagConnection { - conn_id: header_fields.conn?, - topic: header_fields.topic.clone()?, - // type, md5sum, message_definition come from the data section - message_type: data_fields.message_type.clone()?, - md5sum: data_fields.md5sum.clone().unwrap_or_default(), + ) -> Result { + let conn_id = header_fields.conn.ok_or_else(|| { + CodecError::parse( + "BagParser", + "Connection record missing 'conn' field in header", + ) + })?; + let topic = header_fields.topic.clone().ok_or_else(|| { + CodecError::parse( + "BagParser", + "Connection record missing 'topic' field in header", + ) + })?; + let message_type = data_fields.message_type.clone().ok_or_else(|| { + CodecError::parse( + "BagParser", + format!( + "Connection record for topic '{}' missing 'type' field in data section", + topic + ), + ) + })?; + // topic must also be present in the data section per ROS bag spec + if data_fields.topic.is_none() { + return Err(CodecError::parse( + "BagParser", + format!( + "Connection record for topic '{}' missing 'topic' field in data section", + topic + ), + )); + } + let md5sum = data_fields.md5sum.clone().ok_or_else(|| { + CodecError::parse( + "BagParser", + format!( + "Connection record for topic '{}' missing 'md5sum' field in data section", + topic + ), + ) + })?; + + Ok(BagConnection { + conn_id, + topic, + message_type, + md5sum, message_definition: data_fields.message_definition.clone().unwrap_or_default(), caller_id: data_fields.callerid.clone().unwrap_or_default(), }) @@ -539,10 +591,9 @@ impl BagParser { match header_fields.op { Some(OP_CONNECTION) => { // Connection data section also contains field=value pairs - let data_fields = Self::parse_record_header(&data).unwrap_or_default(); - if let Some(conn) = Self::connection_from_fields(&header_fields, &data_fields) { - connections.insert(conn.conn_id, conn); - } + let data_fields = Self::parse_record_header(&data)?; + let conn = Self::connection_from_fields(&header_fields, &data_fields)?; + connections.insert(conn.conn_id, conn); } Some(OP_CHUNK) => { // Record chunk info from the chunk header diff --git a/src/io/formats/bag/stream.rs b/src/io/formats/bag/stream.rs index 29f547e..cbcd5ee 100644 --- a/src/io/formats/bag/stream.rs +++ b/src/io/formats/bag/stream.rs @@ -362,8 +362,15 @@ impl StreamingBagParser { self.connections.insert(conn.conn_id, conn); } } - OP_BAG_HEADER | OP_INDEX_DATA | OP_CHUNK | OP_CHUNK_INFO => { - // Metadata records - ignore for streaming + OP_CHUNK => { + // Chunk records contain compressed message data. + // Decompress and recursively parse the inner records. + let compression = fields.compression.as_deref().unwrap_or("none"); + let decompressed = Self::decompress_chunk(compression, &data)?; + self.parse_inner_records(&decompressed, messages)?; + } + OP_BAG_HEADER | OP_INDEX_DATA | OP_CHUNK_INFO => { + // Metadata records - skip for streaming } _ => { // Unknown op code - this might indicate file corruption or version mismatch @@ -499,6 +506,98 @@ impl StreamingBagParser { }) } + /// Decompress chunk data based on the compression format. + fn decompress_chunk(compression: &str, data: &[u8]) -> Result, FatalError> { + match compression { + "none" => Ok(data.to_vec()), + "bz2" => { + use bzip2::read::BzDecoder; + use std::io::Read as _; + let mut decoder = BzDecoder::new(data); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .map_err(|e| FatalError::io_error(format!("BZ2 decompression failed: {e}")))?; + Ok(decompressed) + } + "lz4" => { + use lz4_flex::decompress_size_prepended; + decompress_size_prepended(data) + .map_err(|e| FatalError::io_error(format!("LZ4 decompression failed: {e}"))) + } + _ => Err(FatalError::io_error(format!( + "Unsupported BAG chunk compression: {compression}" + ))), + } + } + + /// Parse inner records from decompressed chunk data. + /// + /// Decompressed chunks contain `OP_MSG_DATA` and `OP_CONNECTION` records. + fn parse_inner_records( + &mut self, + data: &[u8], + messages: &mut Vec, + ) -> Result<(), FatalError> { + let mut pos = 0; + while pos + 4 <= data.len() { + // Read header_len + let header_len = u32::from_le_bytes( + data[pos..pos + 4] + .try_into() + .expect("slice is exactly 4 bytes"), + ) as usize; + pos += 4; + + if pos + header_len + 4 > data.len() { + break; // Incomplete record at end of chunk + } + + // Parse header fields + let header_bytes = &data[pos..pos + header_len]; + let fields = Self::parse_record_header(header_bytes)?; + pos += header_len; + + // Read data_len + let data_len = u32::from_le_bytes( + data[pos..pos + 4] + .try_into() + .expect("slice is exactly 4 bytes"), + ) as usize; + pos += 4; + + if pos + data_len > data.len() { + break; // Incomplete record at end of chunk + } + + let record_data = &data[pos..pos + data_len]; + pos += data_len; + + match fields.op { + Some(OP_MSG_DATA) => { + if let Some(conn_id) = fields.conn { + let time = fields.time.unwrap_or(0); + messages.push(BagMessageRecord { + conn_id, + log_time: time, + data: record_data.to_vec(), + }); + } + } + Some(OP_CONNECTION) => { + let data_fields = Self::parse_record_header(record_data).unwrap_or_default(); + if let Some(conn) = Self::connection_from_fields(&fields, &data_fields) { + self.connections.insert(conn.conn_id, conn); + } + } + _ => { + // Skip other record types inside chunks (e.g. index data) + } + } + } + Ok(()) + } + /// Get all discovered connections as `ChannelInfo`. /// /// Uses the original BAG connection ID as the channel ID to ensure @@ -600,10 +699,11 @@ impl StreamingParser for StreamingBagParser { fn parse_chunk(&mut self, data: &[u8]) -> Result, FatalError> { // Call the inherent parse_chunk method // Use fully qualified syntax to avoid recursion + let prev_conn_count = self.connections.len(); let messages = StreamingBagParser::parse_chunk(self, data)?; // Rebuild channels if we discovered new connections - if self.has_connections() && self.cached_channels.is_empty() { + if self.connections.len() != prev_conn_count { self.rebuild_channels(); } @@ -655,6 +755,154 @@ enum ParserState { mod tests { use super::*; + // ========================================================================= + // Test helpers: build raw BAG binary structures + // ========================================================================= + + /// Build a BAG header field: `field_len(u32) | name=value`. + fn build_field(name: &[u8], value: &[u8]) -> Vec { + let field_len = (name.len() + 1 + value.len()) as u32; // +1 for '=' + let mut out = Vec::new(); + out.extend(&field_len.to_le_bytes()); + out.extend(name); + out.push(b'='); + out.extend(value); + out + } + + /// Build a complete BAG record: `header_len(u32) | header_bytes | data_len(u32) | data`. + fn build_record(header_fields: &[u8], data: &[u8]) -> Vec { + let mut out = Vec::new(); + out.extend(&(header_fields.len() as u32).to_le_bytes()); + out.extend(header_fields); + out.extend(&(data.len() as u32).to_le_bytes()); + out.extend(data); + out + } + + /// Build op field bytes. + fn op_field(op: u8) -> Vec { + build_field(b"op", &[op]) + } + + /// Build conn field bytes. + fn conn_field(conn_id: u32) -> Vec { + build_field(b"conn", &conn_id.to_le_bytes()) + } + + /// Build time field bytes (sec + nsec). + fn time_field(sec: u32, nsec: u32) -> Vec { + let mut value = Vec::new(); + value.extend(&sec.to_le_bytes()); + value.extend(&nsec.to_le_bytes()); + build_field(b"time", &value) + } + + /// Build topic field bytes. + fn topic_field(topic: &str) -> Vec { + build_field(b"topic", topic.as_bytes()) + } + + /// Build compression field bytes. + fn compression_field(compression: &str) -> Vec { + build_field(b"compression", compression.as_bytes()) + } + + /// Build size field bytes (uncompressed size). + fn size_field(size: u32) -> Vec { + build_field(b"size", &size.to_le_bytes()) + } + + /// Build a BAG OP_MSG_DATA record. + fn build_msg_data_record(conn_id: u32, sec: u32, nsec: u32, payload: &[u8]) -> Vec { + let mut header = Vec::new(); + header.extend(op_field(OP_MSG_DATA)); + header.extend(conn_field(conn_id)); + header.extend(time_field(sec, nsec)); + build_record(&header, payload) + } + + /// Build a BAG OP_CONNECTION record. + fn build_connection_record(conn_id: u32, topic: &str) -> Vec { + let mut header = Vec::new(); + header.extend(op_field(OP_CONNECTION)); + header.extend(conn_field(conn_id)); + header.extend(topic_field(topic)); + + // Data section contains additional fields (type, md5sum, etc.) + let mut data_fields = Vec::new(); + data_fields.extend(build_field(b"type", b"std_msgs/String")); + data_fields.extend(build_field(b"md5sum", b"992ce8a1687cec8c8bd883ec73ca41d1")); + data_fields.extend(build_field(b"message_definition", b"string data")); + + build_record(&header, &data_fields) + } + + /// Build a BAG OP_CHUNK record with uncompressed inner data. + fn build_chunk_record_none(inner_records: &[u8]) -> Vec { + let mut header = Vec::new(); + header.extend(op_field(OP_CHUNK)); + header.extend(compression_field("none")); + header.extend(size_field(inner_records.len() as u32)); + build_record(&header, inner_records) + } + + /// Build a BAG OP_CHUNK record with LZ4-compressed inner data. + fn build_chunk_record_lz4(inner_records: &[u8]) -> Vec { + use lz4_flex::compress_prepend_size; + let compressed = compress_prepend_size(inner_records); + + let mut header = Vec::new(); + header.extend(op_field(OP_CHUNK)); + header.extend(compression_field("lz4")); + header.extend(size_field(inner_records.len() as u32)); + build_record(&header, &compressed) + } + + /// Build a BAG OP_CHUNK record with BZ2-compressed inner data. + fn build_chunk_record_bz2(inner_records: &[u8]) -> Vec { + use bzip2::Compression; + use bzip2::write::BzEncoder; + use std::io::Write; + + let mut encoder = BzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(inner_records).unwrap(); + let compressed = encoder.finish().unwrap(); + + let mut header = Vec::new(); + header.extend(op_field(OP_CHUNK)); + header.extend(compression_field("bz2")); + header.extend(size_field(inner_records.len() as u32)); + build_record(&header, &compressed) + } + + /// Build a BAG header record (op=0x03). + fn build_bag_header_record() -> Vec { + let mut header = Vec::new(); + header.extend(op_field(OP_BAG_HEADER)); + // index_pos and conn_count/chunk_count are typically in the header + header.extend(build_field(b"index_pos", &0u64.to_le_bytes())); + header.extend(build_field(b"conn_count", &0u32.to_le_bytes())); + header.extend(build_field(b"chunk_count", &0u32.to_le_bytes())); + // Padding data (BAG header records often have padding) + build_record(&header, &[0u8; 4]) + } + + /// Build a complete minimal BAG file with magic + header + records. + fn build_bag_file(records: &[Vec]) -> Vec { + let mut out = Vec::new(); + out.extend(b"#ROSBAG V2.0\n"); + out.extend(build_bag_header_record()); + for record in records { + out.extend(record); + } + out + } + + // ========================================================================= + // Basic parser tests + // ========================================================================= + #[test] fn test_parser_new() { let parser = StreamingBagParser::new(); @@ -755,4 +1003,389 @@ mod tests { assert!(parser.channels().is_empty()); assert!(parser.conn_id_map().is_empty()); } + + // ========================================================================= + // Decompress chunk tests + // ========================================================================= + + #[test] + fn test_decompress_chunk_none() { + let data = b"hello world"; + let result = StreamingBagParser::decompress_chunk("none", data).unwrap(); + assert_eq!(result, data); + } + + #[test] + fn test_decompress_chunk_lz4() { + use lz4_flex::compress_prepend_size; + let original = b"hello world this is a test of lz4 compression"; + let compressed = compress_prepend_size(original); + let result = StreamingBagParser::decompress_chunk("lz4", &compressed).unwrap(); + assert_eq!(result, original); + } + + #[test] + fn test_decompress_chunk_bz2() { + use bzip2::Compression; + use bzip2::write::BzEncoder; + use std::io::Write; + + let original = b"hello world this is a test of bz2 compression"; + let mut encoder = BzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(original).unwrap(); + let compressed = encoder.finish().unwrap(); + + let result = StreamingBagParser::decompress_chunk("bz2", &compressed).unwrap(); + assert_eq!(result, original); + } + + #[test] + fn test_decompress_chunk_unsupported() { + let result = StreamingBagParser::decompress_chunk("zstd", b"data"); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("Unsupported BAG chunk compression") + ); + } + + #[test] + fn test_decompress_chunk_lz4_invalid_data() { + let result = StreamingBagParser::decompress_chunk("lz4", b"\x00\x00\x00\x00garbage"); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("LZ4 decompression failed")); + } + + #[test] + fn test_decompress_chunk_bz2_invalid_data() { + let result = StreamingBagParser::decompress_chunk("bz2", b"not-bz2-data"); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("BZ2 decompression failed")); + } + + // ========================================================================= + // Inner record parsing tests + // ========================================================================= + + #[test] + fn test_parse_inner_records_msg_data() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + + // Build inner records: a single message data record + let inner = build_msg_data_record(0, 100, 500, b"payload-data"); + parser.parse_inner_records(&inner, &mut messages).unwrap(); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].conn_id, 0); + assert_eq!(messages[0].log_time, 100 * 1_000_000_000 + 500); + assert_eq!(messages[0].data, b"payload-data"); + } + + #[test] + fn test_parse_inner_records_multiple_messages() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + + let mut inner = Vec::new(); + inner.extend(build_msg_data_record(0, 100, 0, b"msg1")); + inner.extend(build_msg_data_record(1, 200, 0, b"msg2")); + inner.extend(build_msg_data_record(0, 300, 0, b"msg3")); + + parser.parse_inner_records(&inner, &mut messages).unwrap(); + + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].data, b"msg1"); + assert_eq!(messages[1].data, b"msg2"); + assert_eq!(messages[2].data, b"msg3"); + assert_eq!(messages[0].conn_id, 0); + assert_eq!(messages[1].conn_id, 1); + assert_eq!(messages[2].conn_id, 0); + } + + #[test] + fn test_parse_inner_records_connection() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + + let inner = build_connection_record(0, "/camera/image"); + parser.parse_inner_records(&inner, &mut messages).unwrap(); + + assert_eq!(messages.len(), 0); + assert_eq!(parser.connections.len(), 1); + let conn = parser.connections.get(&0).unwrap(); + assert_eq!(conn.topic, "/camera/image"); + assert_eq!(conn.message_type, "std_msgs/String"); + } + + #[test] + fn test_parse_inner_records_connection_and_messages() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/topic_a")); + inner.extend(build_msg_data_record(0, 100, 0, b"data-a")); + inner.extend(build_connection_record(1, "/topic_b")); + inner.extend(build_msg_data_record(1, 200, 0, b"data-b")); + + parser.parse_inner_records(&inner, &mut messages).unwrap(); + + assert_eq!(parser.connections.len(), 2); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].data, b"data-a"); + assert_eq!(messages[1].data, b"data-b"); + } + + #[test] + fn test_parse_inner_records_empty() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + parser.parse_inner_records(&[], &mut messages).unwrap(); + assert_eq!(messages.len(), 0); + } + + #[test] + fn test_parse_inner_records_truncated() { + let mut parser = StreamingBagParser::new(); + let mut messages = Vec::new(); + + // Build a valid record followed by a truncated one + let mut inner = build_msg_data_record(0, 100, 0, b"valid"); + // Append a truncated header (just header_len, no actual data) + inner.extend(&100u32.to_le_bytes()); + + parser.parse_inner_records(&inner, &mut messages).unwrap(); + // Should parse the valid record and stop at the truncated one + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].data, b"valid"); + } + + // ========================================================================= + // End-to-end chunk processing tests + // ========================================================================= + + #[test] + fn test_chunk_uncompressed_end_to_end() { + let mut parser = StreamingBagParser::new(); + + // Build inner records + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/sensor/data")); + inner.extend(build_msg_data_record(0, 1000, 0, b"sensor-reading")); + + // Build the uncompressed chunk record + let chunk = build_chunk_record_none(&inner); + + // Build a complete BAG file + let bag = build_bag_file(&[chunk]); + + // Parse it all in one go + let messages = parser.parse_chunk(&bag).unwrap(); + + assert!(parser.is_initialized()); + assert_eq!(parser.connections.len(), 1); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].data, b"sensor-reading"); + assert_eq!(messages[0].conn_id, 0); + + let channels = parser.channels(); + assert_eq!(channels.len(), 1); + let ch = channels.values().next().unwrap(); + assert_eq!(ch.topic, "/sensor/data"); + } + + #[test] + fn test_chunk_lz4_end_to_end() { + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/lidar/points")); + inner.extend(build_msg_data_record(0, 500, 0, b"point-cloud-data")); + inner.extend(build_msg_data_record(0, 600, 0, b"point-cloud-data-2")); + + let chunk = build_chunk_record_lz4(&inner); + let bag = build_bag_file(&[chunk]); + + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(parser.connections.len(), 1); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].data, b"point-cloud-data"); + assert_eq!(messages[1].data, b"point-cloud-data-2"); + } + + #[test] + fn test_chunk_bz2_end_to_end() { + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/imu/data")); + inner.extend(build_msg_data_record(0, 42, 123, b"imu-reading")); + + let chunk = build_chunk_record_bz2(&inner); + let bag = build_bag_file(&[chunk]); + + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(parser.connections.len(), 1); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].data, b"imu-reading"); + let expected_time = 42u64 * 1_000_000_000 + 123u64; + assert_eq!(messages[0].log_time, expected_time); + } + + #[test] + fn test_multiple_chunks() { + let mut parser = StreamingBagParser::new(); + + // Chunk 1: connection + message + let mut inner1 = Vec::new(); + inner1.extend(build_connection_record(0, "/cam/image")); + inner1.extend(build_msg_data_record(0, 100, 0, b"frame-1")); + let chunk1 = build_chunk_record_none(&inner1); + + // Chunk 2: another connection + messages + let mut inner2 = Vec::new(); + inner2.extend(build_connection_record(1, "/joint/state")); + inner2.extend(build_msg_data_record(0, 200, 0, b"frame-2")); + inner2.extend(build_msg_data_record(1, 200, 0, b"joint-1")); + let chunk2 = build_chunk_record_lz4(&inner2); + + let bag = build_bag_file(&[chunk1, chunk2]); + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(parser.connections.len(), 2); + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].data, b"frame-1"); + assert_eq!(messages[1].data, b"frame-2"); + assert_eq!(messages[2].data, b"joint-1"); + } + + #[test] + fn test_chunk_with_streaming_parser_trait() { + use crate::io::streaming::StreamingParser as _; + + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/topic_a")); + inner.extend(build_msg_data_record(0, 100, 0, b"data-a")); + let chunk = build_chunk_record_none(&inner); + let bag = build_bag_file(&[chunk]); + + // Use the StreamingParser trait method + let messages = StreamingParser::parse_chunk(&mut parser, &bag).unwrap(); + + assert_eq!(messages.len(), 1); + assert!(parser.has_channels()); + let channels = StreamingParser::channels(&parser); + assert_eq!(channels.len(), 1); + assert!(channels.values().any(|c| c.topic == "/topic_a")); + } + + #[test] + fn test_incremental_streaming_across_chunks() { + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/test")); + inner.extend(build_msg_data_record(0, 1, 0, b"msg")); + let chunk = build_chunk_record_none(&inner); + let bag = build_bag_file(&[chunk]); + + // Feed the bag data in small pieces to simulate streaming + let piece_size = 32; + let mut all_messages = Vec::new(); + for piece in bag.chunks(piece_size) { + let msgs = parser.parse_chunk(piece).unwrap(); + all_messages.extend(msgs); + } + + assert!(parser.is_initialized()); + assert_eq!(parser.connections.len(), 1); + assert_eq!(all_messages.len(), 1); + assert_eq!(all_messages[0].data, b"msg"); + } + + #[test] + fn test_top_level_connection_before_chunk() { + let mut parser = StreamingBagParser::new(); + + // In some BAG files, connections appear as top-level records + // (before chunks), then chunks contain only message data. + let conn_record = build_connection_record(0, "/joint_cmd"); + + let mut inner = Vec::new(); + inner.extend(build_msg_data_record(0, 100, 0, b"cmd-data")); + let chunk = build_chunk_record_none(&inner); + + let bag = build_bag_file(&[conn_record, chunk]); + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(parser.connections.len(), 1); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].data, b"cmd-data"); + } + + #[test] + fn test_large_payload_in_chunk() { + let mut parser = StreamingBagParser::new(); + + // Simulate a large image payload + let payload = vec![0xABu8; 1024 * 100]; // 100KB + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/cam/image_raw")); + inner.extend(build_msg_data_record(0, 100, 0, &payload)); + + let chunk = build_chunk_record_lz4(&inner); + let bag = build_bag_file(&[chunk]); + + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].data.len(), 1024 * 100); + assert!(messages[0].data.iter().all(|&b| b == 0xAB)); + } + + #[test] + fn test_message_count_with_chunks() { + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/data")); + for i in 0..10u32 { + inner.extend(build_msg_data_record(0, i, 0, &[i as u8])); + } + let chunk = build_chunk_record_none(&inner); + let bag = build_bag_file(&[chunk]); + + let messages = parser.parse_chunk(&bag).unwrap(); + + assert_eq!(messages.len(), 10); + assert_eq!(parser.message_count(), 10); + } + + #[test] + fn test_reset_clears_chunk_state() { + let mut parser = StreamingBagParser::new(); + + let mut inner = Vec::new(); + inner.extend(build_connection_record(0, "/data")); + inner.extend(build_msg_data_record(0, 1, 0, b"msg")); + let chunk = build_chunk_record_none(&inner); + let bag = build_bag_file(&[chunk]); + + let messages = parser.parse_chunk(&bag).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(parser.connections.len(), 1); + + parser.reset(); + assert_eq!(parser.message_count(), 0); + assert!(!parser.is_initialized()); + assert!(!parser.has_connections()); + } } diff --git a/src/io/formats/bag/writer.rs b/src/io/formats/bag/writer.rs index 1a62256..36a438c 100644 --- a/src/io/formats/bag/writer.rs +++ b/src/io/formats/bag/writer.rs @@ -359,8 +359,12 @@ impl BagWriter { self.connections_written_to_chunk.insert(conn_id); } - // Calculate message offset within the chunk data (for index lookups) - let offset = self.chunk_buffer.len() - self.current_chunk_position; + // Calculate message offset within the chunk data (for index lookups). + // The offset must be relative to the start of chunk DATA content, + // not the chunk record. Subtract the chunk header length because + // chunk_buffer includes the placeholder header at current_chunk_position. + let chunk_header_len = Self::chunk_header_length(); + let offset = self.chunk_buffer.len() - self.current_chunk_position - chunk_header_len; // Write message data record header let _header_len = @@ -624,6 +628,11 @@ impl BagWriter { } /// Write connection record to buffer. + /// + /// ROS1 bag connection record format: + /// - Header: `op=0x07`, `conn=`, `topic=` + /// - Data: `topic=`, `type=`, `md5sum=`, + /// `message_definition=`, optionally `callerid=`, `latching=<0|1>` fn write_connection_record_to_buffer(buffer: &mut Vec, conn: &ConnectionInfo) { // Connection header let mut fields = BTreeMap::new(); @@ -634,14 +643,19 @@ impl BagWriter { Self::write_header(buffer, &fields); // Connection data (nested header with type info) + // The `topic` field is required in the data section per the ROS bag spec. let mut data_fields = BTreeMap::new(); + data_fields.insert("topic".to_string(), conn.topic.as_bytes().to_vec()); data_fields.insert("type".to_string(), conn.datatype.as_bytes().to_vec()); data_fields.insert("md5sum".to_string(), conn.md5sum.as_bytes().to_vec()); data_fields.insert( "message_definition".to_string(), conn.message_definition.as_bytes().to_vec(), ); - if let Some(ref callerid) = conn.callerid { + // Only include callerid if it's non-empty (a bare "/" is not a valid callerid) + if let Some(ref callerid) = conn.callerid + && !callerid.is_empty() + { // Ensure callerid has leading slash like the original ROS bag format let callerid_with_slash = if callerid.starts_with('/') { callerid.clone() diff --git a/src/io/formats/mcap/reader.rs b/src/io/formats/mcap/reader.rs index 7de3d5d..785d6b2 100644 --- a/src/io/formats/mcap/reader.rs +++ b/src/io/formats/mcap/reader.rs @@ -266,6 +266,26 @@ impl FormatReader for McapReader { Ok(Box::new(stream)) } + fn iter_raw_boxed(&self) -> crate::core::Result> { + let raw_iter = self.iter_raw()?; + let stream = raw_iter.stream()?; + // Convert MCAP-specific RawMessage to unified RawMessage + Ok(Box::new(stream.map(|result| { + result.map(|(msg, ch)| { + ( + crate::io::metadata::RawMessage { + channel_id: msg.channel_id, + log_time: msg.log_time, + publish_time: msg.publish_time, + data: msg.data, + sequence: msg.sequence, + }, + ch, + ) + }) + }))) + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/src/io/reader/mod.rs b/src/io/reader/mod.rs index 430c5d7..01253c8 100644 --- a/src/io/reader/mod.rs +++ b/src/io/reader/mod.rs @@ -336,6 +336,36 @@ impl RoboReader { Ok(DecodedMessageIter::new(boxed_iter)) } + /// Iterate over raw (undecoded) messages. + /// + /// Returns a boxed iterator that yields raw messages with their channel + /// information. Messages are not decoded - they contain raw bytes as + /// stored in the file. + /// + /// This is useful for operations that need to copy or filter messages + /// without the overhead of decoding (e.g., extracting subsets of data). + /// + /// # Example + /// + /// ```rust,no_run + /// # use robocodec::io::RoboReader; + /// # fn test() -> Result<(), Box> { + /// let reader = RoboReader::open("data.bag")?; + /// for result in reader.iter_raw()? { + /// let (raw_msg, channel) = result?; + /// println!("Topic: {}, data size: {}", channel.topic, raw_msg.data.len()); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// + /// Returns an error if the format does not support raw iteration. + pub fn iter_raw(&self) -> Result> { + self.inner.iter_raw_boxed() + } + /// Get the file information as a unified struct. #[must_use] pub fn file_info(&self) -> crate::io::metadata::FileInfo { diff --git a/src/io/s3/signer.rs b/src/io/s3/signer.rs index 871f547..252ad5f 100644 --- a/src/io/s3/signer.rs +++ b/src/io/s3/signer.rs @@ -8,23 +8,6 @@ use crate::io::s3::config::AwsCredentials; use http::{HeaderMap, HeaderValue, Method, Uri}; use std::time::{SystemTime, UNIX_EPOCH}; -/// Sign an HTTP request with AWS `SigV4`. -/// -/// This function adds the necessary AWS Signature Version 4 headers to authenticate -/// requests to AWS S3 or compatible services. -/// -/// # Arguments -/// -/// * `credentials` - AWS credentials (access key ID, secret access key, optional token) -/// * `region` - AWS region (e.g., "us-east-1") -/// * `service` - AWS service name (typically "s3") -/// * `method` - HTTP method -/// * `uri` - Request URI -/// * `headers` - Existing request headers (will be modified in-place) -/// -/// # Returns -/// -/// Ok(()) if signing succeeded, Err otherwise. pub fn sign_request( credentials: &AwsCredentials, region: &str, @@ -32,6 +15,30 @@ pub fn sign_request( method: &Method, uri: &Uri, headers: &mut HeaderMap, +) -> Result<(), Box> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|_| "System time is before UNIX epoch")?; + sign_request_at( + credentials, + region, + service, + method, + uri, + headers, + now.as_secs(), + ) +} + +/// Core signing logic with an explicit timestamp for testability. +fn sign_request_at( + credentials: &AwsCredentials, + region: &str, + service: &str, + method: &Method, + uri: &Uri, + headers: &mut HeaderMap, + timestamp_secs: u64, ) -> Result<(), Box> { let access_key = credentials.access_key_id(); let secret_key = credentials.secret_access_key(); @@ -41,20 +48,28 @@ pub fn sign_request( return Err("Credentials are empty".into()); } - // Get current timestamp - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|_| "System time is before UNIX epoch")?; - let amz_date = format_amz_date(now.as_secs()); + let amz_date = format_amz_date(timestamp_secs); let date_stamp = &amz_date[..8]; - // Extract host from URI - let host = uri.host().ok_or("Invalid URI: missing host")?.to_string(); + // Extract host from URI - include port for non-standard ports + // This matches what reqwest sends: host:port for non-standard ports + let host = if let Some(port) = uri.port_u16() { + // Include port if explicitly specified (e.g., 127.0.0.1:9000) + format!( + "{}:{}", + uri.host().ok_or("Invalid URI: missing host")?, + port + ) + } else { + // For implicit ports (443 for https, 80 for http), use host only + uri.host().ok_or("Invalid URI: missing host")?.to_string() + }; + + // Canonical URI is just the path (URL-encoded), without query string + let canonical_uri = uri.path(); - // Build the path and query string - let path = uri.path(); - let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default(); - let canonical_uri = &format!("{path}{query}"); + // Canonical query string: sorted key=value pairs without leading '?' + let canonical_query_string = build_canonical_query_string(uri.query().unwrap_or("")); // Set required headers headers.insert("Host", HeaderValue::from_str(&host)?); @@ -71,9 +86,6 @@ pub fn sign_request( headers.insert("x-amz-security-token", HeaderValue::from_str(token)?); } - // Create canonical query string (empty for our use case) - let canonical_query_string = ""; - // Create canonical headers let canonical_headers = format_canonical_headers(headers); @@ -87,14 +99,24 @@ pub fn sign_request( signed_headers }; - // Create canonical request + // Build the payload hash (we always use UNSIGNED-PAYLOAD for streaming) + let payload_hash = "UNSIGNED-PAYLOAD"; + + // Create canonical request per AWS SigV4 spec (6 components): + // HTTPMethod \n CanonicalURI \n CanonicalQueryString \n + // CanonicalHeaders \n SignedHeaders \n HashedPayload + // + // Note: canonical_headers already ends with '\n' (one per header line). + // The format's '\n' between canonical_headers and signed_headers creates + // the required blank line separator per the AWS spec. let canonical_request = format!( - "{}\n{}\n{}\n{}\n{}", + "{}\n{}\n{}\n{}\n{}\n{}", method.as_str(), canonical_uri, canonical_query_string, canonical_headers, - signed_headers + signed_headers, + payload_hash, ); // Create string to sign @@ -152,6 +174,36 @@ fn format_canonical_headers(headers: &HeaderMap) -> String { canonical_headers } +/// Build the canonical query string per AWS SigV4 spec. +/// +/// Parameters are sorted by name, URI-encoded, and joined with '&'. +/// Empty query returns an empty string. +fn build_canonical_query_string(query: &str) -> String { + if query.is_empty() { + return String::new(); + } + + let mut params: Vec<(&str, &str)> = query + .split('&') + .filter(|s| !s.is_empty()) + .map(|pair| { + let mut parts = pair.splitn(2, '='); + let key = parts.next().unwrap_or(""); + let value = parts.next().unwrap_or(""); + (key, value) + }) + .collect(); + + // Sort by parameter name, then by value + params.sort_by(|a, b| a.0.cmp(b.0).then(a.1.cmp(b.1))); + + params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&") +} + /// Calculate SHA-256 hash and return as hex string. pub(crate) fn hex_sha256(data: &[u8]) -> String { use sha2::{Digest, Sha256}; @@ -222,6 +274,30 @@ fn calculate_signature( hex::encode(result.into_bytes()) } +/// Build the canonical request string (exposed for testing). +#[cfg(test)] +fn build_canonical_request( + method: &Method, + uri: &Uri, + headers: &HeaderMap, + signed_headers: &str, + payload_hash: &str, +) -> String { + let canonical_uri = uri.path(); + let canonical_query_string = build_canonical_query_string(uri.query().unwrap_or("")); + let canonical_headers = format_canonical_headers(headers); + + format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method.as_str(), + canonical_uri, + canonical_query_string, + canonical_headers, + signed_headers, + payload_hash, + ) +} + /// Check if we have valid credentials that should be used for signing. #[must_use] pub fn should_sign(credentials: &AwsCredentials) -> bool { @@ -234,9 +310,20 @@ mod tests { use crate::io::s3::config::AwsCredentials; use std::str::FromStr; + // ── Helper constants for deterministic tests ────────────────────── + const TEST_ACCESS_KEY: &str = "AKIAIOSFODNN7EXAMPLE"; + const TEST_SECRET_KEY: &str = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + // 2025-01-01T00:00:00Z + const TEST_TIMESTAMP: u64 = 1735689600; + + fn test_creds() -> AwsCredentials { + AwsCredentials::new(TEST_ACCESS_KEY, TEST_SECRET_KEY).unwrap() + } + + // ── Credential validation tests ────────────────────────────────── + #[test] fn test_should_sign_none_credentials() { - // AwsCredentials::new returns None for empty keys let result = AwsCredentials::new("", ""); assert!(result.is_none()); } @@ -247,15 +334,40 @@ mod tests { assert!(should_sign(&creds)); } + #[test] + fn test_should_sign_empty_access_key() { + let creds = AwsCredentials::new("", "secret"); + assert!(creds.is_none()); + } + + #[test] + fn test_should_sign_empty_secret_key() { + let creds = AwsCredentials::new("key", ""); + assert!(creds.is_none()); + } + + #[test] + fn test_sign_request_empty_credentials() { + let creds = AwsCredentials::new("", ""); + assert!(creds.is_none()); + } + + // ── Utility function tests ─────────────────────────────────────── + #[test] fn test_format_amz_date() { - let date = format_amz_date(1735689600); // 2025-01-01 00:00:00 UTC - assert!(date.starts_with("20250101")); - assert!(date.ends_with("Z")); + let date = format_amz_date(TEST_TIMESTAMP); + assert_eq!(date, "20250101T000000Z"); } #[test] - fn test_hex_sha256() { + fn test_format_amz_date_epoch() { + let date = format_amz_date(0); + assert_eq!(date, "19700101T000000Z"); + } + + #[test] + fn test_hex_sha256_empty() { let hash = hex_sha256(b""); assert_eq!( hash, @@ -272,66 +384,440 @@ mod tests { ); } + // ── Canonical query string tests ───────────────────────────────── + #[test] - fn test_should_sign_empty_access_key() { - // Empty access key means new() returns None - let creds = AwsCredentials::new("", "secret"); - assert!(creds.is_none()); + fn test_build_canonical_query_string_empty() { + assert_eq!(build_canonical_query_string(""), ""); } #[test] - fn test_should_sign_empty_secret_key() { - // Empty secret key means new() returns None - let creds = AwsCredentials::new("key", ""); - assert!(creds.is_none()); + fn test_build_canonical_query_string_single_param() { + assert_eq!( + build_canonical_query_string("versionId=123"), + "versionId=123" + ); } #[test] - fn test_format_amz_date_epoch() { - let date = format_amz_date(0); // 1970-01-01 00:00:00 UTC - assert!(date.starts_with("19700101")); - assert!(date.ends_with("Z")); + fn test_build_canonical_query_string_sorted() { + // Parameters must be sorted by key name + let result = build_canonical_query_string("z=1&a=2&m=3"); + assert_eq!(result, "a=2&m=3&z=1"); } #[test] - fn test_sign_request_valid_credentials() { - let creds = AwsCredentials::new( - "AKIAIOSFODNN7EXAMPLE", - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + fn test_build_canonical_query_string_same_key() { + // Same key, sorted by value + let result = build_canonical_query_string("key=b&key=a"); + assert_eq!(result, "key=a&key=b"); + } + + // ── Canonical headers tests ────────────────────────────────────── + + #[test] + fn test_format_canonical_headers_order_and_format() { + let mut headers = HeaderMap::new(); + headers.insert("Host", HeaderValue::from_static("example.com")); + headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); + headers.insert( + "x-amz-content-sha256", + HeaderValue::from_static("UNSIGNED-PAYLOAD"), + ); + + let canonical = format_canonical_headers(&headers); + + // Must be in alphabetical order: host, x-amz-content-sha256, x-amz-date + let expected = "host:example.com\n\ + x-amz-content-sha256:UNSIGNED-PAYLOAD\n\ + x-amz-date:20250101T000000Z\n"; + assert_eq!(canonical, expected); + } + + #[test] + fn test_format_canonical_headers_with_session_token() { + let mut headers = HeaderMap::new(); + headers.insert("Host", HeaderValue::from_static("example.com")); + headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); + headers.insert( + "x-amz-content-sha256", + HeaderValue::from_static("UNSIGNED-PAYLOAD"), + ); + headers.insert("x-amz-security-token", HeaderValue::from_static("my-token")); + + let canonical = format_canonical_headers(&headers); + assert!(canonical.contains("x-amz-security-token:my-token\n")); + } + + // ── Canonical request structure tests ───────────────────────────── + + #[test] + fn test_canonical_request_has_six_components() { + let uri = Uri::from_str("https://bucket.s3.amazonaws.com/key").unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("Host", HeaderValue::from_static("bucket.s3.amazonaws.com")); + headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); + headers.insert( + "x-amz-content-sha256", + HeaderValue::from_static("UNSIGNED-PAYLOAD"), + ); + + let signed_headers = "host;x-amz-content-sha256;x-amz-date"; + let cr = build_canonical_request( + &Method::GET, + &uri, + &headers, + signed_headers, + "UNSIGNED-PAYLOAD", + ); + + // SigV4 canonical request must have exactly 6 lines + // (canonical_headers contributes 3 lines ending with \n, plus 3 other lines) + let lines: Vec<&str> = cr.split('\n').collect(); + // Structure: + // 0: GET + // 1: /key + // 2: (empty - canonical query string) + // 3: host:bucket.s3.amazonaws.com + // 4: x-amz-content-sha256:UNSIGNED-PAYLOAD + // 5: x-amz-date:20250101T000000Z + // 6: (empty - trailing \n from last header) + // 7: host;x-amz-content-sha256;x-amz-date + // 8: UNSIGNED-PAYLOAD + assert_eq!(lines.len(), 9, "canonical request line count: {lines:?}"); + assert_eq!(lines[0], "GET"); + assert_eq!(lines[1], "/key"); + assert_eq!(lines[2], ""); // empty canonical query string + assert_eq!(lines[3], "host:bucket.s3.amazonaws.com"); + assert_eq!(lines[4], "x-amz-content-sha256:UNSIGNED-PAYLOAD"); + assert_eq!(lines[5], "x-amz-date:20250101T000000Z"); + assert_eq!(lines[6], ""); // blank line after canonical headers + assert_eq!(lines[7], "host;x-amz-content-sha256;x-amz-date"); + assert_eq!(lines[8], "UNSIGNED-PAYLOAD"); // payload hash! + } + + #[test] + fn test_canonical_request_payload_hash_present() { + let uri = Uri::from_str("https://bucket.s3.amazonaws.com/key").unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("Host", HeaderValue::from_static("bucket.s3.amazonaws.com")); + headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); + headers.insert( + "x-amz-content-sha256", + HeaderValue::from_static("UNSIGNED-PAYLOAD"), + ); + + let signed_headers = "host;x-amz-content-sha256;x-amz-date"; + let cr = build_canonical_request( + &Method::GET, + &uri, + &headers, + signed_headers, + "UNSIGNED-PAYLOAD", + ); + + // The canonical request MUST end with the payload hash + assert!( + cr.ends_with("UNSIGNED-PAYLOAD"), + "canonical request must end with payload hash, got: ...{}", + &cr[cr.len().saturating_sub(50)..] + ); + } + + #[test] + fn test_canonical_request_query_string_not_in_uri() { + // URI with query string: query must appear in canonical query string field, + // NOT appended to the canonical URI + let uri = Uri::from_str("https://bucket.s3.amazonaws.com/key?versionId=123").unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("Host", HeaderValue::from_static("bucket.s3.amazonaws.com")); + headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); + headers.insert( + "x-amz-content-sha256", + HeaderValue::from_static("UNSIGNED-PAYLOAD"), + ); + + let signed_headers = "host;x-amz-content-sha256;x-amz-date"; + let cr = build_canonical_request( + &Method::GET, + &uri, + &headers, + signed_headers, + "UNSIGNED-PAYLOAD", + ); + + let lines: Vec<&str> = cr.split('\n').collect(); + // Line 1 = canonical URI (path only, no query) + assert_eq!( + lines[1], "/key", + "canonical URI must not contain query string" + ); + // Line 2 = canonical query string + assert_eq!( + lines[2], "versionId=123", + "query string must be in canonical query string field" + ); + } + + // ── Host header with port tests ────────────────────────────────── + + #[test] + fn test_sign_request_host_includes_non_standard_port() { + let creds = test_creds(); + let uri = Uri::from_str("http://127.0.0.1:9000/bucket/key").unwrap(); + let mut headers = HeaderMap::new(); + + let result = sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers, + TEST_TIMESTAMP, + ); + assert!(result.is_ok()); + + let host = headers.get("Host").unwrap().to_str().unwrap(); + assert_eq!( + host, "127.0.0.1:9000", + "Host must include non-standard port" + ); + } + + #[test] + fn test_sign_request_host_excludes_standard_port() { + let creds = test_creds(); + let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + let mut headers = HeaderMap::new(); + + let result = sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers, + TEST_TIMESTAMP, + ); + assert!(result.is_ok()); + + let host = headers.get("Host").unwrap().to_str().unwrap(); + assert_eq!(host, "examplebucket.s3.amazonaws.com"); + } + + // ── Deterministic signature tests ──────────────────────────────── + + #[test] + fn test_sign_request_deterministic_no_query() { + let creds = test_creds(); + let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + + // Sign the same request twice at the same timestamp + let mut headers1 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers1, + TEST_TIMESTAMP, + ) + .unwrap(); + + let mut headers2 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers2, + TEST_TIMESTAMP, + ) + .unwrap(); + + let auth1 = headers1.get("Authorization").unwrap().to_str().unwrap(); + let auth2 = headers2.get("Authorization").unwrap().to_str().unwrap(); + assert_eq!( + auth1, auth2, + "same inputs at same timestamp must produce same signature" + ); + } + + #[test] + fn test_sign_request_deterministic_with_query() { + let creds = test_creds(); + let uri = + Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt?versionId=abc").unwrap(); + + let mut headers1 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers1, + TEST_TIMESTAMP, + ) + .unwrap(); + + let mut headers2 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers2, + TEST_TIMESTAMP, + ) + .unwrap(); + + let auth1 = headers1.get("Authorization").unwrap().to_str().unwrap(); + let auth2 = headers2.get("Authorization").unwrap().to_str().unwrap(); + assert_eq!(auth1, auth2); + } + + #[test] + fn test_sign_request_different_timestamps_differ() { + let creds = test_creds(); + let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + + let mut headers1 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers1, + TEST_TIMESTAMP, ) .unwrap(); + let mut headers2 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers2, + TEST_TIMESTAMP + 1, + ) + .unwrap(); + + let auth1 = headers1.get("Authorization").unwrap().to_str().unwrap(); + let auth2 = headers2.get("Authorization").unwrap().to_str().unwrap(); + assert_ne!( + auth1, auth2, + "different timestamps must produce different signatures" + ); + } + + #[test] + fn test_sign_request_query_string_affects_signature() { + let creds = test_creds(); + + let uri_no_q = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + let uri_with_q = + Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt?versionId=1").unwrap(); + + let mut h1 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri_no_q, + &mut h1, + TEST_TIMESTAMP, + ) + .unwrap(); + + let mut h2 = HeaderMap::new(); + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri_with_q, + &mut h2, + TEST_TIMESTAMP, + ) + .unwrap(); + + let auth1 = h1.get("Authorization").unwrap().to_str().unwrap(); + let auth2 = h2.get("Authorization").unwrap().to_str().unwrap(); + assert_ne!(auth1, auth2, "query string must affect the signature"); + } + + // ── AWS SigV4 reference test ───────────────────────────────────── + // Validates the canonical request format against the AWS specification. + // Reference: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + + #[test] + fn test_sigv4_canonical_request_format_aws_spec() { + // Use the well-known AWS example credentials and fixed timestamp + let creds = test_creds(); + let uri = + Uri::from_str("https://examplebucket.s3.amazonaws.com/photos/photo1.jpg").unwrap(); + let mut headers = HeaderMap::new(); + + sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers, + TEST_TIMESTAMP, + ) + .unwrap(); + + // Verify the Authorization header has the correct structure + let auth = headers.get("Authorization").unwrap().to_str().unwrap(); + assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=")); + assert!(auth.contains("SignedHeaders=host;x-amz-content-sha256;x-amz-date")); + assert!(auth.contains("Signature=")); + + // Verify credential scope + assert!(auth.contains("20250101/us-east-1/s3/aws4_request")); + + // Verify x-amz-content-sha256 is set to UNSIGNED-PAYLOAD + let content_sha = headers + .get("x-amz-content-sha256") + .unwrap() + .to_str() + .unwrap(); + assert_eq!(content_sha, "UNSIGNED-PAYLOAD"); + } + + // ── sign_request integration tests ─────────────────────────────── + + #[test] + fn test_sign_request_valid_credentials() { + let creds = test_creds(); let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); let mut headers = HeaderMap::new(); let result = sign_request(&creds, "us-east-1", "s3", &Method::GET, &uri, &mut headers); assert!(result.is_ok()); - // Check that required headers were added assert!(headers.contains_key("Authorization")); assert!(headers.contains_key("x-amz-date")); assert!(headers.contains_key("x-amz-content-sha256")); - // Authorization header should contain our access key let auth = headers.get("Authorization").unwrap().to_str().unwrap(); - assert!(auth.contains("AKIAIOSFODNN7EXAMPLE")); - } - - #[test] - fn test_sign_request_empty_credentials() { - // Empty credentials result in None from new() - let creds = AwsCredentials::new("", ""); - assert!(creds.is_none()); + assert!(auth.contains(TEST_ACCESS_KEY)); } #[test] fn test_sign_request_with_session_token() { - let creds = AwsCredentials::new( - "AKIAIOSFODNN7EXAMPLE", - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - ) - .unwrap() - .with_session_token("session_token"); + let creds = test_creds().with_session_token("session_token"); let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); let mut headers = HeaderMap::new(); @@ -339,22 +825,15 @@ mod tests { let result = sign_request(&creds, "us-east-1", "s3", &Method::GET, &uri, &mut headers); assert!(result.is_ok()); - // Check that session token header was added assert!(headers.contains_key("x-amz-security-token")); - // Authorization header should include security-token in signed headers let auth = headers.get("Authorization").unwrap().to_str().unwrap(); assert!(auth.contains("x-amz-security-token")); } #[test] fn test_sign_request_with_query_string() { - let creds = AwsCredentials::new( - "AKIAIOSFODNN7EXAMPLE", - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - ) - .unwrap(); - + let creds = test_creds(); let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt?versionId=123").unwrap(); let mut headers = HeaderMap::new(); @@ -365,12 +844,7 @@ mod tests { #[test] fn test_sign_request_post_method() { - let creds = AwsCredentials::new( - "AKIAIOSFODNN7EXAMPLE", - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - ) - .unwrap(); - + let creds = test_creds(); let uri = Uri::from_str("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); let mut headers = HeaderMap::new(); @@ -378,41 +852,47 @@ mod tests { assert!(result.is_ok()); let auth = headers.get("Authorization").unwrap().to_str().unwrap(); - assert!(auth.contains("AKIAIOSFODNN7EXAMPLE")); + assert!(auth.contains(TEST_ACCESS_KEY)); } - #[test] - fn test_format_canonical_headers() { - let mut headers = HeaderMap::new(); - headers.insert("Host", HeaderValue::from_static("example.com")); - headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); - headers.insert( - "x-amz-content-sha256", - HeaderValue::from_static("UNSIGNED-PAYLOAD"), - ); - - let canonical = format_canonical_headers(&headers); - - // Check headers are in correct order and format - assert!(canonical.contains("host:example.com\n")); - assert!(canonical.contains("x-amz-content-sha256:UNSIGNED-PAYLOAD\n")); - assert!(canonical.contains("x-amz-date:20250101T000000Z\n")); - } + // ── MinIO path-style URL test ──────────────────────────────────── #[test] - fn test_format_canonical_headers_with_session_token() { + fn test_sign_request_minio_path_style() { + let creds = AwsCredentials::new("minioadmin", "minioadmin").unwrap(); + let uri = Uri::from_str("http://127.0.0.1:9000/mybucket/mykey.bag").unwrap(); let mut headers = HeaderMap::new(); - headers.insert("Host", HeaderValue::from_static("example.com")); - headers.insert("x-amz-date", HeaderValue::from_static("20250101T000000Z")); - headers.insert( - "x-amz-content-sha256", - HeaderValue::from_static("UNSIGNED-PAYLOAD"), + + let result = sign_request_at( + &creds, + "us-east-1", + "s3", + &Method::GET, + &uri, + &mut headers, + TEST_TIMESTAMP, ); - headers.insert("x-amz-security-token", HeaderValue::from_static("my-token")); + assert!(result.is_ok()); - let canonical = format_canonical_headers(&headers); + // Host must include port for MinIO + let host = headers.get("Host").unwrap().to_str().unwrap(); + assert_eq!(host, "127.0.0.1:9000"); - // Session token should be included at the end - assert!(canonical.contains("x-amz-security-token:my-token\n")); + // Authorization header must be well-formed + let auth = headers.get("Authorization").unwrap().to_str().unwrap(); + assert!(auth.starts_with( + "AWS4-HMAC-SHA256 Credential=minioadmin/20250101/us-east-1/s3/aws4_request" + )); + assert!(auth.contains("SignedHeaders=host;x-amz-content-sha256;x-amz-date")); + assert!(auth.contains("Signature=")); + + // Verify signature is a 64-char hex string + let sig_start = auth.find("Signature=").unwrap() + "Signature=".len(); + let signature = &auth[sig_start..]; + assert_eq!(signature.len(), 64, "signature must be 64 hex chars"); + assert!( + signature.chars().all(|c| c.is_ascii_hexdigit()), + "signature must be hex: {signature}" + ); } } diff --git a/src/io/traits.rs b/src/io/traits.rs index 4a99c24..46f7cae 100644 --- a/src/io/traits.rs +++ b/src/io/traits.rs @@ -18,6 +18,13 @@ use super::metadata::{ChannelInfo, FileInfo, RawMessage, TimestampedDecodedMessa // Re-export filter types use super::filter::TopicFilter; +/// A boxed iterator over raw (undecoded) messages with channel info. +/// +/// This type alias simplifies the complex return type used by +/// `FormatReader::iter_raw_boxed()` and `RoboReader::iter_raw()`. +pub type RawMessageIter<'a> = + Box> + Send + 'a>; + /// Trait for iterating over decoded messages with timestamps. /// /// This trait abstracts over format-specific iterator implementations, @@ -317,6 +324,28 @@ pub trait FormatReader: Send + Sync { )) } + /// Iterate over raw (undecoded) messages as a boxed iterator. + /// + /// This method provides a trait-based approach for raw message iteration, + /// allowing format readers to provide raw messages without exposing + /// concrete iterator types. + /// + /// The default implementation returns an error. Format-specific readers + /// should override this method to provide their implementation. + /// + /// # Returns + /// + /// A boxed iterator yielding `(RawMessage, ChannelInfo)` tuples. + /// + /// # Errors + /// + /// Returns an error if the format reader does not support raw iteration. + fn iter_raw_boxed(&self) -> Result> { + Err(CodecError::unsupported( + "iter_raw_boxed() not supported for this format reader", + )) + } + /// Downcast to `Any` for accessing format-specific functionality. fn as_any(&self) -> &dyn Any; diff --git a/tests/fixtures/robocodec_test_15.bag b/tests/fixtures/robocodec_test_15.bag index 4f55f98..b8aded8 100644 Binary files a/tests/fixtures/robocodec_test_15.bag and b/tests/fixtures/robocodec_test_15.bag differ diff --git a/tests/fixtures/robocodec_test_17.bag b/tests/fixtures/robocodec_test_17.bag index 711ac74..92cea17 100644 Binary files a/tests/fixtures/robocodec_test_17.bag and b/tests/fixtures/robocodec_test_17.bag differ diff --git a/tests/fixtures/robocodec_test_18.bag b/tests/fixtures/robocodec_test_18.bag index a59a9b0..3c99ee1 100644 Binary files a/tests/fixtures/robocodec_test_18.bag and b/tests/fixtures/robocodec_test_18.bag differ diff --git a/tests/fixtures/robocodec_test_19.bag b/tests/fixtures/robocodec_test_19.bag index b76be8a..9a70e1b 100644 Binary files a/tests/fixtures/robocodec_test_19.bag and b/tests/fixtures/robocodec_test_19.bag differ diff --git a/tests/fixtures/robocodec_test_20.bag b/tests/fixtures/robocodec_test_20.bag index cef5b2e..f34c974 100644 Binary files a/tests/fixtures/robocodec_test_20.bag and b/tests/fixtures/robocodec_test_20.bag differ diff --git a/tests/fixtures/robocodec_test_21.bag b/tests/fixtures/robocodec_test_21.bag index f83a4f2..dadc7b3 100644 Binary files a/tests/fixtures/robocodec_test_21.bag and b/tests/fixtures/robocodec_test_21.bag differ diff --git a/tests/fixtures/robocodec_test_22.bag b/tests/fixtures/robocodec_test_22.bag index 904a771..ed26e83 100644 Binary files a/tests/fixtures/robocodec_test_22.bag and b/tests/fixtures/robocodec_test_22.bag differ diff --git a/tests/fixtures/robocodec_test_23.bag b/tests/fixtures/robocodec_test_23.bag index a46312a..eadb799 100644 Binary files a/tests/fixtures/robocodec_test_23.bag and b/tests/fixtures/robocodec_test_23.bag differ diff --git a/tests/property/value_properties.proptest-regressions b/tests/property/value_properties.proptest-regressions index 5f3ca3b..a1fb724 100644 --- a/tests/property/value_properties.proptest-regressions +++ b/tests/property/value_properties.proptest-regressions @@ -5,3 +5,4 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc ef16048e1a8ca3336e5b72faa2e945ac09df08ecdf42c305b334b43d8138312a # shrinks to secs = -2, nanos = 2000000001 +cc 7aa5ce64e5c2553b325f3fc4c3447608e8c05e40afee13e9ab0a166181c3bc8c # shrinks to secs = -1, nanos = 1000000001 diff --git a/tests/property/value_properties.rs b/tests/property/value_properties.rs index 3c54c49..542b505 100644 --- a/tests/property/value_properties.rs +++ b/tests/property/value_properties.rs @@ -178,9 +178,9 @@ proptest! { prop_assert!(total_nanos >= 0); } - /// Property: Duration can be negative + /// Property: Duration can be negative when secs < 0 and nanos in [0, 1e9) #[test] - fn prop_duration_can_be_negative(secs in -1000i32..0, nanos in any::()) { + fn prop_duration_can_be_negative(secs in -1000i32..0i32, nanos in 0i32..1_000_000_000i32) { let dur = CodecValue::duration_from_secs_nanos(secs, nanos); let total_nanos = dur.as_duration_nanos().unwrap(); diff --git a/tests/test_bag_stream.rs b/tests/test_bag_stream.rs index 1e39bc2..d573cdd 100644 --- a/tests/test_bag_stream.rs +++ b/tests/test_bag_stream.rs @@ -4,11 +4,14 @@ //! Integration tests for BAG streaming parser. +#[cfg(feature = "remote")] +use robocodec::io::formats::bag::StreamingBagParser; #[cfg(feature = "remote")] use robocodec::io::s3::{ BAG_MAGIC_PREFIX, BagMessageRecord, BagRecordFields, BagRecordHeader, FatalError, - StreamingBagParser, }; +#[cfg(feature = "remote")] +use std::path::Path; #[cfg(feature = "remote")] #[test] @@ -162,3 +165,260 @@ fn test_bag_stream_record_fields_default() { assert!(fields.time.is_none()); assert!(fields.topic.is_none()); } + +// ========================================================================= +// Real fixture file tests - feed actual .bag files through StreamingBagParser +// ========================================================================= + +#[cfg(feature = "remote")] +/// Helper: read a bag fixture file and parse it through the streaming parser. +/// Returns (total_messages, num_connections, parser). +fn parse_fixture_bag(filename: &str) -> (Vec, usize, StreamingBagParser) { + let path = format!("tests/fixtures/{filename}"); + assert!(Path::new(&path).exists(), "Fixture file not found: {path}"); + + let data = std::fs::read(&path).unwrap(); + let mut parser = StreamingBagParser::new(); + + // Feed the entire file in 256KB chunks to simulate streaming + let chunk_size = 256 * 1024; + let mut all_messages = Vec::new(); + + for piece in data.chunks(chunk_size) { + let msgs = parser + .parse_chunk(piece) + .unwrap_or_else(|e| panic!("Failed to parse {filename}: {e}")); + all_messages.extend(msgs); + } + + let num_connections = parser.channels().len(); + (all_messages, num_connections, parser) +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_15_streaming() { + let (messages, num_channels, parser) = parse_fixture_bag("robocodec_test_15.bag"); + + assert!(parser.is_initialized()); + assert_eq!(parser.version(), Some("2.0")); + assert!(num_channels > 0, "Expected at least 1 channel, got 0"); + assert!( + !messages.is_empty(), + "Expected messages from robocodec_test_15.bag, got 0" + ); + assert_eq!(parser.message_count(), messages.len() as u64); + + // Verify all messages have valid conn_id that maps to a known connection + let channels = parser.channels(); + for msg in &messages { + assert!( + channels.contains_key(&(msg.conn_id as u16)), + "Message references unknown conn_id {}", + msg.conn_id + ); + } + + println!( + "robocodec_test_15.bag: {} messages, {} channels", + messages.len(), + num_channels + ); + for (id, ch) in &channels { + println!( + " channel {id}: topic={}, type={}", + ch.topic, ch.message_type + ); + } +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_18_streaming() { + // Smaller file (887K), good for quick validation + let (messages, num_channels, parser) = parse_fixture_bag("robocodec_test_18.bag"); + + assert!(parser.is_initialized()); + assert!(num_channels > 0, "Expected at least 1 channel"); + assert!( + !messages.is_empty(), + "Expected messages from robocodec_test_18.bag, got 0" + ); + + println!( + "robocodec_test_18.bag: {} messages, {} channels", + messages.len(), + num_channels + ); +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_19_streaming() { + let (messages, num_channels, parser) = parse_fixture_bag("robocodec_test_19.bag"); + + assert!(parser.is_initialized()); + assert!(num_channels > 0); + assert!(!messages.is_empty()); + + println!( + "robocodec_test_19.bag: {} messages, {} channels", + messages.len(), + num_channels + ); +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_23_streaming() { + let (messages, num_channels, parser) = parse_fixture_bag("robocodec_test_23.bag"); + + assert!(parser.is_initialized()); + assert!(num_channels > 0); + assert!(!messages.is_empty()); + + println!( + "robocodec_test_23.bag: {} messages, {} channels", + messages.len(), + num_channels + ); +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_streaming_vs_nonstreaming_consistency() { + // Compare: streaming parser should discover the same connections and + // message count as the non-streaming BagParser. + use robocodec::io::formats::bag::parser::BagParser; + + let bag_path = "tests/fixtures/robocodec_test_18.bag"; + if !Path::new(bag_path).exists() { + println!("Skipping: fixture not found"); + return; + } + + // --- Non-streaming parser --- + let non_streaming = BagParser::open(bag_path).unwrap(); + let ns_conn_count = non_streaming.connections().len(); + + // Build the conn_id_map the same way parallel reader does: + // map each connection ID to a sequential channel index + let conn_id_map: std::collections::HashMap = non_streaming + .connections() + .keys() + .enumerate() + .map(|(i, &conn_id)| (conn_id, i as u16)) + .collect(); + + let mut ns_message_count = 0usize; + for chunk_info in non_streaming.chunks() { + let decompressed = non_streaming.read_chunk(chunk_info).unwrap(); + let msgs = non_streaming + .parse_chunk_messages(&decompressed, &conn_id_map) + .unwrap(); + ns_message_count += msgs.len(); + } + + // --- Streaming parser --- + let (stream_messages, stream_conn_count, _parser) = parse_fixture_bag("robocodec_test_18.bag"); + + println!( + "Non-streaming: {} connections, {} messages", + ns_conn_count, ns_message_count + ); + println!( + "Streaming: {} connections, {} messages", + stream_conn_count, + stream_messages.len() + ); + + // Connection counts should match + assert_eq!( + stream_conn_count, ns_conn_count, + "Connection count mismatch: streaming={stream_conn_count}, non-streaming={ns_conn_count}" + ); + + // Message counts should match + assert_eq!( + stream_messages.len(), + ns_message_count, + "Message count mismatch: streaming={}, non-streaming={ns_message_count}", + stream_messages.len() + ); +} + +#[cfg(feature = "remote")] +#[test] +fn test_fixture_bag_small_chunk_streaming() { + // Test streaming with very small read chunks (64 bytes) to stress + // the cross-chunk boundary handling + let path = "tests/fixtures/robocodec_test_19.bag"; + if !Path::new(path).exists() { + println!("Skipping: fixture not found"); + return; + } + + let data = std::fs::read(path).unwrap(); + let mut parser = StreamingBagParser::new(); + + // Feed in tiny 64-byte chunks + let mut all_messages = Vec::new(); + for piece in data.chunks(64) { + let msgs = parser.parse_chunk(piece).unwrap(); + all_messages.extend(msgs); + } + + assert!(parser.is_initialized()); + assert!( + !all_messages.is_empty(), + "Expected messages with 64-byte streaming chunks" + ); + assert!(!parser.channels().is_empty()); + + // Compare with the larger chunk parse + let (large_chunk_msgs, _, _) = parse_fixture_bag("robocodec_test_19.bag"); + assert_eq!( + all_messages.len(), + large_chunk_msgs.len(), + "64-byte chunks should yield same message count as 256KB chunks" + ); +} + +#[cfg(feature = "remote")] +#[test] +fn test_all_fixture_bags_nonzero_messages() { + // Ensure ALL fixture .bag files produce at least some messages + let fixtures = [ + "robocodec_test_15.bag", + "robocodec_test_17.bag", + "robocodec_test_18.bag", + "robocodec_test_19.bag", + "robocodec_test_20.bag", + "robocodec_test_21.bag", + "robocodec_test_22.bag", + "robocodec_test_23.bag", + ]; + + for fixture in &fixtures { + let path = format!("tests/fixtures/{fixture}"); + if !Path::new(&path).exists() { + println!("Skipping {fixture}: not found"); + continue; + } + + let (messages, channels, parser) = parse_fixture_bag(fixture); + assert!(parser.is_initialized(), "{fixture}: parser not initialized"); + assert!(channels > 0, "{fixture}: no channels discovered"); + assert!( + !messages.is_empty(), + "{fixture}: no messages extracted (likely chunk handling bug)" + ); + + println!( + "{fixture}: {} messages, {} channels, version={:?}", + messages.len(), + channels, + parser.version() + ); + } +}