diff --git a/opentelemetry/src/baggage.rs b/opentelemetry/src/baggage.rs index 37ba28e682..35bbd8d2ff 100644 --- a/opentelemetry/src/baggage.rs +++ b/opentelemetry/src/baggage.rs @@ -16,6 +16,7 @@ //! [W3C Baggage]: https://w3c.github.io/baggage use crate::{Context, Key, KeyValue, Value}; use once_cell::sync::Lazy; +use std::collections::hash_map::Entry; use std::collections::{hash_map, HashMap}; use std::fmt; @@ -124,6 +125,8 @@ impl Baggage { /// Same with `insert`, if the name was not present, [`None`] will be returned. /// If the name is present, the old value and metadata will be returned. /// + /// Also checks for [limits](https://w3c.github.io/baggage/#limits). + /// /// # Examples /// /// ``` @@ -146,10 +149,41 @@ impl Baggage { S: Into, { let (key, value, metadata) = (key.into(), value.into(), metadata.into()); - if self.insertable(&key, &value, &metadata) { - self.inner.insert(key, (value, metadata)) - } else { - None + if !key.as_str().is_ascii() { + return None; + } + let entry_content_len = + key_value_metadata_bytes_size(key.as_str(), value.as_str().as_ref(), metadata.as_str()); + if entry_content_len >= MAX_BYTES_FOR_ONE_PAIR { + return None; + } + let entries_count = self.inner.len(); + match self.inner.entry(key) { + Entry::Occupied(mut occupied_entry) => { + let prev_content_len = key_value_metadata_bytes_size( + occupied_entry.key().as_str(), + &occupied_entry.get().0.as_str().as_ref(), + occupied_entry.get().1.as_str(), + ); + let new_content_len = self.kv_content_len + entry_content_len - prev_content_len; + if new_content_len > MAX_LEN_OF_ALL_PAIRS { + return None; + } + self.kv_content_len = new_content_len; + Some(occupied_entry.insert((value, metadata))) + } + Entry::Vacant(vacant_entry) => { + if entries_count == MAX_KEY_VALUE_PAIRS { + return None; + } + let new_content_len = self.kv_content_len + entry_content_len; + if new_content_len > MAX_LEN_OF_ALL_PAIRS { + return None; + } + self.kv_content_len = new_content_len; + vacant_entry.insert((value, metadata)); + None + } } } @@ -169,59 +203,10 @@ impl Baggage { self.inner.is_empty() } - /// Gets an iterator over the baggage items, sorted by name. + /// Gets an iterator over the baggage items, in any order. pub fn iter(&self) -> Iter<'_> { self.into_iter() } - - /// Determine whether the key value pair exceed one of the [limits](https://w3c.github.io/baggage/#limits). - /// If not, update the total length of key values - fn insertable(&mut self, key: &Key, value: &Value, metadata: &BaggageMetadata) -> bool { - if !key.as_str().is_ascii() { - return false; - } - let value = value.as_str(); - if key_value_metadata_bytes_size(key.as_str(), value.as_ref(), metadata.as_str()) - < MAX_BYTES_FOR_ONE_PAIR - { - match self.inner.get(key) { - None => { - // check total length - if self.kv_content_len - + metadata.as_str().len() - + value.len() - + key.as_str().len() - > MAX_LEN_OF_ALL_PAIRS - { - return false; - } - // check number of pairs - if self.inner.len() + 1 > MAX_KEY_VALUE_PAIRS { - return false; - } - self.kv_content_len += - metadata.as_str().len() + value.len() + key.as_str().len() - } - Some((old_value, old_metadata)) => { - let old_value = old_value.as_str(); - if self.kv_content_len - old_metadata.as_str().len() - old_value.len() - + metadata.as_str().len() - + value.len() - > MAX_LEN_OF_ALL_PAIRS - { - return false; - } - self.kv_content_len = - self.kv_content_len - old_metadata.as_str().len() - old_value.len() - + metadata.as_str().len() - + value.len() - } - } - true - } else { - false - } - } } /// Get the number of bytes for one key-value pair @@ -376,13 +361,11 @@ impl BaggageExt for Context { &self, baggage: T, ) -> Self { - let mut merged: Baggage = self - .baggage() - .iter() - .map(|(key, (value, metadata))| { - KeyValueMetadata::new(key.clone(), value.clone(), metadata.clone()) - }) - .collect(); + let old = self.baggage(); + let mut merged = Baggage { + inner: old.inner.clone(), + kv_content_len: old.kv_content_len, + }; for kvm in baggage.into_iter().map(|kv| kv.into()) { merged.insert_with_metadata(kvm.key, kvm.value, kvm.metadata); }