Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

emit an $addFields stage before $sort with safe aliases if necessary #109

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This changelog documents the changes between release versions.
### Fixed

- Selecting nested fields with names that begin with a dollar sign ([#108](https://github.com/hasura/ndc-mongodb/pull/108))
- Sorting by fields with names that begin with a dollar sign ([#109](https://github.com/hasura/ndc-mongodb/pull/109))

### Changed

Expand Down
3 changes: 2 additions & 1 deletion crates/mongodb-agent-common/src/mongodb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ mod database;
mod pipeline;
pub mod sanitize;
mod selection;
mod sort_document;
mod stage;

#[cfg(test)]
pub mod test_helpers;

pub use self::{
accumulator::Accumulator, collection::CollectionTrait, database::DatabaseTrait,
pipeline::Pipeline, selection::Selection, stage::Stage,
pipeline::Pipeline, selection::Selection, sort_document::SortDocument, stage::Stage,
};

// MockCollectionTrait is generated by automock when the test flag is active.
Expand Down
4 changes: 2 additions & 2 deletions crates/mongodb-agent-common/src/mongodb/sanitize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub fn is_name_safe(name: &str) -> bool {
/// Given a collection or field name, returns Ok if the name is safe, or Err if it contains
/// characters that MongoDB will interpret specially.
///
/// TODO: MDB-159, MBD-160 remove this function in favor of ColumnRef which is infallible
/// TODO: ENG-973 remove this function in favor of ColumnRef which is infallible
pub fn safe_name(name: &str) -> Result<Cow<str>, MongoAgentError> {
if name.starts_with('$') || name.contains('.') {
Err(MongoAgentError::BadQuery(anyhow!("cannot execute query that includes the name, \"{name}\", because it includes characters that MongoDB interperets specially")))
Expand All @@ -56,7 +56,7 @@ const ESCAPE_CHAR_ESCAPE_SEQUENCE: u32 = 0xff;

/// MongoDB variable names allow a limited set of ASCII characters, or any non-ASCII character.
/// See https://www.mongodb.com/docs/manual/reference/aggregation-variables/
fn escape_invalid_variable_chars(input: &str) -> String {
pub fn escape_invalid_variable_chars(input: &str) -> String {
let mut encoded = String::new();
for char in input.chars() {
match char {
Expand Down
14 changes: 14 additions & 0 deletions crates/mongodb-agent-common/src/mongodb/sort_document.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use mongodb::bson;
use serde::{Deserialize, Serialize};

/// Wraps a BSON document that represents a set of sort criteria. A SortDocument value is intended
/// to be used as the argument to a $sort pipeline stage.
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[serde(transparent)]
pub struct SortDocument(pub bson::Document);

impl SortDocument {
pub fn from_doc(doc: bson::Document) -> Self {
SortDocument(doc)
}
}
11 changes: 9 additions & 2 deletions crates/mongodb-agent-common/src/mongodb/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ use std::collections::BTreeMap;
use mongodb::bson;
use serde::Serialize;

use super::{accumulator::Accumulator, pipeline::Pipeline, Selection};
use super::{accumulator::Accumulator, pipeline::Pipeline, Selection, SortDocument};

/// Aggergation Pipeline Stage. This is a work-in-progress - we are adding enum variants to match
/// MongoDB pipeline stage types as we need them in this app. For documentation on all stage types
/// see,
/// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#std-label-aggregation-pipeline-operator-reference
#[derive(Clone, Debug, PartialEq, Serialize)]
pub enum Stage {
/// Adds new fields to documents. $addFields outputs documents that contain all existing fields
/// from the input documents and newly added fields.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/addFields/
#[serde(rename = "$addFields")]
AddFields(bson::Document),

/// Returns literal documents from input expressions.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/documents/#mongodb-pipeline-pipe.-documents
Expand All @@ -35,7 +42,7 @@ pub enum Stage {
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/sort/#mongodb-pipeline-pipe.-sort
#[serde(rename = "$sort")]
Sort(bson::Document),
Sort(SortDocument),

/// Passes the first n documents unmodified to the pipeline where n is the specified limit. For
/// each input document, outputs either one document (for the first n documents) or zero
Expand Down
32 changes: 18 additions & 14 deletions crates/mongodb-agent-common/src/query/column_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl<'a> ColumnRef<'a> {
from_comparison_target(column)
}

/// TODO: This will hopefully become infallible once MDB-150 & MDB-151 are implemented.
/// TODO: This will hopefully become infallible once ENG-1011 & ENG-1010 are implemented.
pub fn from_order_by_target(target: &OrderByTarget) -> Result<ColumnRef<'_>, MongoAgentError> {
from_order_by_target(target)
}
Expand Down Expand Up @@ -138,30 +138,33 @@ fn from_comparison_target(column: &ComparisonTarget) -> ColumnRef<'_> {

fn from_order_by_target(target: &OrderByTarget) -> Result<ColumnRef<'_>, MongoAgentError> {
match target {
// We exclude `path` (the relationship path) from the resulting ColumnRef because MongoDB
// field references are not relationship-aware. Traversing relationship references is
// handled upstream.
OrderByTarget::Column {
name, field_path, ..
name,
field_path,
path,
} => {
let name_and_path = once(name.as_ref() as &str).chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_ref() as &str),
);
let name_and_path = path
.iter()
.map(|n| n.as_str())
.chain([name.as_str()])
.chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_str()),
);
// The None case won't come up if the input to [from_target_helper] has at least
// one element, and we know it does because we start the iterable with `name`
Ok(from_path(None, name_and_path).unwrap())
}
OrderByTarget::SingleColumnAggregate { .. } => {
// TODO: MDB-150
// TODO: ENG-1011
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
OrderByTarget::StarCountAggregate { .. } => {
// TODO: MDB-151
// TODO: ENG-1010
Err(MongoAgentError::NotImplemented(
"ordering by star count aggregate".into(),
))
Expand Down Expand Up @@ -352,7 +355,8 @@ mod tests {
scope: Scope::Root,
};
let actual = ColumnRef::from_comparison_target(&target);
let expected = ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into());
let expected =
ColumnRef::ExpressionStringShorthand("$$scope_root.field.prop1.prop2".into());
assert_eq!(actual, expected);
Ok(())
}
Expand Down
207 changes: 159 additions & 48 deletions crates/mongodb-agent-common/src/query/make_sort.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,176 @@
use itertools::Itertools as _;
use mongodb::bson::{bson, Document};
use std::{collections::BTreeMap, iter::once};

use itertools::join;
use mongodb::bson::bson;
use ndc_models::OrderDirection;

use crate::{
interface_types::MongoAgentError,
mongo_query_plan::{OrderBy, OrderByTarget},
mongodb::sanitize::safe_name,
mongodb::{sanitize::escape_invalid_variable_chars, SortDocument, Stage},
};

pub fn make_sort(order_by: &OrderBy) -> Result<Document, MongoAgentError> {
use super::column_ref::ColumnRef;

/// In a [SortDocument] there is no way to reference field names that need to be escaped, such as
/// names that begin with dollar signs. To sort on such fields we need to insert an $addFields
/// stage _before_ the $sort stage to map safe aliases.
type RequiredAliases<'a> = BTreeMap<String, ColumnRef<'a>>;

type Result<T> = std::result::Result<T, MongoAgentError>;

pub fn make_sort_stages(order_by: &OrderBy) -> Result<Vec<Stage>> {
let (sort_document, required_aliases) = make_sort(order_by)?;
let mut stages = vec![];

if !required_aliases.is_empty() {
let fields = required_aliases
.into_iter()
.map(|(alias, expression)| (alias, expression.into_aggregate_expression()))
.collect();
let stage = Stage::AddFields(fields);
stages.push(stage);
}

let sort_stage = Stage::Sort(sort_document);
stages.push(sort_stage);

Ok(stages)
}

fn make_sort(order_by: &OrderBy) -> Result<(SortDocument, RequiredAliases<'_>)> {
let OrderBy { elements } = order_by;

elements
.clone()
let keys_directions_expressions: BTreeMap<String, (OrderDirection, Option<ColumnRef<'_>>)> =
elements
.iter()
.map(|obe| {
let col_ref = ColumnRef::from_order_by_target(&obe.target)?;
let (key, required_alias) = match col_ref {
ColumnRef::MatchKey(key) => (key.to_string(), None),
ref_expr => (safe_alias(&obe.target)?, Some(ref_expr)),
};
Ok((key, (obe.order_direction, required_alias)))
})
.collect::<Result<BTreeMap<_, _>>>()?;

let sort_document = keys_directions_expressions
.iter()
.map(|obe| {
let direction = match obe.clone().order_direction {
.map(|(key, (direction, _))| {
let direction_bson = match direction {
OrderDirection::Asc => bson!(1),
OrderDirection::Desc => bson!(-1),
};
match &obe.target {
OrderByTarget::Column {
name,
field_path,
path,
} => Ok((
column_ref_with_path(name, field_path.as_deref(), path)?,
direction,
)),
OrderByTarget::SingleColumnAggregate {
column: _,
function: _,
path: _,
result_type: _,
} =>
// TODO: MDB-150
{
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
OrderByTarget::StarCountAggregate { path: _ } => Err(
// TODO: MDB-151
MongoAgentError::NotImplemented("ordering by star count aggregate".into()),
),
}
(key.clone(), direction_bson)
})
.collect()
.collect();

let required_aliases = keys_directions_expressions
.into_iter()
.flat_map(|(key, (_, expr))| expr.map(|e| (key, e)))
.collect();

Ok((SortDocument(sort_document), required_aliases))
}

// TODO: MDB-159 Replace use of [safe_name] with [ColumnRef].
fn column_ref_with_path(
name: &ndc_models::FieldName,
field_path: Option<&[ndc_models::FieldName]>,
relation_path: &[ndc_models::RelationshipName],
) -> Result<String, MongoAgentError> {
relation_path
.iter()
.map(|n| n.as_str())
.chain(std::iter::once(name.as_str()))
.chain(field_path.into_iter().flatten().map(|n| n.as_str()))
.map(safe_name)
.process_results(|mut iter| iter.join("."))
fn safe_alias(target: &OrderByTarget) -> Result<String> {
match target {
ndc_query_plan::OrderByTarget::Column {
name,
field_path,
path,
} => {
let name_and_path = once("__sort_key_")
.chain(path.iter().map(|n| n.as_str()))
.chain([name.as_str()])
.chain(
field_path
.iter()
.flatten()
.map(|field_name| field_name.as_str()),
);
let combine_all_elements_into_one_name = join(name_and_path, "_");
Ok(escape_invalid_variable_chars(
&combine_all_elements_into_one_name,
))
}
ndc_query_plan::OrderByTarget::SingleColumnAggregate { .. } => {
// TODO: ENG-1011
Err(MongoAgentError::NotImplemented(
"ordering by single column aggregate".into(),
))
}
ndc_query_plan::OrderByTarget::StarCountAggregate { .. } => {
// TODO: ENG-1010
Err(MongoAgentError::NotImplemented(
"ordering by star count aggregate".into(),
))
}
}
}

#[cfg(test)]
mod tests {
use mongodb::bson::doc;
use ndc_models::{FieldName, OrderDirection};
use ndc_query_plan::OrderByElement;
use pretty_assertions::assert_eq;

use crate::{mongo_query_plan::OrderBy, mongodb::SortDocument, query::column_ref::ColumnRef};

use super::make_sort;

#[test]
fn escapes_field_names() -> anyhow::Result<()> {
let order_by = OrderBy {
elements: vec![OrderByElement {
order_direction: OrderDirection::Asc,
target: ndc_query_plan::OrderByTarget::Column {
name: "$schema".into(),
field_path: Default::default(),
path: Default::default(),
},
}],
};
let path: [FieldName; 1] = ["$schema".into()];

let actual = make_sort(&order_by)?;
let expected_sort_doc = SortDocument(doc! {
"__sort_key__·24schema": 1
});
let expected_aliases = [(
"__sort_key__·24schema".into(),
ColumnRef::from_field_path(path.iter()),
)]
.into();
assert_eq!(actual, (expected_sort_doc, expected_aliases));
Ok(())
}

#[test]
fn escapes_nested_field_names() -> anyhow::Result<()> {
let order_by = OrderBy {
elements: vec![OrderByElement {
order_direction: OrderDirection::Asc,
target: ndc_query_plan::OrderByTarget::Column {
name: "configuration".into(),
field_path: Some(vec!["$schema".into()]),
path: Default::default(),
},
}],
};
let path: [FieldName; 2] = ["configuration".into(), "$schema".into()];

let actual = make_sort(&order_by)?;
let expected_sort_doc = SortDocument(doc! {
"__sort_key__configuration_·24schema": 1
});
let expected_aliases = [(
"__sort_key__configuration_·24schema".into(),
ColumnRef::from_field_path(path.iter()),
)]
.into();
assert_eq!(actual, (expected_sort_doc, expected_aliases));
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/mongodb-agent-common/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use ndc_models::{QueryRequest, QueryResponse};
use self::execute_query_request::execute_query_request;
pub use self::{
make_selector::make_selector,
make_sort::make_sort,
make_sort::make_sort_stages,
pipeline::{is_response_faceted, pipeline_for_non_foreach, pipeline_for_query_request},
query_target::QueryTarget,
response::QueryResponseError,
Expand Down
Loading
Loading