diff --git a/ntpd/src/metrics/exporter.rs b/ntpd/src/metrics/exporter.rs index 2d0c5214b..e09d3670b 100644 --- a/ntpd/src/metrics/exporter.rs +++ b/ntpd/src/metrics/exporter.rs @@ -1,6 +1,6 @@ use libc::{ECONNABORTED, EMFILE, ENFILE, ENOBUFS, ENOMEM}; use timestamped_socket::interface::ChangeDetector; -use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tracing::{debug, error, trace, warn}; @@ -211,9 +211,43 @@ async fn run(options: NtpMetricsExporterOptions) -> Result<(), Box std::io::Result<()> { + // Wait until a request was sent, dropping the bytes read when this scope ends + // to ensure we don't accidentally use them afterwards + { + // Receive all data until the header was fully received, or until max buf size + let mut buf = [0u8; 2048]; + let mut bytes_read = 0; + loop { + bytes_read += stream.read(&mut buf[bytes_read..]).await?; + + // The headers end with two CRLFs in a row + if buf[0..bytes_read].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + + // Headers should easily fit within the buffer + // If we have not found the end yet, we are not going to + if bytes_read >= buf.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Request too long", + )); + } + } + + // We only respond to GET requests + if !buf[0..bytes_read].starts_with(b"GET ") { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Expected GET request", + )); + } + } + + // Send the response let mut buf = String::with_capacity(4 * 1024); match handler(&mut buf, observation_socket_path).await { Ok(()) => { @@ -261,6 +295,8 @@ fn format_response(buf: &mut String, state: &ObservableState) -> std::fmt::Resul #[cfg(test)] mod tests { + use std::io::Cursor; + use super::*; const BINARY: &str = "/usr/bin/ntp-metrics-exporter"; @@ -274,4 +310,24 @@ mod tests { let options = NtpMetricsExporterOptions::try_parse_from(arguments).unwrap(); assert_eq!(options.config.unwrap().as_path(), config); } + + #[tokio::test] + async fn deny_non_get_request() { + let mut example = b"POST / HTTP/1.1\r\n\r\n".to_vec(); + let mut cursor = Cursor::new(&mut example); + let res = handle_connection(&mut cursor, Path::new("/tmp/ntpd-rs.sock")).await; + let err = res.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); + assert_eq!(err.to_string(), "Expected GET request"); + } + + #[tokio::test] + async fn does_not_accept_large_requests() { + let mut example = [1u8; 4096].to_vec(); + let mut cursor = Cursor::new(&mut example); + let res = handle_connection(&mut cursor, Path::new("/tmp/ntpd-rs.sock")).await; + let err = res.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); + assert_eq!(err.to_string(), "Request too long"); + } }