Skip to content

Commit

Permalink
emit an $addFields stage before $sort with safe aliases if necessary (#…
Browse files Browse the repository at this point in the history
…109)

* emit an $addFields stage before $sort with safe aliases if necessary

* update changelog
  • Loading branch information
hallettj authored Oct 1, 2024
1 parent 8ab0ab2 commit 6f264f3
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 76 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This changelog documents the changes between release versions.
### Fixed

- Selecting nested fields with names that begin with a dollar sign ([#108](https://github.com/hasura/ndc-mongodb/pull/108))
- Sorting by fields with names that begin with a dollar sign ([#109](https://github.com/hasura/ndc-mongodb/pull/109))

### Changed

Expand Down
3 changes: 2 additions & 1 deletion crates/mongodb-agent-common/src/mongodb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ mod database;
mod pipeline;
pub mod sanitize;
mod selection;
mod sort_document;
mod stage;

#[cfg(test)]
pub mod test_helpers;

pub use self::{
accumulator::Accumulator, collection::CollectionTrait, database::DatabaseTrait,
pipeline::Pipeline, selection::Selection, stage::Stage,
pipeline::Pipeline, selection::Selection, sort_document::SortDocument, stage::Stage,
};

// MockCollectionTrait is generated by automock when the test flag is active.
Expand Down
4 changes: 2 additions & 2 deletions crates/mongodb-agent-common/src/mongodb/sanitize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub fn is_name_safe(name: &str) -> bool {
/// Given a collection or field name, returns Ok if the name is safe, or Err if it contains
/// characters that MongoDB will interpret specially.
///
/// TODO: MDB-159, MBD-160 remove this function in favor of ColumnRef which is infallible
/// TODO: ENG-973 remove this function in favor of ColumnRef which is infallible
pub fn safe_name(name: &str) -> Result<Cow<str>, MongoAgentError> {
if name.starts_with('$') || name.contains('.') {
Err(MongoAgentError::BadQuery(anyhow!("cannot execute query that includes the name, \"{name}\", because it includes characters that MongoDB interperets specially")))
Expand All @@ -56,7 +56,7 @@ const ESCAPE_CHAR_ESCAPE_SEQUENCE: u32 = 0xff;

/// MongoDB variable names allow a limited set of ASCII characters, or any non-ASCII character.
/// See https://www.mongodb.com/docs/manual/reference/aggregation-variables/
fn escape_invalid_variable_chars(input: &str) -> String {
pub fn escape_invalid_variable_chars(input: &str) -> String {
let mut encoded = String::new();
for char in input.chars() {
match char {
Expand Down
14 changes: 14 additions & 0 deletions crates/mongodb-agent-common/src/mongodb/sort_document.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use mongodb::bson;
use serde::{Deserialize, Serialize};

/// Wraps a BSON document that represents a set of sort criteria. A SortDocument value is intended
/// to be used as the argument to a $sort pipeline stage.
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[serde(transparent)]
pub struct SortDocument(pub bson::Document);

impl SortDocument {
pub fn from_doc(doc: bson::Document) -> Self {
SortDocument(doc)
}
}
11 changes: 9 additions & 2 deletions crates/mongodb-agent-common/src/mongodb/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ use std::collections::BTreeMap;
use mongodb::bson;
use serde::Serialize;

use super::{accumulator::Accumulator, pipeline::Pipeline, Selection};
use super::{accumulator::Accumulator, pipeline::Pipeline, Selection, SortDocument};

/// Aggergation Pipeline Stage. This is a work-in-progress - we are adding enum variants to match
/// MongoDB pipeline stage types as we need them in this app. For documentation on all stage types
/// see,
/// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#std-label-aggregation-pipeline-operator-reference
#[derive(Clone, Debug, PartialEq, Serialize)]
pub enum Stage {
/// Adds new fields to documents. $addFields outputs documents that contain all existing fields
/// from the input documents and newly added fields.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/addFields/
#[serde(rename = "$addFields")]
AddFields(bson::Document),

/// Returns literal documents from input expressions.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/documents/#mongodb-pipeline-pipe.-documents
Expand All @@ -35,7 +42,7 @@ pub enum Stage {
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/sort/#mongodb-pipeline-pipe.-sort
#[serde(rename = "$sort")]
Sort(bson::Document),
Sort(SortDocument),

/// Passes the first n documents unmodified to the pipeline where n is the specified limit. For
/// each input document, outputs either one document (for the first n documents) or zero
Expand Down
32 changes: 18 additions & 14 deletions crates/mongodb-agent-common/src/query/column_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl<'a> ColumnRef<'a> {
from_comparison_target(column)
}

/// TODO: This will hopefully become infallible once MDB-150 & MDB-151 are implemented.
/// TODO: This will hopefully become infallible once ENG-1011 & ENG-1010 are implemented.
pub fn from_order_by_target(target: &OrderByTarget) -> Result<ColumnRef<'_>, MongoAgentError> {
from_order_by_target(target)
}
Expand Down Expand Up @@ -138,30 +138,33 @@ fn from_comparison_target(column: &ComparisonTarget) -> ColumnRef<'_> {

fn from_order_by_target(target: &OrderByTarget) -> Result<ColumnRef<'_>, MongoAgentError> {
match target {
// We exclude `path` (the relationship path) from the resulting ColumnRef because MongoDB
// field references are not relationship-aware. Traversing relationship references is
// handled upstream.
OrderByTarget::Column {
name, field_path, ..
name,
field_path,
path,
} => {
let name_and_path = once(name.as_ref() as &str).chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_ref() as &str),
);
let name_and_path = path
.iter()
.map(|n| n.as_str())
.chain([name.as_str()])
.chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_str()),
);
// The None case won't come up if the input to [from_target_helper] has at least
// one element, and we know it does because we start the iterable with `name`
Ok(from_path(None, name_and_path).unwrap())
}
OrderByTarget::SingleColumnAggregate { .. } => {
// TODO: MDB-150
// TODO: ENG-1011
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
OrderByTarget::StarCountAggregate { .. } => {
// TODO: MDB-151
// TODO: ENG-1010
Err(MongoAgentError::NotImplemented(
"ordering by star count aggregate".into(),
))
Expand Down Expand Up @@ -352,7 +355,8 @@ mod tests {
scope: Scope::Root,
};
let actual = ColumnRef::from_comparison_target(&target);
let expected = ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into());
let expected =
ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into());
assert_eq!(actual, expected);
Ok(())
}
Expand Down
207 changes: 159 additions & 48 deletions crates/mongodb-agent-common/src/query/make_sort.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,176 @@
use itertools::Itertools as _;
use mongodb::bson::{bson, Document};
use std::{collections::BTreeMap, iter::once};

use itertools::join;
use mongodb::bson::bson;
use ndc_models::OrderDirection;

use crate::{
interface_types::MongoAgentError,
mongo_query_plan::{OrderBy, OrderByTarget},
mongodb::sanitize::safe_name,
mongodb::{sanitize::escape_invalid_variable_chars, SortDocument, Stage},
};

pub fn make_sort(order_by: &OrderBy) -> Result<Document, MongoAgentError> {
use super::column_ref::ColumnRef;

/// In a [SortDocument] there is no way to reference field names that need to be escaped, such as
/// names that begin with dollar signs. To sort on such fields we need to insert an $addFields
/// stage _before_ the $sort stage to map safe aliases.
type RequiredAliases<'a> = BTreeMap<String, ColumnRef<'a>>;

type Result<T> = std::result::Result<T, MongoAgentError>;

pub fn make_sort_stages(order_by: &OrderBy) -> Result<Vec<Stage>> {
let (sort_document, required_aliases) = make_sort(order_by)?;
let mut stages = vec![];

if !required_aliases.is_empty() {
let fields = required_aliases
.into_iter()
.map(|(alias, expression)| (alias, expression.into_aggregate_expression()))
.collect();
let stage = Stage::AddFields(fields);
stages.push(stage);
}

let sort_stage = Stage::Sort(sort_document);
stages.push(sort_stage);

Ok(stages)
}

fn make_sort(order_by: &OrderBy) -> Result<(SortDocument, RequiredAliases<'_>)> {
let OrderBy { elements } = order_by;

elements
.clone()
let keys_directions_expressions: BTreeMap<String, (OrderDirection, Option<ColumnRef<'_>>)> =
elements
.iter()
.map(|obe| {
let col_ref = ColumnRef::from_order_by_target(&obe.target)?;
let (key, required_alias) = match col_ref {
ColumnRef::MatchKey(key) => (key.to_string(), None),
ref_expr => (safe_alias(&obe.target)?, Some(ref_expr)),
};
Ok((key, (obe.order_direction, required_alias)))
})
.collect::<Result<BTreeMap<_, _>>>()?;

let sort_document = keys_directions_expressions
.iter()
.map(|obe| {
let direction = match obe.clone().order_direction {
.map(|(key, (direction, _))| {
let direction_bson = match direction {
OrderDirection::Asc => bson!(1),
OrderDirection::Desc => bson!(-1),
};
match &obe.target {
OrderByTarget::Column {
name,
field_path,
path,
} => Ok((
column_ref_with_path(name, field_path.as_deref(), path)?,
direction,
)),
OrderByTarget::SingleColumnAggregate {
column: _,
function: _,
path: _,
result_type: _,
} =>
// TODO: MDB-150
{
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
OrderByTarget::StarCountAggregate { path: _ } => Err(
// TODO: MDB-151
MongoAgentError::NotImplemented("ordering by star count aggregate".into()),
),
}
(key.clone(), direction_bson)
})
.collect()
.collect();

let required_aliases = keys_directions_expressions
.into_iter()
.flat_map(|(key, (_, expr))| expr.map(|e| (key, e)))
.collect();

Ok((SortDocument(sort_document), required_aliases))
}

// TODO: MDB-159 Replace use of [safe_name] with [ColumnRef].
fn column_ref_with_path(
name: &ndc_models::FieldName,
field_path: Option<&[ndc_models::FieldName]>,
relation_path: &[ndc_models::RelationshipName],
) -> Result<String, MongoAgentError> {
relation_path
.iter()
.map(|n| n.as_str())
.chain(std::iter::once(name.as_str()))
.chain(field_path.into_iter().flatten().map(|n| n.as_str()))
.map(safe_name)
.process_results(|mut iter| iter.join("."))
fn safe_alias(target: &OrderByTarget) -> Result<String> {
match target {
ndc_query_plan::OrderByTarget::Column {
name,
field_path,
path,
} => {
let name_and_path = once("__sort_key_")
.chain(path.iter().map(|n| n.as_str()))
.chain([name.as_str()])
.chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_str()),
);
let combine_all_elements_into_one_name = join(name_and_path, "_");
Ok(escape_invalid_variable_chars(
&combine_all_elements_into_one_name,
))
}
ndc_query_plan::OrderByTarget::SingleColumnAggregate { .. } => {
// TODO: ENG-1011
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
ndc_query_plan::OrderByTarget::StarCountAggregate { .. } => {
// TODO: ENG-1010
Err(MongoAgentError::NotImplemented(
"ordering by star count aggregate".into(),
))
}
}
}

#[cfg(test)]
mod tests {
use mongodb::bson::doc;
use ndc_models::{FieldName, OrderDirection};
use ndc_query_plan::OrderByElement;
use pretty_assertions::assert_eq;

use crate::{mongo_query_plan::OrderBy, mongodb::SortDocument, query::column_ref::ColumnRef};

use super::make_sort;

#[test]
fn escapes_field_names() -> anyhow::Result<()> {
let order_by = OrderBy {
elements: vec![OrderByElement {
order_direction: OrderDirection::Asc,
target: ndc_query_plan::OrderByTarget::Column {
name: "$schema".into(),
field_path: Default::default(),
path: Default::default(),
},
}],
};
let path: [FieldName; 1] = ["$schema".into()];

let actual = make_sort(&order_by)?;
let expected_sort_doc = SortDocument(doc! {
"__sort_key__·24schema": 1
});
let expected_aliases = [(
"__sort_key__·24schema".into(),
ColumnRef::from_field_path(path.iter()),
)]
.into();
assert_eq!(actual, (expected_sort_doc, expected_aliases));
Ok(())
}

#[test]
fn escapes_nested_field_names() -> anyhow::Result<()> {
let order_by = OrderBy {
elements: vec![OrderByElement {
order_direction: OrderDirection::Asc,
target: ndc_query_plan::OrderByTarget::Column {
name: "configuration".into(),
field_path: Some(vec!["$schema".into()]),
path: Default::default(),
},
}],
};
let path: [FieldName; 2] = ["configuration".into(), "$schema".into()];

let actual = make_sort(&order_by)?;
let expected_sort_doc = SortDocument(doc! {
"__sort_key__configuration_·24schema": 1
});
let expected_aliases = [(
"__sort_key__configuration_·24schema".into(),
ColumnRef::from_field_path(path.iter()),
)]
.into();
assert_eq!(actual, (expected_sort_doc, expected_aliases));
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/mongodb-agent-common/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use ndc_models::{QueryRequest, QueryResponse};
use self::execute_query_request::execute_query_request;
pub use self::{
make_selector::make_selector,
make_sort::make_sort,
make_sort::make_sort_stages,
pipeline::{is_response_faceted, pipeline_for_non_foreach, pipeline_for_query_request},
query_target::QueryTarget,
response::QueryResponseError,
Expand Down
Loading

0 comments on commit 6f264f3

Please sign in to comment.