diff --git a/CHANGELOG.md b/CHANGELOG.md index e3e97707..711df1f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This changelog documents the changes between release versions. ### Fixed - Selecting nested fields with names that begin with a dollar sign ([#108](https://github.com/hasura/ndc-mongodb/pull/108)) +- Sorting by fields with names that begin with a dollar sign ([#109](https://github.com/hasura/ndc-mongodb/pull/109)) ### Changed diff --git a/crates/mongodb-agent-common/src/mongodb/mod.rs b/crates/mongodb-agent-common/src/mongodb/mod.rs index 8931d5db..d1a7c8c4 100644 --- a/crates/mongodb-agent-common/src/mongodb/mod.rs +++ b/crates/mongodb-agent-common/src/mongodb/mod.rs @@ -4,6 +4,7 @@ mod database; mod pipeline; pub mod sanitize; mod selection; +mod sort_document; mod stage; #[cfg(test)] @@ -11,7 +12,7 @@ pub mod test_helpers; pub use self::{ accumulator::Accumulator, collection::CollectionTrait, database::DatabaseTrait, - pipeline::Pipeline, selection::Selection, stage::Stage, + pipeline::Pipeline, selection::Selection, sort_document::SortDocument, stage::Stage, }; // MockCollectionTrait is generated by automock when the test flag is active. diff --git a/crates/mongodb-agent-common/src/mongodb/sanitize.rs b/crates/mongodb-agent-common/src/mongodb/sanitize.rs index b5f3f84b..b7027205 100644 --- a/crates/mongodb-agent-common/src/mongodb/sanitize.rs +++ b/crates/mongodb-agent-common/src/mongodb/sanitize.rs @@ -35,7 +35,7 @@ pub fn is_name_safe(name: &str) -> bool { /// Given a collection or field name, returns Ok if the name is safe, or Err if it contains /// characters that MongoDB will interpret specially. /// -/// TODO: MDB-159, MBD-160 remove this function in favor of ColumnRef which is infallible +/// TODO: ENG-973 remove this function in favor of ColumnRef which is infallible pub fn safe_name(name: &str) -> Result, MongoAgentError> { if name.starts_with('$') || name.contains('.') { Err(MongoAgentError::BadQuery(anyhow!("cannot execute query that includes the name, \"{name}\", because it includes characters that MongoDB interperets specially"))) @@ -56,7 +56,7 @@ const ESCAPE_CHAR_ESCAPE_SEQUENCE: u32 = 0xff; /// MongoDB variable names allow a limited set of ASCII characters, or any non-ASCII character. /// See https://www.mongodb.com/docs/manual/reference/aggregation-variables/ -fn escape_invalid_variable_chars(input: &str) -> String { +pub fn escape_invalid_variable_chars(input: &str) -> String { let mut encoded = String::new(); for char in input.chars() { match char { diff --git a/crates/mongodb-agent-common/src/mongodb/sort_document.rs b/crates/mongodb-agent-common/src/mongodb/sort_document.rs new file mode 100644 index 00000000..37756cb2 --- /dev/null +++ b/crates/mongodb-agent-common/src/mongodb/sort_document.rs @@ -0,0 +1,14 @@ +use mongodb::bson; +use serde::{Deserialize, Serialize}; + +/// Wraps a BSON document that represents a set of sort criteria. A SortDocument value is intended +/// to be used as the argument to a $sort pipeline stage. +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +#[serde(transparent)] +pub struct SortDocument(pub bson::Document); + +impl SortDocument { + pub fn from_doc(doc: bson::Document) -> Self { + SortDocument(doc) + } +} diff --git a/crates/mongodb-agent-common/src/mongodb/stage.rs b/crates/mongodb-agent-common/src/mongodb/stage.rs index 9845f922..87dc51bb 100644 --- a/crates/mongodb-agent-common/src/mongodb/stage.rs +++ b/crates/mongodb-agent-common/src/mongodb/stage.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use mongodb::bson; use serde::Serialize; -use super::{accumulator::Accumulator, pipeline::Pipeline, Selection}; +use super::{accumulator::Accumulator, pipeline::Pipeline, Selection, SortDocument}; /// Aggergation Pipeline Stage. This is a work-in-progress - we are adding enum variants to match /// MongoDB pipeline stage types as we need them in this app. For documentation on all stage types @@ -11,6 +11,13 @@ use super::{accumulator::Accumulator, pipeline::Pipeline, Selection}; /// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#std-label-aggregation-pipeline-operator-reference #[derive(Clone, Debug, PartialEq, Serialize)] pub enum Stage { + /// Adds new fields to documents. $addFields outputs documents that contain all existing fields + /// from the input documents and newly added fields. + /// + /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/addFields/ + #[serde(rename = "$addFields")] + AddFields(bson::Document), + /// Returns literal documents from input expressions. /// /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/documents/#mongodb-pipeline-pipe.-documents @@ -35,7 +42,7 @@ pub enum Stage { /// /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/sort/#mongodb-pipeline-pipe.-sort #[serde(rename = "$sort")] - Sort(bson::Document), + Sort(SortDocument), /// Passes the first n documents unmodified to the pipeline where n is the specified limit. For /// each input document, outputs either one document (for the first n documents) or zero diff --git a/crates/mongodb-agent-common/src/query/column_ref.rs b/crates/mongodb-agent-common/src/query/column_ref.rs index d474f1d8..eefacf2d 100644 --- a/crates/mongodb-agent-common/src/query/column_ref.rs +++ b/crates/mongodb-agent-common/src/query/column_ref.rs @@ -53,7 +53,7 @@ impl<'a> ColumnRef<'a> { from_comparison_target(column) } - /// TODO: This will hopefully become infallible once MDB-150 & MDB-151 are implemented. + /// TODO: This will hopefully become infallible once ENG-1011 & ENG-1010 are implemented. pub fn from_order_by_target(target: &OrderByTarget) -> Result, MongoAgentError> { from_order_by_target(target) } @@ -138,30 +138,33 @@ fn from_comparison_target(column: &ComparisonTarget) -> ColumnRef<'_> { fn from_order_by_target(target: &OrderByTarget) -> Result, MongoAgentError> { match target { - // We exclude `path` (the relationship path) from the resulting ColumnRef because MongoDB - // field references are not relationship-aware. Traversing relationship references is - // handled upstream. OrderByTarget::Column { - name, field_path, .. + name, + field_path, + path, } => { - let name_and_path = once(name.as_ref() as &str).chain( - field_path - .iter() - .flatten() - .map(|field_name| field_name.as_ref() as &str), - ); + let name_and_path = path + .iter() + .map(|n| n.as_str()) + .chain([name.as_str()]) + .chain( + field_path + .iter() + .flatten() + .map(|field_name| field_name.as_str()), + ); // The None case won't come up if the input to [from_target_helper] has at least // one element, and we know it does because we start the iterable with `name` Ok(from_path(None, name_and_path).unwrap()) } OrderByTarget::SingleColumnAggregate { .. } => { - // TODO: MDB-150 + // TODO: ENG-1011 Err(MongoAgentError::NotImplemented( "ordering by single column aggregate".into(), )) } OrderByTarget::StarCountAggregate { .. } => { - // TODO: MDB-151 + // TODO: ENG-1010 Err(MongoAgentError::NotImplemented( "ordering by star count aggregate".into(), )) @@ -352,7 +355,8 @@ mod tests { scope: Scope::Root, }; let actual = ColumnRef::from_comparison_target(&target); - let expected = ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into()); + let expected = + ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into()); assert_eq!(actual, expected); Ok(()) } diff --git a/crates/mongodb-agent-common/src/query/make_sort.rs b/crates/mongodb-agent-common/src/query/make_sort.rs index ead5ceb4..e2de1d35 100644 --- a/crates/mongodb-agent-common/src/query/make_sort.rs +++ b/crates/mongodb-agent-common/src/query/make_sort.rs @@ -1,65 +1,176 @@ -use itertools::Itertools as _; -use mongodb::bson::{bson, Document}; +use std::{collections::BTreeMap, iter::once}; + +use itertools::join; +use mongodb::bson::bson; use ndc_models::OrderDirection; use crate::{ interface_types::MongoAgentError, mongo_query_plan::{OrderBy, OrderByTarget}, - mongodb::sanitize::safe_name, + mongodb::{sanitize::escape_invalid_variable_chars, SortDocument, Stage}, }; -pub fn make_sort(order_by: &OrderBy) -> Result { +use super::column_ref::ColumnRef; + +/// In a [SortDocument] there is no way to reference field names that need to be escaped, such as +/// names that begin with dollar signs. To sort on such fields we need to insert an $addFields +/// stage _before_ the $sort stage to map safe aliases. +type RequiredAliases<'a> = BTreeMap>; + +type Result = std::result::Result; + +pub fn make_sort_stages(order_by: &OrderBy) -> Result> { + let (sort_document, required_aliases) = make_sort(order_by)?; + let mut stages = vec![]; + + if !required_aliases.is_empty() { + let fields = required_aliases + .into_iter() + .map(|(alias, expression)| (alias, expression.into_aggregate_expression())) + .collect(); + let stage = Stage::AddFields(fields); + stages.push(stage); + } + + let sort_stage = Stage::Sort(sort_document); + stages.push(sort_stage); + + Ok(stages) +} + +fn make_sort(order_by: &OrderBy) -> Result<(SortDocument, RequiredAliases<'_>)> { let OrderBy { elements } = order_by; - elements - .clone() + let keys_directions_expressions: BTreeMap>)> = + elements + .iter() + .map(|obe| { + let col_ref = ColumnRef::from_order_by_target(&obe.target)?; + let (key, required_alias) = match col_ref { + ColumnRef::MatchKey(key) => (key.to_string(), None), + ref_expr => (safe_alias(&obe.target)?, Some(ref_expr)), + }; + Ok((key, (obe.order_direction, required_alias))) + }) + .collect::>>()?; + + let sort_document = keys_directions_expressions .iter() - .map(|obe| { - let direction = match obe.clone().order_direction { + .map(|(key, (direction, _))| { + let direction_bson = match direction { OrderDirection::Asc => bson!(1), OrderDirection::Desc => bson!(-1), }; - match &obe.target { - OrderByTarget::Column { - name, - field_path, - path, - } => Ok(( - column_ref_with_path(name, field_path.as_deref(), path)?, - direction, - )), - OrderByTarget::SingleColumnAggregate { - column: _, - function: _, - path: _, - result_type: _, - } => - // TODO: MDB-150 - { - Err(MongoAgentError::NotImplemented( - "ordering by single column aggregate".into(), - )) - } - OrderByTarget::StarCountAggregate { path: _ } => Err( - // TODO: MDB-151 - MongoAgentError::NotImplemented("ordering by star count aggregate".into()), - ), - } + (key.clone(), direction_bson) }) - .collect() + .collect(); + + let required_aliases = keys_directions_expressions + .into_iter() + .flat_map(|(key, (_, expr))| expr.map(|e| (key, e))) + .collect(); + + Ok((SortDocument(sort_document), required_aliases)) } -// TODO: MDB-159 Replace use of [safe_name] with [ColumnRef]. -fn column_ref_with_path( - name: &ndc_models::FieldName, - field_path: Option<&[ndc_models::FieldName]>, - relation_path: &[ndc_models::RelationshipName], -) -> Result { - relation_path - .iter() - .map(|n| n.as_str()) - .chain(std::iter::once(name.as_str())) - .chain(field_path.into_iter().flatten().map(|n| n.as_str())) - .map(safe_name) - .process_results(|mut iter| iter.join(".")) +fn safe_alias(target: &OrderByTarget) -> Result { + match target { + ndc_query_plan::OrderByTarget::Column { + name, + field_path, + path, + } => { + let name_and_path = once("__sort_key_") + .chain(path.iter().map(|n| n.as_str())) + .chain([name.as_str()]) + .chain( + field_path + .iter() + .flatten() + .map(|field_name| field_name.as_str()), + ); + let combine_all_elements_into_one_name = join(name_and_path, "_"); + Ok(escape_invalid_variable_chars( + &combine_all_elements_into_one_name, + )) + } + ndc_query_plan::OrderByTarget::SingleColumnAggregate { .. } => { + // TODO: ENG-1011 + Err(MongoAgentError::NotImplemented( + "ordering by single column aggregate".into(), + )) + } + ndc_query_plan::OrderByTarget::StarCountAggregate { .. } => { + // TODO: ENG-1010 + Err(MongoAgentError::NotImplemented( + "ordering by star count aggregate".into(), + )) + } + } +} + +#[cfg(test)] +mod tests { + use mongodb::bson::doc; + use ndc_models::{FieldName, OrderDirection}; + use ndc_query_plan::OrderByElement; + use pretty_assertions::assert_eq; + + use crate::{mongo_query_plan::OrderBy, mongodb::SortDocument, query::column_ref::ColumnRef}; + + use super::make_sort; + + #[test] + fn escapes_field_names() -> anyhow::Result<()> { + let order_by = OrderBy { + elements: vec![OrderByElement { + order_direction: OrderDirection::Asc, + target: ndc_query_plan::OrderByTarget::Column { + name: "$schema".into(), + field_path: Default::default(), + path: Default::default(), + }, + }], + }; + let path: [FieldName; 1] = ["$schema".into()]; + + let actual = make_sort(&order_by)?; + let expected_sort_doc = SortDocument(doc! { + "__sort_key__·24schema": 1 + }); + let expected_aliases = [( + "__sort_key__·24schema".into(), + ColumnRef::from_field_path(path.iter()), + )] + .into(); + assert_eq!(actual, (expected_sort_doc, expected_aliases)); + Ok(()) + } + + #[test] + fn escapes_nested_field_names() -> anyhow::Result<()> { + let order_by = OrderBy { + elements: vec![OrderByElement { + order_direction: OrderDirection::Asc, + target: ndc_query_plan::OrderByTarget::Column { + name: "configuration".into(), + field_path: Some(vec!["$schema".into()]), + path: Default::default(), + }, + }], + }; + let path: [FieldName; 2] = ["configuration".into(), "$schema".into()]; + + let actual = make_sort(&order_by)?; + let expected_sort_doc = SortDocument(doc! { + "__sort_key__configuration_·24schema": 1 + }); + let expected_aliases = [( + "__sort_key__configuration_·24schema".into(), + ColumnRef::from_field_path(path.iter()), + )] + .into(); + assert_eq!(actual, (expected_sort_doc, expected_aliases)); + Ok(()) + } } diff --git a/crates/mongodb-agent-common/src/query/mod.rs b/crates/mongodb-agent-common/src/query/mod.rs index da61f225..3353b572 100644 --- a/crates/mongodb-agent-common/src/query/mod.rs +++ b/crates/mongodb-agent-common/src/query/mod.rs @@ -18,7 +18,7 @@ use ndc_models::{QueryRequest, QueryResponse}; use self::execute_query_request::execute_query_request; pub use self::{ make_selector::make_selector, - make_sort::make_sort, + make_sort::make_sort_stages, pipeline::{is_response_faceted, pipeline_for_non_foreach, pipeline_for_query_request}, query_target::QueryTarget, response::QueryResponseError, diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index a7fb3868..4d72bf26 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use itertools::Itertools; use mongodb::bson::{self, doc, Bson}; use tracing::instrument; @@ -13,7 +14,8 @@ use crate::{ use super::{ constants::{RESULT_FIELD, ROWS_FIELD}, foreach::pipeline_for_foreach, - make_selector, make_sort, + make_selector, + make_sort::make_sort_stages, native_query::pipeline_for_native_query, query_level::QueryLevel, relations::pipeline_for_relations, @@ -70,16 +72,17 @@ pub fn pipeline_for_non_foreach( .map(make_selector) .transpose()? .map(Stage::Match); - let sort_stage: Option = order_by + let sort_stages: Vec = order_by .iter() - .map(|o| Ok(Stage::Sort(make_sort(o)?)) as Result<_, MongoAgentError>) - .next() - .transpose()?; + .map(make_sort_stages) + .flatten_ok() + .collect::, _>>()?; let skip_stage = offset.map(Stage::Skip); - [match_stage, sort_stage, skip_stage] + match_stage .into_iter() - .flatten() + .chain(sort_stages) + .chain(skip_stage) .for_each(|stage| pipeline.push(stage)); // `diverging_stages` includes either a $facet stage if the query includes aggregates, or the diff --git a/crates/mongodb-agent-common/src/query/relations.rs b/crates/mongodb-agent-common/src/query/relations.rs index 39edbdc6..f909627f 100644 --- a/crates/mongodb-agent-common/src/query/relations.rs +++ b/crates/mongodb-agent-common/src/query/relations.rs @@ -85,7 +85,7 @@ fn make_lookup_stage( } } -// TODO: MDB-160 Replace uses of [safe_name] with [ColumnRef]. +// TODO: ENG-973 Replace uses of [safe_name] with [ColumnRef]. fn single_column_mapping_lookup( from: ndc_models::CollectionName, source_selector: &ndc_models::FieldName,