From 6a5e208c9588ff3ae2f3ebecb5692dfb097d2885 Mon Sep 17 00:00:00 2001 From: Jesse Hallett Date: Mon, 24 Jun 2024 15:59:55 -0700 Subject: [PATCH] rework queries with variable sets so they use indexes (#83) * create indexes in mongodb fixtures * capture expected types of variables * map request variables to $documents stage, replace $facet with $lookup * test variable name escaping function * tests for query_variable_name * use escaping in `variable` function to make it infallible * replace variable map lookups with mongodb variable references * some test updates, delegate to variable function * fix make_selector * run `db.aggregate` if query request has variable sets * update response serialization for change in foreach response shape * update one of the foreach unit tests * update some stale comments * handle responses with aggregates, update tests * handle aggregate responses without rows * add test for binary comparison bug that I incidentally fixed * skip remote relationship integration tests in mongodb 5 * update changelog * note breaking change in changelog * change aggregate target in explain to match target in query --- CHANGELOG.md | 3 + Cargo.lock | 1 + .../src/tests/remote_relationship.rs | 22 + .../proptest-regressions/mongodb/sanitize.txt | 7 + .../query/query_variable_name.txt | 7 + crates/mongodb-agent-common/src/explain.rs | 7 +- .../src/interface_types/mongo_agent_error.rs | 5 - .../src/mongo_query_plan/mod.rs | 1 + .../src/mongodb/sanitize.rs | 117 ++++- .../mongodb-agent-common/src/mongodb/stage.rs | 6 + .../src/procedure/interpolated_command.rs | 34 +- .../mongodb-agent-common/src/procedure/mod.rs | 5 + .../src/query/arguments.rs | 36 +- .../src/query/execute_query_request.rs | 10 +- .../mongodb-agent-common/src/query/foreach.rs | 449 +++++++++++------- .../src/query/make_selector.rs | 153 +++--- crates/mongodb-agent-common/src/query/mod.rs | 1 + .../src/query/native_query.rs | 30 +- .../src/query/pipeline.rs | 21 +- .../src/query/query_variable_name.rs | 94 ++++ .../src/query/relations.rs | 16 +- .../src/query/response.rs | 33 +- crates/ndc-query-plan/src/lib.rs | 3 +- .../src/plan_for_query_request/mod.rs | 42 +- .../plan_test_helpers/mod.rs | 2 +- .../query_plan_state.rs | 61 ++- .../src/plan_for_query_request/tests.rs | 6 + crates/ndc-query-plan/src/query_plan.rs | 20 +- crates/ndc-query-plan/src/vec_set.rs | 80 ++++ crates/test-helpers/Cargo.toml | 1 + crates/test-helpers/src/arb_plan_type.rs | 27 ++ crates/test-helpers/src/lib.rs | 2 + fixtures/mongodb/chinook/chinook-import.sh | 2 + fixtures/mongodb/chinook/indexes.js | 20 + fixtures/mongodb/sample_import.sh | 1 + fixtures/mongodb/sample_mflix/indexes.js | 3 + 36 files changed, 933 insertions(+), 395 deletions(-) create mode 100644 crates/mongodb-agent-common/proptest-regressions/mongodb/sanitize.txt create mode 100644 crates/mongodb-agent-common/proptest-regressions/query/query_variable_name.txt create mode 100644 crates/mongodb-agent-common/src/query/query_variable_name.rs create mode 100644 crates/ndc-query-plan/src/vec_set.rs create mode 100644 crates/test-helpers/src/arb_plan_type.rs create mode 100644 fixtures/mongodb/chinook/indexes.js create mode 100644 fixtures/mongodb/sample_mflix/indexes.js diff --git a/CHANGELOG.md b/CHANGELOG.md index ba16f2df..b1382da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ This changelog documents the changes between release versions. ## [Unreleased] - Fix bug with operator lookup when filtering on nested fields ([#82](https://github.com/hasura/ndc-mongodb/pull/82)) +- Rework query plans for requests with variable sets to allow use of indexes ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) +- Fix: error when requesting query plan if MongoDB is target of a remote join ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) +- Breaking change: remote joins no longer work in MongoDB v5 ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) ## [0.1.0] - 2024-06-13 diff --git a/Cargo.lock b/Cargo.lock index 6759f32a..573a2132 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3184,6 +3184,7 @@ dependencies = [ "mongodb", "mongodb-support", "ndc-models", + "ndc-query-plan", "ndc-test-helpers", "proptest", ] diff --git a/crates/integration-tests/src/tests/remote_relationship.rs b/crates/integration-tests/src/tests/remote_relationship.rs index c5558d2e..c4a99608 100644 --- a/crates/integration-tests/src/tests/remote_relationship.rs +++ b/crates/integration-tests/src/tests/remote_relationship.rs @@ -5,6 +5,17 @@ use serde_json::json; #[tokio::test] async fn provides_source_and_target_for_remote_relationship() -> anyhow::Result<()> { + // Skip this test in MongoDB 5 because the example fails there. We're getting an error: + // + // > Kind: Command failed: Error code 5491300 (Location5491300): $documents' is not allowed in user requests, labels: {} + // + // This means that remote joins are not working in MongoDB 5 + if let Ok(image) = std::env::var("MONGODB_IMAGE") { + if image == "mongo:5" { + return Ok(()); + } + } + assert_yaml_snapshot!( graphql_query( r#" @@ -29,6 +40,17 @@ async fn provides_source_and_target_for_remote_relationship() -> anyhow::Result< #[tokio::test] async fn handles_request_with_single_variable_set() -> anyhow::Result<()> { + // Skip this test in MongoDB 5 because the example fails there. We're getting an error: + // + // > Kind: Command failed: Error code 5491300 (Location5491300): $documents' is not allowed in user requests, labels: {} + // + // This means that remote joins are not working in MongoDB 5 + if let Ok(image) = std::env::var("MONGODB_IMAGE") { + if image == "mongo:5" { + return Ok(()); + } + } + assert_yaml_snapshot!( run_connector_query( query_request() diff --git a/crates/mongodb-agent-common/proptest-regressions/mongodb/sanitize.txt b/crates/mongodb-agent-common/proptest-regressions/mongodb/sanitize.txt new file mode 100644 index 00000000..af838b34 --- /dev/null +++ b/crates/mongodb-agent-common/proptest-regressions/mongodb/sanitize.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 2357e8c9d6e3a68dfeff6f95a955a86d866c87c8d2a33afb9846fe8e1006402a # shrinks to input = "·" diff --git a/crates/mongodb-agent-common/proptest-regressions/query/query_variable_name.txt b/crates/mongodb-agent-common/proptest-regressions/query/query_variable_name.txt new file mode 100644 index 00000000..1aaebc12 --- /dev/null +++ b/crates/mongodb-agent-common/proptest-regressions/query/query_variable_name.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc fdd2dffdde1f114a438c67d891387aaca81b3df2676213ff17171208feb290ba # shrinks to variable_name = "", (type_a, type_b) = (Scalar(Bson(Double)), Scalar(Bson(Decimal))) diff --git a/crates/mongodb-agent-common/src/explain.rs b/crates/mongodb-agent-common/src/explain.rs index 738b3a73..8c924f76 100644 --- a/crates/mongodb-agent-common/src/explain.rs +++ b/crates/mongodb-agent-common/src/explain.rs @@ -22,9 +22,10 @@ pub async fn explain_query( let pipeline = query::pipeline_for_query_request(config, &query_plan)?; let pipeline_bson = to_bson(&pipeline)?; - let aggregate_target = match QueryTarget::for_request(config, &query_plan).input_collection() { - Some(collection_name) => Bson::String(collection_name.to_owned()), - None => Bson::Int32(1), + let target = QueryTarget::for_request(config, &query_plan); + let aggregate_target = match (target.input_collection(), query_plan.has_variables()) { + (Some(collection_name), false) => Bson::String(collection_name.to_owned()), + _ => Bson::Int32(1), }; let query_command = doc! { diff --git a/crates/mongodb-agent-common/src/interface_types/mongo_agent_error.rs b/crates/mongodb-agent-common/src/interface_types/mongo_agent_error.rs index b725e129..40b1dff1 100644 --- a/crates/mongodb-agent-common/src/interface_types/mongo_agent_error.rs +++ b/crates/mongodb-agent-common/src/interface_types/mongo_agent_error.rs @@ -26,7 +26,6 @@ pub enum MongoAgentError { Serialization(serde_json::Error), UnknownAggregationFunction(String), UnspecifiedRelation(String), - VariableNotDefined(String), AdHoc(#[from] anyhow::Error), } @@ -88,10 +87,6 @@ impl MongoAgentError { StatusCode::BAD_REQUEST, ErrorResponse::new(&format!("Query referenced a relationship, \"{relation}\", but did not include relation metadata in `table_relationships`")) ), - VariableNotDefined(variable_name) => ( - StatusCode::BAD_REQUEST, - ErrorResponse::new(&format!("Query referenced a variable, \"{variable_name}\", but it is not defined by the query request")) - ), AdHoc(err) => (StatusCode::INTERNAL_SERVER_ERROR, ErrorResponse::new(&err)), } } diff --git a/crates/mongodb-agent-common/src/mongo_query_plan/mod.rs b/crates/mongodb-agent-common/src/mongo_query_plan/mod.rs index 6fdc4e8f..b9a7a881 100644 --- a/crates/mongodb-agent-common/src/mongo_query_plan/mod.rs +++ b/crates/mongodb-agent-common/src/mongo_query_plan/mod.rs @@ -110,3 +110,4 @@ pub type QueryPlan = ndc_query_plan::QueryPlan; pub type Relationship = ndc_query_plan::Relationship; pub type Relationships = ndc_query_plan::Relationships; pub type Type = ndc_query_plan::Type; +pub type VariableTypes = ndc_query_plan::VariableTypes; diff --git a/crates/mongodb-agent-common/src/mongodb/sanitize.rs b/crates/mongodb-agent-common/src/mongodb/sanitize.rs index 5ac11794..b5f3f84b 100644 --- a/crates/mongodb-agent-common/src/mongodb/sanitize.rs +++ b/crates/mongodb-agent-common/src/mongodb/sanitize.rs @@ -2,8 +2,6 @@ use std::borrow::Cow; use anyhow::anyhow; use mongodb::bson::{doc, Document}; -use once_cell::sync::Lazy; -use regex::Regex; use crate::interface_types::MongoAgentError; @@ -15,28 +13,21 @@ pub fn get_field(name: &str) -> Document { doc! { "$getField": { "$literal": name } } } -/// Returns its input prefixed with "v_" if it is a valid MongoDB variable name. Valid names may -/// include the ASCII characters [_a-zA-Z0-9] or any non-ASCII characters. The exclusion of special -/// characters like `$` and `.` avoids potential code injection. -/// -/// We add the "v_" prefix because variable names may not begin with an underscore, but in some -/// cases, like when using relation-mapped column names as variable names, we want to be able to -/// use names like "_id". -/// -/// TODO: Instead of producing an error we could use an escaping scheme to unambiguously map -/// invalid characters to safe ones. -pub fn variable(name: &str) -> Result { - static VALID_EXPRESSION: Lazy = - Lazy::new(|| Regex::new(r"^[_a-zA-Z0-9\P{ascii}]+$").unwrap()); - if VALID_EXPRESSION.is_match(name) { - Ok(format!("v_{name}")) +/// Given a name returns a valid variable name for use in MongoDB aggregation expressions. Outputs +/// are guaranteed to be distinct for distinct inputs. Consistently returns the same output for the +/// same input string. +pub fn variable(name: &str) -> String { + let name_with_valid_initial = if name.chars().next().unwrap_or('!').is_ascii_lowercase() { + Cow::Borrowed(name) } else { - Err(MongoAgentError::InvalidVariableName(name.to_owned())) - } + Cow::Owned(format!("v_{name}")) + }; + escape_invalid_variable_chars(&name_with_valid_initial) } /// Returns false if the name contains characters that MongoDB will interpret specially, such as an -/// initial dollar sign, or dots. +/// initial dollar sign, or dots. This indicates whether a name is safe for field references +/// - variable names are more strict. pub fn is_name_safe(name: &str) -> bool { !(name.starts_with('$') || name.contains('.')) } @@ -52,3 +43,89 @@ pub fn safe_name(name: &str) -> Result, MongoAgentError> { Ok(Cow::Borrowed(name)) } } + +// The escape character must be a valid character in MongoDB variable names, but must not appear in +// lower-case hex strings. A non-ASCII character works if we specifically map it to a two-character +// hex escape sequence (see [ESCAPE_CHAR_ESCAPE_SEQUENCE]). Another option would be to use an +// allowed ASCII character such as 'x'. +const ESCAPE_CHAR: char = '·'; + +/// We want all escape sequences to be two-character hex strings so this must be a value that does +/// not represent an ASCII character, and that is <= 0xff. +const ESCAPE_CHAR_ESCAPE_SEQUENCE: u32 = 0xff; + +/// MongoDB variable names allow a limited set of ASCII characters, or any non-ASCII character. +/// See https://www.mongodb.com/docs/manual/reference/aggregation-variables/ +fn escape_invalid_variable_chars(input: &str) -> String { + let mut encoded = String::new(); + for char in input.chars() { + match char { + ESCAPE_CHAR => push_encoded_char(&mut encoded, ESCAPE_CHAR_ESCAPE_SEQUENCE), + 'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => encoded.push(char), + char if char as u32 <= 127 => push_encoded_char(&mut encoded, char as u32), + char => encoded.push(char), + } + } + encoded +} + +/// Escape invalid characters using the escape character followed by a two-character hex sequence +/// that gives the character's ASCII codepoint +fn push_encoded_char(encoded: &mut String, char: u32) { + encoded.push(ESCAPE_CHAR); + let zero_pad = if char < 0x10 { "0" } else { "" }; + encoded.push_str(&format!("{zero_pad}{char:x}")); +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::{escape_invalid_variable_chars, ESCAPE_CHAR, ESCAPE_CHAR_ESCAPE_SEQUENCE}; + + proptest! { + // Escaped strings must be consistent and distinct. A round-trip test demonstrates this. + #[test] + fn escaping_variable_chars_roundtrips(input: String) { + let encoded = escape_invalid_variable_chars(&input); + let decoded = unescape_invalid_variable_chars(&encoded); + prop_assert_eq!(decoded, input, "encoded string: {}", encoded) + } + } + + proptest! { + #[test] + fn escaped_variable_names_are_valid(input: String) { + let encoded = escape_invalid_variable_chars(&input); + prop_assert!( + encoded.chars().all(|char| + char as u32 > 127 || + char.is_ascii_alphanumeric() || + char == '_' + ), + "encoded string contains only valid characters\nencoded string: {}", + encoded + ) + } + } + + fn unescape_invalid_variable_chars(input: &str) -> String { + let mut decoded = String::new(); + let mut chars = input.chars(); + while let Some(char) = chars.next() { + if char == ESCAPE_CHAR { + let escape_sequence = [chars.next().unwrap(), chars.next().unwrap()]; + let code_point = + u32::from_str_radix(&escape_sequence.iter().collect::(), 16).unwrap(); + if code_point == ESCAPE_CHAR_ESCAPE_SEQUENCE { + decoded.push(ESCAPE_CHAR) + } else { + decoded.push(char::from_u32(code_point).unwrap()) + } + } else { + decoded.push(char) + } + } + decoded + } +} diff --git a/crates/mongodb-agent-common/src/mongodb/stage.rs b/crates/mongodb-agent-common/src/mongodb/stage.rs index addb6fe3..9845f922 100644 --- a/crates/mongodb-agent-common/src/mongodb/stage.rs +++ b/crates/mongodb-agent-common/src/mongodb/stage.rs @@ -11,6 +11,12 @@ use super::{accumulator::Accumulator, pipeline::Pipeline, Selection}; /// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#std-label-aggregation-pipeline-operator-reference #[derive(Clone, Debug, PartialEq, Serialize)] pub enum Stage { + /// Returns literal documents from input expressions. + /// + /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/documents/#mongodb-pipeline-pipe.-documents + #[serde(rename = "$documents")] + Documents(Vec), + /// Filters the document stream to allow only matching documents to pass unmodified into the /// next pipeline stage. [`$match`] uses standard MongoDB queries. For each input document, /// outputs either one document (a match) or zero documents (no match). diff --git a/crates/mongodb-agent-common/src/procedure/interpolated_command.rs b/crates/mongodb-agent-common/src/procedure/interpolated_command.rs index 59d8b488..b3e555c4 100644 --- a/crates/mongodb-agent-common/src/procedure/interpolated_command.rs +++ b/crates/mongodb-agent-common/src/procedure/interpolated_command.rs @@ -138,6 +138,7 @@ mod tests { use configuration::{native_mutation::NativeMutation, MongoScalarType}; use mongodb::bson::doc; use mongodb_support::BsonScalarType as S; + use ndc_models::Argument; use pretty_assertions::assert_eq; use serde_json::json; @@ -175,8 +176,13 @@ mod tests { }; let input_arguments = [ - ("id".to_owned(), json!(1001)), - ("name".to_owned(), json!("Regina Spektor")), + ("id".to_owned(), Argument::Literal { value: json!(1001) }), + ( + "name".to_owned(), + Argument::Literal { + value: json!("Regina Spektor"), + }, + ), ] .into_iter() .collect(); @@ -232,10 +238,12 @@ mod tests { let input_arguments = [( "documents".to_owned(), - json!([ - { "ArtistId": 1001, "Name": "Regina Spektor" } , - { "ArtistId": 1002, "Name": "Ok Go" } , - ]), + Argument::Literal { + value: json!([ + { "ArtistId": 1001, "Name": "Regina Spektor" } , + { "ArtistId": 1002, "Name": "Ok Go" } , + ]), + }, )] .into_iter() .collect(); @@ -289,8 +297,18 @@ mod tests { }; let input_arguments = [ - ("prefix".to_owned(), json!("current")), - ("basename".to_owned(), json!("some-coll")), + ( + "prefix".to_owned(), + Argument::Literal { + value: json!("current"), + }, + ), + ( + "basename".to_owned(), + Argument::Literal { + value: json!("some-coll"), + }, + ), ] .into_iter() .collect(); diff --git a/crates/mongodb-agent-common/src/procedure/mod.rs b/crates/mongodb-agent-common/src/procedure/mod.rs index 841f670a..42ec794e 100644 --- a/crates/mongodb-agent-common/src/procedure/mod.rs +++ b/crates/mongodb-agent-common/src/procedure/mod.rs @@ -7,6 +7,7 @@ use std::collections::BTreeMap; use configuration::native_mutation::NativeMutation; use mongodb::options::SelectionCriteria; use mongodb::{bson, Database}; +use ndc_models::Argument; use crate::mongo_query_plan::Type; use crate::query::arguments::resolve_arguments; @@ -61,6 +62,10 @@ fn interpolate( arguments: BTreeMap, command: &bson::Document, ) -> Result { + let arguments = arguments + .into_iter() + .map(|(name, value)| (name, Argument::Literal { value })) + .collect(); let bson_arguments = resolve_arguments(parameters, arguments)?; interpolated_command(command, &bson_arguments) } diff --git a/crates/mongodb-agent-common/src/query/arguments.rs b/crates/mongodb-agent-common/src/query/arguments.rs index be1d8066..f5889b02 100644 --- a/crates/mongodb-agent-common/src/query/arguments.rs +++ b/crates/mongodb-agent-common/src/query/arguments.rs @@ -3,12 +3,15 @@ use std::collections::BTreeMap; use indent::indent_all_by; use itertools::Itertools as _; use mongodb::bson::Bson; -use serde_json::Value; +use ndc_models::Argument; use thiserror::Error; use crate::mongo_query_plan::Type; -use super::serialization::{json_to_bson, JsonToBsonError}; +use super::{ + query_variable_name::query_variable_name, + serialization::{json_to_bson, JsonToBsonError}, +}; #[derive(Debug, Error)] pub enum ArgumentError { @@ -28,11 +31,11 @@ pub enum ArgumentError { /// map to declared parameters (no excess arguments). pub fn resolve_arguments( parameters: &BTreeMap, - mut arguments: BTreeMap, + mut arguments: BTreeMap, ) -> Result, ArgumentError> { validate_no_excess_arguments(parameters, &arguments)?; - let (arguments, missing): (Vec<(String, Value, &Type)>, Vec) = parameters + let (arguments, missing): (Vec<(String, Argument, &Type)>, Vec) = parameters .iter() .map(|(name, parameter_type)| { if let Some((name, argument)) = arguments.remove_entry(name) { @@ -48,12 +51,12 @@ pub fn resolve_arguments( let (resolved, errors): (BTreeMap, BTreeMap) = arguments .into_iter() - .map( - |(name, argument, parameter_type)| match json_to_bson(parameter_type, argument) { + .map(|(name, argument, parameter_type)| { + match argument_to_mongodb_expression(&argument, parameter_type) { Ok(bson) => Ok((name, bson)), Err(err) => Err((name, err)), - }, - ) + } + }) .partition_result(); if !errors.is_empty() { return Err(ArgumentError::Invalid(errors)); @@ -62,9 +65,22 @@ pub fn resolve_arguments( Ok(resolved) } -pub fn validate_no_excess_arguments( +fn argument_to_mongodb_expression( + argument: &Argument, + parameter_type: &Type, +) -> Result { + match argument { + Argument::Variable { name } => { + let mongodb_var_name = query_variable_name(name, parameter_type); + Ok(format!("$${mongodb_var_name}").into()) + } + Argument::Literal { value } => json_to_bson(parameter_type, value.clone()), + } +} + +pub fn validate_no_excess_arguments( parameters: &BTreeMap, - arguments: &BTreeMap, + arguments: &BTreeMap, ) -> Result<(), ArgumentError> { let excess: Vec = arguments .iter() diff --git a/crates/mongodb-agent-common/src/query/execute_query_request.rs b/crates/mongodb-agent-common/src/query/execute_query_request.rs index 7bbed719..9ff5c55b 100644 --- a/crates/mongodb-agent-common/src/query/execute_query_request.rs +++ b/crates/mongodb-agent-common/src/query/execute_query_request.rs @@ -57,8 +57,12 @@ async fn execute_query_pipeline( // The target of a query request might be a collection, or it might be a native query. In the // latter case there is no collection to perform the aggregation against. So instead of sending // the MongoDB API call `db..aggregate` we instead call `db.aggregate`. - let documents = match target.input_collection() { - Some(collection_name) => { + // + // If the query request includes variable sets then instead of specifying the target collection + // up front that is deferred until the `$lookup` stage of the aggregation pipeline. That is + // another case where we call `db.aggregate` instead of `db..aggregate`. + let documents = match (target.input_collection(), query_plan.has_variables()) { + (Some(collection_name), false) => { let collection = database.collection(collection_name); collect_response_documents( collection @@ -71,7 +75,7 @@ async fn execute_query_pipeline( ) .await } - None => { + _ => { collect_response_documents( database .aggregate(pipeline, None) diff --git a/crates/mongodb-agent-common/src/query/foreach.rs b/crates/mongodb-agent-common/src/query/foreach.rs index cf5e429e..e11b7d2e 100644 --- a/crates/mongodb-agent-common/src/query/foreach.rs +++ b/crates/mongodb-agent-common/src/query/foreach.rs @@ -1,58 +1,118 @@ -use mongodb::bson::{doc, Bson}; +use anyhow::anyhow; +use configuration::MongoScalarType; +use itertools::Itertools as _; +use mongodb::bson::{self, doc, Bson}; use ndc_query_plan::VariableSet; use super::pipeline::pipeline_for_non_foreach; use super::query_level::QueryLevel; -use crate::mongo_query_plan::{MongoConfiguration, QueryPlan}; +use super::query_variable_name::query_variable_name; +use super::serialization::json_to_bson; +use super::QueryTarget; +use crate::mongo_query_plan::{MongoConfiguration, QueryPlan, Type, VariableTypes}; use crate::mongodb::Selection; use crate::{ interface_types::MongoAgentError, mongodb::{Pipeline, Stage}, }; -const FACET_FIELD: &str = "__FACET__"; +type Result = std::result::Result; -/// Produces a complete MongoDB pipeline for a foreach query. -/// -/// For symmetry with [`super::execute_query_request::pipeline_for_query`] and -/// [`pipeline_for_non_foreach`] this function returns a pipeline paired with a value that -/// indicates whether the response requires post-processing in the agent. +/// Produces a complete MongoDB pipeline for a query request that includes variable sets. pub fn pipeline_for_foreach( - variable_sets: &[VariableSet], + request_variable_sets: &[VariableSet], config: &MongoConfiguration, query_request: &QueryPlan, -) -> Result { - let pipelines: Vec<(String, Pipeline)> = variable_sets +) -> Result { + let target = QueryTarget::for_request(config, query_request); + + let variable_sets = + variable_sets_to_bson(request_variable_sets, &query_request.variable_types)?; + + let variable_names = variable_sets .iter() - .enumerate() - .map(|(index, variables)| { - let pipeline = - pipeline_for_non_foreach(config, Some(variables), query_request, QueryLevel::Top)?; - Ok((facet_name(index), pipeline)) - }) - .collect::>()?; + .flat_map(|variable_set| variable_set.keys()); + let bindings: bson::Document = variable_names + .map(|name| (name.to_owned(), format!("${name}").into())) + .collect(); + + let variable_sets_stage = Stage::Documents(variable_sets); - let selection = Selection(doc! { - "row_sets": pipelines.iter().map(|(key, _)| - Bson::String(format!("${key}")), - ).collect::>() - }); + let query_pipeline = pipeline_for_non_foreach(config, query_request, QueryLevel::Top)?; - let queries = pipelines.into_iter().collect(); + let lookup_stage = Stage::Lookup { + from: target.input_collection().map(ToString::to_string), + local_field: None, + foreign_field: None, + r#let: Some(bindings), + pipeline: Some(query_pipeline), + r#as: "query".to_string(), + }; + + let selection = if query_request.query.has_aggregates() && query_request.query.has_fields() { + doc! { + "aggregates": { "$getField": { "input": { "$first": "$query" }, "field": "aggregates" } }, + "rows": { "$getField": { "input": { "$first": "$query" }, "field": "rows" } }, + } + } else if query_request.query.has_aggregates() { + doc! { + "aggregates": { "$getField": { "input": { "$first": "$query" }, "field": "aggregates" } }, + } + } else { + doc! { + "rows": "$query" + } + }; + let selection_stage = Stage::ReplaceWith(Selection(selection)); Ok(Pipeline { - stages: vec![Stage::Facet(queries), Stage::ReplaceWith(selection)], + stages: vec![variable_sets_stage, lookup_stage, selection_stage], }) } -fn facet_name(index: usize) -> String { - format!("{FACET_FIELD}_{index}") +fn variable_sets_to_bson( + variable_sets: &[VariableSet], + variable_types: &VariableTypes, +) -> Result> { + variable_sets + .iter() + .map(|variable_set| { + variable_set + .iter() + .flat_map(|(variable_name, value)| { + let types = variable_types.get(variable_name); + variable_to_bson(variable_name, value, types.iter().copied().flatten()) + .collect_vec() + }) + .try_collect() + }) + .try_collect() +} + +/// It may be necessary to include a request variable in the MongoDB pipeline multiple times if it +/// requires different BSON serializations. +fn variable_to_bson<'a>( + name: &'a str, + value: &'a serde_json::Value, + variable_types: impl IntoIterator> + 'a, +) -> impl Iterator> + 'a { + variable_types.into_iter().map(|t| { + let resolved_type = match t { + None => &Type::Scalar(MongoScalarType::ExtendedJSON), + Some(t) => t, + }; + let variable_name = query_variable_name(name, resolved_type); + let bson_value = json_to_bson(resolved_type, value.clone()) + .map_err(|e| MongoAgentError::BadQuery(anyhow!(e)))?; + Ok((variable_name, bson_value)) + }) } #[cfg(test)] mod tests { use configuration::Configuration; - use mongodb::bson::{bson, Bson}; + use itertools::Itertools as _; + use mongodb::bson::{bson, doc}; use ndc_test_helpers::{ binop, collection, field, named_type, object_type, query, query_request, query_response, row_set, star_count_aggregate, target, variable, @@ -62,7 +122,7 @@ mod tests { use crate::{ mongo_query_plan::MongoConfiguration, - mongodb::test_helpers::mock_collection_aggregate_response_for_pipeline, + mongodb::test_helpers::mock_aggregate_response_for_pipeline, query::execute_query_request::execute_query_request, }; @@ -80,31 +140,32 @@ mod tests { let expected_pipeline = bson!([ { - "$facet": { - "__FACET___0": [ - { "$match": { "artistId": { "$eq": 1 } } }, + "$documents": [ + { "artistId_int": 1 }, + { "artistId_int": 2 }, + ], + }, + { + "$lookup": { + "from": "tracks", + "let": { + "artistId_int": "$artistId_int", + }, + "as": "query", + "pipeline": [ + { "$match": { "$expr": { "$eq": ["$artistId", "$$artistId_int"] } } }, { "$replaceWith": { "albumId": { "$ifNull": ["$albumId", null] }, "title": { "$ifNull": ["$title", null] } } }, ], - "__FACET___1": [ - { "$match": { "artistId": { "$eq": 2 } } }, - { "$replaceWith": { - "albumId": { "$ifNull": ["$albumId", null] }, - "title": { "$ifNull": ["$title", null] } - } }, - ] }, }, { "$replaceWith": { - "row_sets": [ - "$__FACET___0", - "$__FACET___1", - ] - }, - } + "rows": "$query", + } + }, ]); let expected_response = query_response() @@ -121,21 +182,18 @@ mod tests { ]) .build(); - let db = mock_collection_aggregate_response_for_pipeline( - "tracks", + let db = mock_aggregate_response_for_pipeline( expected_pipeline, - bson!([{ - "row_sets": [ - [ - { "albumId": 1, "title": "For Those About To Rock We Salute You" }, - { "albumId": 4, "title": "Let There Be Rock" } - ], - [ - { "albumId": 2, "title": "Balls to the Wall" }, - { "albumId": 3, "title": "Restless and Wild" } - ], - ], - }]), + bson!([ + { "rows": [ + { "albumId": 1, "title": "For Those About To Rock We Salute You" }, + { "albumId": 4, "title": "Let There Be Rock" } + ] }, + { "rows": [ + { "albumId": 2, "title": "Balls to the Wall" }, + { "albumId": 3, "title": "Restless and Wild" } + ] }, + ]), ); let result = execute_query_request(db, &music_config(), query_request).await?; @@ -159,28 +217,20 @@ mod tests { let expected_pipeline = bson!([ { - "$facet": { - "__FACET___0": [ - { "$match": { "artistId": {"$eq": 1 }}}, - { "$facet": { - "__ROWS__": [{ "$replaceWith": { - "albumId": { "$ifNull": ["$albumId", null] }, - "title": { "$ifNull": ["$title", null] } - }}], - "count": [{ "$count": "result" }], - } }, - { "$replaceWith": { - "aggregates": { - "count": { "$getField": { - "field": "result", - "input": { "$first": { "$getField": { "$literal": "count" } } } - } }, - }, - "rows": "$__ROWS__", - } }, - ], - "__FACET___1": [ - { "$match": { "artistId": {"$eq": 2 }}}, + "$documents": [ + { "artistId_int": 1 }, + { "artistId_int": 2 }, + ] + }, + { + "$lookup": { + "from": "tracks", + "let": { + "artistId_int": "$artistId_int" + }, + "as": "query", + "pipeline": [ + { "$match": { "$expr": { "$eq": ["$artistId", "$$artistId_int"] } }}, { "$facet": { "__ROWS__": [{ "$replaceWith": { "albumId": { "$ifNull": ["$albumId", null] }, @@ -198,16 +248,14 @@ mod tests { "rows": "$__ROWS__", } }, ] - }, + } }, { "$replaceWith": { - "row_sets": [ - "$__FACET___0", - "$__FACET___1", - ] - }, - } + "aggregates": { "$getField": { "input": { "$first": "$query" }, "field": "aggregates" } }, + "rows": { "$getField": { "input": { "$first": "$query" }, "field": "rows" } }, + } + }, ]); let expected_response = query_response() @@ -232,31 +280,105 @@ mod tests { ) .build(); - let db = mock_collection_aggregate_response_for_pipeline( - "tracks", + let db = mock_aggregate_response_for_pipeline( expected_pipeline, - bson!([{ - "row_sets": [ - { - "aggregates": { - "count": 2, - }, - "rows": [ - { "albumId": 1, "title": "For Those About To Rock We Salute You" }, - { "albumId": 4, "title": "Let There Be Rock" }, - ] + bson!([ + { + "aggregates": { + "count": 2, }, - { - "aggregates": { - "count": 2, - }, - "rows": [ - { "albumId": 2, "title": "Balls to the Wall" }, - { "albumId": 3, "title": "Restless and Wild" }, - ] + "rows": [ + { "albumId": 1, "title": "For Those About To Rock We Salute You" }, + { "albumId": 4, "title": "Let There Be Rock" }, + ] + }, + { + "aggregates": { + "count": 2, }, + "rows": [ + { "albumId": 2, "title": "Balls to the Wall" }, + { "albumId": 3, "title": "Restless and Wild" }, + ] + }, + ]), + ); + + let result = execute_query_request(db, &music_config(), query_request).await?; + assert_eq!(expected_response, result); + + Ok(()) + } + + #[tokio::test] + async fn executes_query_with_variables_and_aggregates_and_no_rows() -> Result<(), anyhow::Error> + { + let query_request = query_request() + .collection("tracks") + .query( + query() + .aggregates([star_count_aggregate!("count")]) + .predicate(binop("_eq", target!("artistId"), variable!(artistId))), + ) + .variables([[("artistId", 1)], [("artistId", 2)]]) + .into(); + + let expected_pipeline = bson!([ + { + "$documents": [ + { "artistId_int": 1 }, + { "artistId_int": 2 }, ] - }]), + }, + { + "$lookup": { + "from": "tracks", + "let": { + "artistId_int": "$artistId_int" + }, + "as": "query", + "pipeline": [ + { "$match": { "$expr": { "$eq": ["$artistId", "$$artistId_int"] } }}, + { "$facet": { + "count": [{ "$count": "result" }], + } }, + { "$replaceWith": { + "aggregates": { + "count": { "$getField": { + "field": "result", + "input": { "$first": { "$getField": { "$literal": "count" } } } + } }, + }, + } }, + ] + } + }, + { + "$replaceWith": { + "aggregates": { "$getField": { "input": { "$first": "$query" }, "field": "aggregates" } }, + } + }, + ]); + + let expected_response = query_response() + .row_set(row_set().aggregates([("count", json!({ "$numberInt": "2" }))])) + .row_set(row_set().aggregates([("count", json!({ "$numberInt": "2" }))])) + .build(); + + let db = mock_aggregate_response_for_pipeline( + expected_pipeline, + bson!([ + { + "aggregates": { + "count": 2, + }, + }, + { + "aggregates": { + "count": 2, + }, + }, + ]), ); let result = execute_query_request(db, &music_config(), query_request).await?; @@ -277,51 +399,37 @@ mod tests { ) .into(); - fn facet(artist_id: i32) -> Bson { - bson!([ - { "$match": { "artistId": {"$eq": artist_id } } }, - { "$replaceWith": { - "albumId": { "$ifNull": ["$albumId", null] }, - "title": { "$ifNull": ["$title", null] } - } }, - ]) - } - let expected_pipeline = bson!([ { - "$facet": { - "__FACET___0": facet(1), - "__FACET___1": facet(2), - "__FACET___2": facet(3), - "__FACET___3": facet(4), - "__FACET___4": facet(5), - "__FACET___5": facet(6), - "__FACET___6": facet(7), - "__FACET___7": facet(8), - "__FACET___8": facet(9), - "__FACET___9": facet(10), - "__FACET___10": facet(11), - "__FACET___11": facet(12), - }, + "$documents": (1..=12).map(|artist_id| doc! { "artistId_int": artist_id }).collect_vec(), }, { - "$replaceWith": { - "row_sets": [ - "$__FACET___0", - "$__FACET___1", - "$__FACET___2", - "$__FACET___3", - "$__FACET___4", - "$__FACET___5", - "$__FACET___6", - "$__FACET___7", - "$__FACET___8", - "$__FACET___9", - "$__FACET___10", - "$__FACET___11", + "$lookup": { + "from": "tracks", + "let": { + "artistId_int": "$artistId_int" + }, + "as": "query", + "pipeline": [ + { + "$match": { + "$expr": { "$eq": ["$artistId", "$$artistId_int"] } + } + }, + { + "$replaceWith": { + "albumId": { "$ifNull": ["$albumId", null] }, + "title": { "$ifNull": ["$title", null] } + } + }, ] - }, - } + } + }, + { + "$replaceWith": { + "rows": "$query" + } + }, ]); let expected_response = query_response() @@ -347,30 +455,27 @@ mod tests { .empty_row_set() .build(); - let db = mock_collection_aggregate_response_for_pipeline( - "tracks", + let db = mock_aggregate_response_for_pipeline( expected_pipeline, - bson!([{ - "row_sets": [ - [ - { "albumId": 1, "title": "For Those About To Rock We Salute You" }, - { "albumId": 4, "title": "Let There Be Rock" } - ], - [], - [ - { "albumId": 2, "title": "Balls to the Wall" }, - { "albumId": 3, "title": "Restless and Wild" } - ], - [], - [], - [], - [], - [], - [], - [], - [], - ], - }]), + bson!([ + { "rows": [ + { "albumId": 1, "title": "For Those About To Rock We Salute You" }, + { "albumId": 4, "title": "Let There Be Rock" } + ] }, + { "rows": [] }, + { "rows": [ + { "albumId": 2, "title": "Balls to the Wall" }, + { "albumId": 3, "title": "Restless and Wild" } + ] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + { "rows": [] }, + ]), ); let result = execute_query_request(db, &music_config(), query_request).await?; diff --git a/crates/mongodb-agent-common/src/query/make_selector.rs b/crates/mongodb-agent-common/src/query/make_selector.rs index 8cda7c46..ea2bf197 100644 --- a/crates/mongodb-agent-common/src/query/make_selector.rs +++ b/crates/mongodb-agent-common/src/query/make_selector.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap; - use anyhow::anyhow; use mongodb::bson::{self, doc, Document}; use ndc_models::UnaryComparisonOperator; @@ -11,7 +9,7 @@ use crate::{ query::column_ref::{column_expression, ColumnRef}, }; -use super::serialization::json_to_bson; +use super::{query_variable_name::query_variable_name, serialization::json_to_bson}; pub type Result = std::result::Result; @@ -21,16 +19,13 @@ fn bson_from_scalar_value(value: &serde_json::Value, value_type: &Type) -> Resul json_to_bson(value_type, value.clone()).map_err(|e| MongoAgentError::BadQuery(anyhow!(e))) } -pub fn make_selector( - variables: Option<&BTreeMap>, - expr: &Expression, -) -> Result { +pub fn make_selector(expr: &Expression) -> Result { match expr { Expression::And { expressions } => { let sub_exps: Vec = expressions .clone() .iter() - .map(|e| make_selector(variables, e)) + .map(make_selector) .collect::>()?; Ok(doc! {"$and": sub_exps}) } @@ -38,20 +33,18 @@ pub fn make_selector( let sub_exps: Vec = expressions .clone() .iter() - .map(|e| make_selector(variables, e)) + .map(make_selector) .collect::>()?; Ok(doc! {"$or": sub_exps}) } - Expression::Not { expression } => { - Ok(doc! { "$nor": [make_selector(variables, expression)?]}) - } + Expression::Not { expression } => Ok(doc! { "$nor": [make_selector(expression)?]}), Expression::Exists { in_collection, predicate, } => Ok(match in_collection { ExistsInCollection::Related { relationship } => match predicate { Some(predicate) => doc! { - relationship: { "$elemMatch": make_selector(variables, predicate)? } + relationship: { "$elemMatch": make_selector(predicate)? } }, None => doc! { format!("{relationship}.0"): { "$exists": true } }, }, @@ -67,7 +60,7 @@ pub fn make_selector( column, operator, value, - } => make_binary_comparison_selector(variables, column, operator, value), + } => make_binary_comparison_selector(column, operator, value), Expression::UnaryComparisonOperator { column, operator } => match operator { UnaryComparisonOperator::IsNull => { let match_doc = match ColumnRef::from_comparison_target(column) { @@ -90,7 +83,6 @@ pub fn make_selector( } fn make_binary_comparison_selector( - variables: Option<&BTreeMap>, target_column: &ComparisonTarget, operator: &ComparisonFunction, value: &ComparisonValue, @@ -117,9 +109,9 @@ fn make_binary_comparison_selector( let comparison_value = bson_from_scalar_value(value, value_type)?; let match_doc = match ColumnRef::from_comparison_target(target_column) { ColumnRef::MatchKey(key) => operator.mongodb_match_query(key, comparison_value), - ColumnRef::Expression(expr) => { - operator.mongodb_aggregation_expression(expr, comparison_value) - } + ColumnRef::Expression(expr) => doc! { + "$expr": operator.mongodb_aggregation_expression(expr, comparison_value) + }, }; traverse_relationship_path(target_column.relationship_path(), match_doc) } @@ -127,13 +119,12 @@ fn make_binary_comparison_selector( name, variable_type, } => { - let comparison_value = - variable_to_mongo_expression(variables, name, variable_type).map(Into::into)?; - let match_doc = match ColumnRef::from_comparison_target(target_column) { - ColumnRef::MatchKey(key) => operator.mongodb_match_query(key, comparison_value), - ColumnRef::Expression(expr) => { - operator.mongodb_aggregation_expression(expr, comparison_value) - } + let comparison_value = variable_to_mongo_expression(name, variable_type); + let match_doc = doc! { + "$expr": operator.mongodb_aggregation_expression( + column_expression(target_column), + comparison_value + ) }; traverse_relationship_path(target_column.relationship_path(), match_doc) } @@ -157,16 +148,9 @@ fn traverse_relationship_path(path: &[String], mut expression: Document) -> Docu expression } -fn variable_to_mongo_expression( - variables: Option<&BTreeMap>, - variable: &str, - value_type: &Type, -) -> Result { - let value = variables - .and_then(|vars| vars.get(variable)) - .ok_or_else(|| MongoAgentError::VariableNotDefined(variable.to_owned()))?; - - bson_from_scalar_value(value, value_type) +fn variable_to_mongo_expression(variable: &str, value_type: &Type) -> bson::Bson { + let mongodb_var_name = query_variable_name(variable, value_type); + format!("$${mongodb_var_name}").into() } #[cfg(test)] @@ -175,7 +159,7 @@ mod tests { use mongodb::bson::{self, bson, doc}; use mongodb_support::BsonScalarType; use ndc_models::UnaryComparisonOperator; - use ndc_query_plan::plan_for_query_request; + use ndc_query_plan::{plan_for_query_request, Scope}; use ndc_test_helpers::{ binop, column_value, path_element, query, query_request, relation_field, root, target, value, @@ -194,22 +178,19 @@ mod tests { #[test] fn compares_fields_of_related_documents_using_elem_match_in_binary_comparison( ) -> anyhow::Result<()> { - let selector = make_selector( - None, - &Expression::BinaryComparisonOperator { - column: ComparisonTarget::Column { - name: "Name".to_owned(), - field_path: None, - field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), - path: vec!["Albums".into(), "Tracks".into()], - }, - operator: ComparisonFunction::Equal, - value: ComparisonValue::Scalar { - value: "Helter Skelter".into(), - value_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), - }, + let selector = make_selector(&Expression::BinaryComparisonOperator { + column: ComparisonTarget::Column { + name: "Name".to_owned(), + field_path: None, + field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + path: vec!["Albums".into(), "Tracks".into()], }, - )?; + operator: ComparisonFunction::Equal, + value: ComparisonValue::Scalar { + value: "Helter Skelter".into(), + value_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + }, + })?; let expected = doc! { "Albums": { @@ -230,18 +211,15 @@ mod tests { #[test] fn compares_fields_of_related_documents_using_elem_match_in_unary_comparison( ) -> anyhow::Result<()> { - let selector = make_selector( - None, - &Expression::UnaryComparisonOperator { - column: ComparisonTarget::Column { - name: "Name".to_owned(), - field_path: None, - field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), - path: vec!["Albums".into(), "Tracks".into()], - }, - operator: UnaryComparisonOperator::IsNull, + let selector = make_selector(&Expression::UnaryComparisonOperator { + column: ComparisonTarget::Column { + name: "Name".to_owned(), + field_path: None, + field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + path: vec!["Albums".into(), "Tracks".into()], }, - )?; + operator: UnaryComparisonOperator::IsNull, + })?; let expected = doc! { "Albums": { @@ -261,26 +239,23 @@ mod tests { #[test] fn compares_two_columns() -> anyhow::Result<()> { - let selector = make_selector( - None, - &Expression::BinaryComparisonOperator { + let selector = make_selector(&Expression::BinaryComparisonOperator { + column: ComparisonTarget::Column { + name: "Name".to_owned(), + field_path: None, + field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + path: Default::default(), + }, + operator: ComparisonFunction::Equal, + value: ComparisonValue::Column { column: ComparisonTarget::Column { - name: "Name".to_owned(), + name: "Title".to_owned(), field_path: None, field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), path: Default::default(), }, - operator: ComparisonFunction::Equal, - value: ComparisonValue::Column { - column: ComparisonTarget::Column { - name: "Title".to_owned(), - field_path: None, - field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), - path: Default::default(), - }, - }, }, - )?; + })?; let expected = doc! { "$expr": { @@ -292,6 +267,32 @@ mod tests { Ok(()) } + #[test] + fn compares_root_collection_column_to_scalar() -> anyhow::Result<()> { + let selector = make_selector(&Expression::BinaryComparisonOperator { + column: ComparisonTarget::ColumnInScope { + name: "Name".to_owned(), + field_path: None, + field_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + scope: Scope::Named("scope_0".to_string()), + }, + operator: ComparisonFunction::Equal, + value: ComparisonValue::Scalar { + value: "Lady Gaga".into(), + value_type: Type::Scalar(MongoScalarType::Bson(BsonScalarType::String)), + }, + })?; + + let expected = doc! { + "$expr": { + "$eq": ["$$scope_0.Name", "Lady Gaga"] + } + }; + + assert_eq!(selector, expected); + Ok(()) + } + #[test] fn root_column_reference_refereces_column_of_nearest_query() -> anyhow::Result<()> { let request = query_request() diff --git a/crates/mongodb-agent-common/src/query/mod.rs b/crates/mongodb-agent-common/src/query/mod.rs index 2f574656..2a4f82b3 100644 --- a/crates/mongodb-agent-common/src/query/mod.rs +++ b/crates/mongodb-agent-common/src/query/mod.rs @@ -9,6 +9,7 @@ mod native_query; mod pipeline; mod query_level; mod query_target; +mod query_variable_name; mod relations; pub mod response; pub mod serialization; diff --git a/crates/mongodb-agent-common/src/query/native_query.rs b/crates/mongodb-agent-common/src/query/native_query.rs index 0df1fbf6..56ffc4dc 100644 --- a/crates/mongodb-agent-common/src/query/native_query.rs +++ b/crates/mongodb-agent-common/src/query/native_query.rs @@ -3,7 +3,6 @@ use std::collections::BTreeMap; use configuration::native_query::NativeQuery; use itertools::Itertools as _; use ndc_models::Argument; -use ndc_query_plan::VariableSet; use crate::{ interface_types::MongoAgentError, @@ -18,7 +17,6 @@ use super::{arguments::resolve_arguments, query_target::QueryTarget}; /// an empty pipeline if the query request target is not a native query pub fn pipeline_for_native_query( config: &MongoConfiguration, - variables: Option<&VariableSet>, query_request: &QueryPlan, ) -> Result { match QueryTarget::for_request(config, query_request) { @@ -27,26 +25,15 @@ pub fn pipeline_for_native_query( native_query, arguments, .. - } => make_pipeline(variables, native_query, arguments), + } => make_pipeline(native_query, arguments), } } fn make_pipeline( - variables: Option<&VariableSet>, native_query: &NativeQuery, arguments: &BTreeMap, ) -> Result { - let expressions = arguments - .iter() - .map(|(name, argument)| { - Ok(( - name.to_owned(), - argument_to_mongodb_expression(argument, variables)?, - )) as Result<_, MongoAgentError> - }) - .try_collect()?; - - let bson_arguments = resolve_arguments(&native_query.arguments, expressions) + let bson_arguments = resolve_arguments(&native_query.arguments, arguments.clone()) .map_err(ProcedureError::UnresolvableArguments)?; // Replace argument placeholders with resolved expressions, convert document list to @@ -61,19 +48,6 @@ fn make_pipeline( Ok(Pipeline::new(stages)) } -fn argument_to_mongodb_expression( - argument: &Argument, - variables: Option<&VariableSet>, -) -> Result { - match argument { - Argument::Variable { name } => variables - .and_then(|vs| vs.get(name)) - .ok_or_else(|| MongoAgentError::VariableNotDefined(name.to_owned())) - .cloned(), - Argument::Literal { value } => Ok(value.clone()), - } -} - #[cfg(test)] mod tests { use configuration::{ diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index 03e280f3..ca82df78 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -1,7 +1,6 @@ use std::collections::BTreeMap; use mongodb::bson::{self, doc, Bson}; -use ndc_query_plan::VariableSet; use tracing::instrument; use crate::{ @@ -31,9 +30,6 @@ pub fn is_response_faceted(query: &Query) -> bool { } /// Shared logic to produce a MongoDB aggregation pipeline for a query request. -/// -/// Returns a pipeline paired with a value that indicates whether the response requires -/// post-processing in the agent. #[instrument(name = "Build Query Pipeline" skip_all, fields(internal.visibility = "user"))] pub fn pipeline_for_query_request( config: &MongoConfiguration, @@ -42,18 +38,15 @@ pub fn pipeline_for_query_request( if let Some(variable_sets) = &query_plan.variables { pipeline_for_foreach(variable_sets, config, query_plan) } else { - pipeline_for_non_foreach(config, None, query_plan, QueryLevel::Top) + pipeline_for_non_foreach(config, query_plan, QueryLevel::Top) } } -/// Produces a pipeline for a non-foreach query request, or for one variant of a foreach query -/// request. -/// -/// Returns a pipeline paired with a value that indicates whether the response requires -/// post-processing in the agent. +/// Produces a pipeline for a query request that does not include variable sets, or produces +/// a sub-pipeline to be used inside of a larger pipeline for a query request that does include +/// variable sets. pub fn pipeline_for_non_foreach( config: &MongoConfiguration, - variables: Option<&VariableSet>, query_plan: &QueryPlan, query_level: QueryLevel, ) -> Result { @@ -67,14 +60,14 @@ pub fn pipeline_for_non_foreach( let mut pipeline = Pipeline::empty(); // If this is a native query then we start with the native query's pipeline - pipeline.append(pipeline_for_native_query(config, variables, query_plan)?); + pipeline.append(pipeline_for_native_query(config, query_plan)?); // Stages common to aggregate and row queries. - pipeline.append(pipeline_for_relations(config, variables, query_plan)?); + pipeline.append(pipeline_for_relations(config, query_plan)?); let match_stage = predicate .as_ref() - .map(|expression| make_selector(variables, expression)) + .map(make_selector) .transpose()? .map(Stage::Match); let sort_stage: Option = order_by diff --git a/crates/mongodb-agent-common/src/query/query_variable_name.rs b/crates/mongodb-agent-common/src/query/query_variable_name.rs new file mode 100644 index 00000000..1778a700 --- /dev/null +++ b/crates/mongodb-agent-common/src/query/query_variable_name.rs @@ -0,0 +1,94 @@ +use std::borrow::Cow; + +use configuration::MongoScalarType; + +use crate::{ + mongo_query_plan::{ObjectType, Type}, + mongodb::sanitize::variable, +}; + +/// Maps a variable name and type from a [ndc_models::QueryRequest] `variables` map to a variable +/// name for use in a MongoDB aggregation pipeline. The type is incorporated into the produced name +/// because it is possible the same request variable may be used in different type contexts, which +/// may require different BSON conversions for the different contexts. +/// +/// This function has some important requirements: +/// +/// - reproducibility: the same input name and type must always produce the same output name +/// - distinct outputs: inputs with different types (or names) must produce different output names +/// - It must produce a valid MongoDB variable name (see https://www.mongodb.com/docs/manual/reference/aggregation-variables/) +pub fn query_variable_name(name: &str, variable_type: &Type) -> String { + variable(&format!("{}_{}", name, type_name(variable_type))) +} + +fn type_name(input_type: &Type) -> Cow<'static, str> { + match input_type { + Type::Scalar(MongoScalarType::Bson(t)) => t.bson_name().into(), + Type::Scalar(MongoScalarType::ExtendedJSON) => "unknown".into(), + Type::Object(obj) => object_type_name(obj).into(), + Type::ArrayOf(t) => format!("[{}]", type_name(t)).into(), + Type::Nullable(t) => format!("nullable({})", type_name(t)).into(), + } +} + +fn object_type_name(obj: &ObjectType) -> String { + let mut output = "{".to_string(); + for (key, t) in &obj.fields { + output.push_str(&format!("{key}:{}", type_name(t))); + } + output.push('}'); + output +} + +#[cfg(test)] +mod tests { + use once_cell::sync::Lazy; + use proptest::prelude::*; + use regex::Regex; + use test_helpers::arb_plan_type; + + use super::query_variable_name; + + proptest! { + #[test] + fn variable_names_are_reproducible(variable_name: String, variable_type in arb_plan_type()) { + let a = query_variable_name(&variable_name, &variable_type); + let b = query_variable_name(&variable_name, &variable_type); + prop_assert_eq!(a, b) + } + } + + proptest! { + #[test] + fn variable_names_are_distinct_when_input_names_are_distinct( + (name_a, name_b) in (any::(), any::()).prop_filter("names are equale", |(a, b)| a != b), + variable_type in arb_plan_type() + ) { + let a = query_variable_name(&name_a, &variable_type); + let b = query_variable_name(&name_b, &variable_type); + prop_assert_ne!(a, b) + } + } + + proptest! { + #[test] + fn variable_names_are_distinct_when_types_are_distinct( + variable_name: String, + (type_a, type_b) in (arb_plan_type(), arb_plan_type()).prop_filter("types are equal", |(a, b)| a != b) + ) { + let a = query_variable_name(&variable_name, &type_a); + let b = query_variable_name(&variable_name, &type_b); + prop_assert_ne!(a, b) + } + } + + proptest! { + #[test] + fn variable_names_are_valid_for_mongodb_expressions(variable_name: String, variable_type in arb_plan_type()) { + static VALID_NAME: Lazy = + Lazy::new(|| Regex::new(r"^[a-z\P{ascii}][_a-zA-Z0-9\P{ascii}]*$").unwrap()); + let name = query_variable_name(&variable_name, &variable_type); + prop_assert!(VALID_NAME.is_match(&name)) + } + } +} diff --git a/crates/mongodb-agent-common/src/query/relations.rs b/crates/mongodb-agent-common/src/query/relations.rs index c700a653..22a162b0 100644 --- a/crates/mongodb-agent-common/src/query/relations.rs +++ b/crates/mongodb-agent-common/src/query/relations.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use itertools::Itertools as _; use mongodb::bson::{doc, Bson, Document}; -use ndc_query_plan::{Scope, VariableSet}; +use ndc_query_plan::Scope; use crate::mongo_query_plan::{MongoConfiguration, Query, QueryPlan}; use crate::mongodb::sanitize::safe_name; @@ -22,7 +22,6 @@ type Result = std::result::Result; /// each sub-query in the plan. pub fn pipeline_for_relations( config: &MongoConfiguration, - variables: Option<&VariableSet>, query_plan: &QueryPlan, ) -> Result { let QueryPlan { query, .. } = query_plan; @@ -40,7 +39,6 @@ pub fn pipeline_for_relations( // Recursively build pipeline according to relation query let lookup_pipeline = pipeline_for_non_foreach( config, - variables, &QueryPlan { query: relationship.query.clone(), collection: relationship.target_collection.clone(), @@ -125,7 +123,7 @@ fn multiple_column_mapping_lookup( .keys() .map(|local_field| { Ok(( - variable(local_field)?, + variable(local_field), Bson::String(format!("${}", safe_name(local_field)?.into_owned())), )) }) @@ -145,7 +143,7 @@ fn multiple_column_mapping_lookup( .into_iter() .map(|(local_field, remote_field)| { Ok(doc! { "$eq": [ - format!("$${}", variable(local_field)?), + format!("$${}", variable(local_field)), format!("${}", safe_name(remote_field)?) ] }) }) @@ -400,16 +398,16 @@ mod tests { "$lookup": { "from": "students", "let": { - "v_year": "$year", - "v_title": "$title", + "year": "$year", + "title": "$title", "scope_root": "$$ROOT", }, "pipeline": [ { "$match": { "$expr": { "$and": [ - { "$eq": ["$$v_title", "$class_title"] }, - { "$eq": ["$$v_year", "$year"] }, + { "$eq": ["$$title", "$class_title"] }, + { "$eq": ["$$year", "$year"] }, ], } }, }, diff --git a/crates/mongodb-agent-common/src/query/response.rs b/crates/mongodb-agent-common/src/query/response.rs index 3149b7b1..850813ca 100644 --- a/crates/mongodb-agent-common/src/query/response.rs +++ b/crates/mongodb-agent-common/src/query/response.rs @@ -39,18 +39,6 @@ pub enum QueryResponseError { type Result = std::result::Result; -// These structs describe possible shapes of data returned by MongoDB query plans - -#[derive(Debug, Deserialize)] -struct ResponseForVariableSetsRowsOnly { - row_sets: Vec>, -} - -#[derive(Debug, Deserialize)] -struct ResponseForVariableSetsAggregates { - row_sets: Vec, -} - #[derive(Debug, Deserialize)] struct BsonRowSet { #[serde(default)] @@ -66,27 +54,14 @@ pub fn serialize_query_response( ) -> Result { let collection_name = &query_plan.collection; - // If the query request specified variable sets then we should have gotten a single document - // from MongoDB with fields for multiple sets of results - one for each set of variables. - let row_sets = if query_plan.has_variables() && query_plan.query.has_aggregates() { - let responses: ResponseForVariableSetsAggregates = - parse_single_document(response_documents)?; - responses - .row_sets + let row_sets = if query_plan.has_variables() { + response_documents .into_iter() - .map(|row_set| { + .map(|document| { + let row_set = bson::from_document(document)?; serialize_row_set_with_aggregates(&[collection_name], &query_plan.query, row_set) }) .try_collect() - } else if query_plan.variables.is_some() { - let responses: ResponseForVariableSetsRowsOnly = parse_single_document(response_documents)?; - responses - .row_sets - .into_iter() - .map(|row_set| { - serialize_row_set_rows_only(&[collection_name], &query_plan.query, row_set) - }) - .try_collect() } else if query_plan.query.has_aggregates() { let row_set = parse_single_document(response_documents)?; Ok(vec![serialize_row_set_with_aggregates( diff --git a/crates/ndc-query-plan/src/lib.rs b/crates/ndc-query-plan/src/lib.rs index 7ce74bd1..1bfb5e3a 100644 --- a/crates/ndc-query-plan/src/lib.rs +++ b/crates/ndc-query-plan/src/lib.rs @@ -1,6 +1,7 @@ mod plan_for_query_request; mod query_plan; mod type_system; +pub mod vec_set; pub use plan_for_query_request::{ plan_for_query_request, @@ -12,6 +13,6 @@ pub use query_plan::{ Aggregate, AggregateFunctionDefinition, ComparisonOperatorDefinition, ComparisonTarget, ComparisonValue, ConnectorTypes, ExistsInCollection, Expression, Field, NestedArray, NestedField, NestedObject, OrderBy, OrderByElement, OrderByTarget, Query, QueryPlan, - Relationship, Relationships, Scope, VariableSet, + Relationship, Relationships, Scope, VariableSet, VariableTypes, }; pub use type_system::{inline_object_types, ObjectType, Type}; diff --git a/crates/ndc-query-plan/src/plan_for_query_request/mod.rs b/crates/ndc-query-plan/src/plan_for_query_request/mod.rs index 766a7a89..f628123c 100644 --- a/crates/ndc-query-plan/src/plan_for_query_request/mod.rs +++ b/crates/ndc-query-plan/src/plan_for_query_request/mod.rs @@ -17,6 +17,7 @@ use indexmap::IndexMap; use itertools::Itertools; use ndc::{ExistsInCollection, QueryRequest}; use ndc_models as ndc; +use query_plan_state::QueryPlanInfo; use self::{ helpers::{find_object_field, find_object_field_path, lookup_relationship}, @@ -42,14 +43,38 @@ pub fn plan_for_query_request( )?; query.scope = Some(Scope::Root); - let unrelated_collections = plan_state.into_unrelated_collections(); + let QueryPlanInfo { + unrelated_joins, + variable_types, + } = plan_state.into_query_plan_info(); + + // If there are variables that don't have corresponding entries in the variable_types map that + // means that those variables were not observed in the query. Filter them out because we don't + // need them, and we don't want users to have to deal with variables with unknown types. + let variables = request.variables.map(|variable_sets| { + variable_sets + .into_iter() + .map(|variable_set| { + variable_set + .into_iter() + .filter(|(var_name, _)| { + variable_types + .get(var_name) + .map(|types| !types.is_empty()) + .unwrap_or(false) + }) + .collect() + }) + .collect() + }); Ok(QueryPlan { collection: request.collection, arguments: request.arguments, query, - variables: request.variables, - unrelated_collections, + variables, + variable_types, + unrelated_collections: unrelated_joins, }) } @@ -559,10 +584,13 @@ fn plan_for_comparison_value( value, value_type: expected_type, }), - ndc::ComparisonValue::Variable { name } => Ok(plan::ComparisonValue::Variable { - name, - variable_type: expected_type, - }), + ndc::ComparisonValue::Variable { name } => { + plan_state.register_variable_use(&name, expected_type.clone()); + Ok(plan::ComparisonValue::Variable { + name, + variable_type: expected_type, + }) + } } } diff --git a/crates/ndc-query-plan/src/plan_for_query_request/plan_test_helpers/mod.rs b/crates/ndc-query-plan/src/plan_for_query_request/plan_test_helpers/mod.rs index 45da89fe..31cee380 100644 --- a/crates/ndc-query-plan/src/plan_for_query_request/plan_test_helpers/mod.rs +++ b/crates/ndc-query-plan/src/plan_for_query_request/plan_test_helpers/mod.rs @@ -122,7 +122,7 @@ impl NamedEnum for ComparisonOperator { } } -#[derive(Clone, Copy, Debug, PartialEq, Sequence)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Sequence)] pub enum ScalarType { Bool, Date, diff --git a/crates/ndc-query-plan/src/plan_for_query_request/query_plan_state.rs b/crates/ndc-query-plan/src/plan_for_query_request/query_plan_state.rs index 5ea76bb0..e5a4c78c 100644 --- a/crates/ndc-query-plan/src/plan_for_query_request/query_plan_state.rs +++ b/crates/ndc-query-plan/src/plan_for_query_request/query_plan_state.rs @@ -9,8 +9,9 @@ use ndc_models as ndc; use crate::{ plan_for_query_request::helpers::lookup_relationship, - query_plan::{Scope, UnrelatedJoin}, - Query, QueryContext, QueryPlanError, Relationship, + query_plan::{Scope, UnrelatedJoin, VariableTypes}, + vec_set::VecSet, + ConnectorTypes, Query, QueryContext, QueryPlanError, Relationship, Type, }; use super::unify_relationship_references::unify_relationship_references; @@ -32,6 +33,7 @@ pub struct QueryPlanState<'a, T: QueryContext> { unrelated_joins: Rc>>>, relationship_name_counter: Rc>, scope_name_counter: Rc>, + variable_types: Rc>>, } impl QueryPlanState<'_, T> { @@ -47,6 +49,7 @@ impl QueryPlanState<'_, T> { unrelated_joins: Rc::new(RefCell::new(Default::default())), relationship_name_counter: Rc::new(Cell::new(0)), scope_name_counter: Rc::new(Cell::new(0)), + variable_types: Rc::new(RefCell::new(Default::default())), } } @@ -62,6 +65,7 @@ impl QueryPlanState<'_, T> { unrelated_joins: self.unrelated_joins.clone(), relationship_name_counter: self.relationship_name_counter.clone(), scope_name_counter: self.scope_name_counter.clone(), + variable_types: self.variable_types.clone(), } } @@ -81,6 +85,13 @@ impl QueryPlanState<'_, T> { let ndc_relationship = lookup_relationship(self.collection_relationships, &ndc_relationship_name)?; + for argument in arguments.values() { + if let RelationshipArgument::Variable { name } = argument { + // TODO: Is there a way to infer a type here? + self.register_variable_use_of_unknown_type(name) + } + } + let relationship = Relationship { column_mapping: ndc_relationship.column_mapping.clone(), relationship_type: ndc_relationship.relationship_type, @@ -141,6 +152,36 @@ impl QueryPlanState<'_, T> { key } + /// It's important to call this for every use of a variable encountered when building + /// a [crate::QueryPlan] so we can capture types for each variable. + pub fn register_variable_use( + &mut self, + variable_name: &str, + expected_type: Type, + ) { + self.register_variable_use_helper(variable_name, Some(expected_type)) + } + + pub fn register_variable_use_of_unknown_type(&mut self, variable_name: &str) { + self.register_variable_use_helper(variable_name, None) + } + + fn register_variable_use_helper( + &mut self, + variable_name: &str, + expected_type: Option>, + ) { + let mut type_map = self.variable_types.borrow_mut(); + match type_map.get_mut(variable_name) { + None => { + type_map.insert(variable_name.to_string(), VecSet::singleton(expected_type)); + } + Some(entry) => { + entry.insert(expected_type); + } + } + } + /// Use this for subquery plans to get the relationships for each sub-query pub fn into_relationships(self) -> BTreeMap> { self.relationships @@ -150,9 +191,12 @@ impl QueryPlanState<'_, T> { self.scope } - /// Use this with the top-level plan to get unrelated joins. - pub fn into_unrelated_collections(self) -> BTreeMap> { - self.unrelated_joins.take() + /// Use this with the top-level plan to get unrelated joins and variable types + pub fn into_query_plan_info(self) -> QueryPlanInfo { + QueryPlanInfo { + unrelated_joins: self.unrelated_joins.take(), + variable_types: self.variable_types.take(), + } } fn unique_relationship_name(&mut self, name: impl std::fmt::Display) -> String { @@ -167,3 +211,10 @@ impl QueryPlanState<'_, T> { format!("scope_{count}") } } + +/// Data extracted from [QueryPlanState] for use in building top-level [crate::QueryPlan] +#[derive(Debug)] +pub struct QueryPlanInfo { + pub unrelated_joins: BTreeMap>, + pub variable_types: VariableTypes, +} diff --git a/crates/ndc-query-plan/src/plan_for_query_request/tests.rs b/crates/ndc-query-plan/src/plan_for_query_request/tests.rs index a9e40b39..82472f1b 100644 --- a/crates/ndc-query-plan/src/plan_for_query_request/tests.rs +++ b/crates/ndc-query-plan/src/plan_for_query_request/tests.rs @@ -90,6 +90,7 @@ fn translates_query_request_relationships() -> Result<(), anyhow::Error> { collection: "schools".to_owned(), arguments: Default::default(), variables: None, + variable_types: Default::default(), unrelated_collections: Default::default(), query: Query { predicate: Some(Expression::And { @@ -498,6 +499,7 @@ fn translates_root_column_references() -> Result<(), anyhow::Error> { .into(), arguments: Default::default(), variables: Default::default(), + variable_types: Default::default(), }; assert_eq!(query_plan, expected); @@ -546,6 +548,7 @@ fn translates_aggregate_selections() -> Result<(), anyhow::Error> { }, arguments: Default::default(), variables: Default::default(), + variable_types: Default::default(), unrelated_collections: Default::default(), }; @@ -731,6 +734,7 @@ fn translates_relationships_in_fields_predicates_and_orderings() -> Result<(), a }, arguments: Default::default(), variables: Default::default(), + variable_types: Default::default(), unrelated_collections: Default::default(), }; @@ -840,6 +844,7 @@ fn translates_nested_fields() -> Result<(), anyhow::Error> { }, arguments: Default::default(), variables: Default::default(), + variable_types: Default::default(), unrelated_collections: Default::default(), }; @@ -934,6 +939,7 @@ fn translates_predicate_referencing_field_of_related_collection() -> anyhow::Res }, arguments: Default::default(), variables: Default::default(), + variable_types: Default::default(), unrelated_collections: Default::default(), }; diff --git a/crates/ndc-query-plan/src/query_plan.rs b/crates/ndc-query-plan/src/query_plan.rs index 750fc4f5..49200ff6 100644 --- a/crates/ndc-query-plan/src/query_plan.rs +++ b/crates/ndc-query-plan/src/query_plan.rs @@ -7,22 +7,33 @@ use ndc_models::{ Argument, OrderDirection, RelationshipArgument, RelationshipType, UnaryComparisonOperator, }; -use crate::Type; +use crate::{vec_set::VecSet, Type}; pub trait ConnectorTypes { - type ScalarType: Clone + Debug + PartialEq; + type ScalarType: Clone + Debug + PartialEq + Eq; type AggregateFunction: Clone + Debug + PartialEq; type ComparisonOperator: Clone + Debug + PartialEq; } #[derive(Derivative)] -#[derivative(Clone(bound = ""), Debug(bound = ""), PartialEq(bound = ""))] +#[derivative( + Clone(bound = ""), + Debug(bound = ""), + PartialEq(bound = "T::ScalarType: PartialEq") +)] pub struct QueryPlan { pub collection: String, pub query: Query, pub arguments: BTreeMap, pub variables: Option>, + /// Types for values from the `variables` map as inferred by usages in the query request. It is + /// possible for the same variable to be used in multiple contexts with different types. This + /// map provides sets of all observed types. + /// + /// The observed type may be `None` if the type of a variable use could not be inferred. + pub variable_types: VariableTypes, + // TODO: type for unrelated collection pub unrelated_collections: BTreeMap>, } @@ -33,8 +44,9 @@ impl QueryPlan { } } -pub type VariableSet = BTreeMap; pub type Relationships = BTreeMap>; +pub type VariableSet = BTreeMap; +pub type VariableTypes = BTreeMap>>>; #[derive(Derivative)] #[derivative( diff --git a/crates/ndc-query-plan/src/vec_set.rs b/crates/ndc-query-plan/src/vec_set.rs new file mode 100644 index 00000000..b7a28640 --- /dev/null +++ b/crates/ndc-query-plan/src/vec_set.rs @@ -0,0 +1,80 @@ +/// Set implementation that only requires an [Eq] implementation on its value type +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct VecSet { + items: Vec, +} + +impl VecSet { + pub fn new() -> Self { + VecSet { items: Vec::new() } + } + + pub fn singleton(value: T) -> Self { + VecSet { items: vec![value] } + } + + /// If the value does not exist in the set, inserts it and returns `true`. If the value does + /// exist returns `false`, and leaves the set unchanged. + pub fn insert(&mut self, value: T) -> bool + where + T: Eq, + { + if self.items.iter().any(|v| *v == value) { + false + } else { + self.items.push(value); + true + } + } + + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + pub fn iter(&self) -> std::slice::Iter<'_, T> { + self.items.iter() + } +} + +impl FromIterator for VecSet { + fn from_iter>(iter: I) -> Self { + VecSet { + items: Vec::from_iter(iter), + } + } +} + +impl From<[T; N]> for VecSet { + fn from(value: [T; N]) -> Self { + VecSet { + items: value.into(), + } + } +} + +impl IntoIterator for VecSet { + type Item = T; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +impl<'a, T> IntoIterator for &'a VecSet { + type Item = &'a T; + type IntoIter = std::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter() + } +} + +impl<'a, T> IntoIterator for &'a mut VecSet { + type Item = &'a mut T; + type IntoIter = std::slice::IterMut<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter_mut() + } +} diff --git a/crates/test-helpers/Cargo.toml b/crates/test-helpers/Cargo.toml index 744d22ce..3e22d819 100644 --- a/crates/test-helpers/Cargo.toml +++ b/crates/test-helpers/Cargo.toml @@ -6,6 +6,7 @@ version.workspace = true [dependencies] configuration = { path = "../configuration" } mongodb-support = { path = "../mongodb-support" } +ndc-query-plan = { path = "../ndc-query-plan" } ndc-test-helpers = { path = "../ndc-test-helpers" } enum-iterator = "^2.0.0" diff --git a/crates/test-helpers/src/arb_plan_type.rs b/crates/test-helpers/src/arb_plan_type.rs new file mode 100644 index 00000000..b878557a --- /dev/null +++ b/crates/test-helpers/src/arb_plan_type.rs @@ -0,0 +1,27 @@ +use configuration::MongoScalarType; +use ndc_query_plan::{ObjectType, Type}; +use proptest::{collection::btree_map, prelude::*}; + +use crate::arb_type::arb_bson_scalar_type; + +pub fn arb_plan_type() -> impl Strategy> { + let leaf = arb_plan_scalar_type().prop_map(Type::Scalar); + leaf.prop_recursive(3, 10, 10, |inner| { + prop_oneof![ + inner.clone().prop_map(|t| Type::ArrayOf(Box::new(t))), + inner.clone().prop_map(|t| Type::Nullable(Box::new(t))), + ( + any::>(), + btree_map(any::(), inner, 1..=10) + ) + .prop_map(|(name, fields)| Type::Object(ObjectType { name, fields })) + ] + }) +} + +fn arb_plan_scalar_type() -> impl Strategy { + prop_oneof![ + arb_bson_scalar_type().prop_map(MongoScalarType::Bson), + Just(MongoScalarType::ExtendedJSON) + ] +} diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs index 751ce2d2..be884004 100644 --- a/crates/test-helpers/src/lib.rs +++ b/crates/test-helpers/src/lib.rs @@ -1,5 +1,7 @@ pub mod arb_bson; +mod arb_plan_type; pub mod arb_type; pub use arb_bson::{arb_bson, arb_bson_with_options, ArbBsonOptions}; +pub use arb_plan_type::arb_plan_type; pub use arb_type::arb_type; diff --git a/fixtures/mongodb/chinook/chinook-import.sh b/fixtures/mongodb/chinook/chinook-import.sh index 66f4aa09..32fbd7d5 100755 --- a/fixtures/mongodb/chinook/chinook-import.sh +++ b/fixtures/mongodb/chinook/chinook-import.sh @@ -41,4 +41,6 @@ importCollection "Playlist" importCollection "PlaylistTrack" importCollection "Track" +$MONGO_SH "$DATABASE_NAME" "$FIXTURES/indexes.js" + echo "✅ Sample Chinook data imported..." diff --git a/fixtures/mongodb/chinook/indexes.js b/fixtures/mongodb/chinook/indexes.js new file mode 100644 index 00000000..2727a1ed --- /dev/null +++ b/fixtures/mongodb/chinook/indexes.js @@ -0,0 +1,20 @@ +db.Album.createIndex({ AlbumId: 1 }) +db.Album.createIndex({ ArtistId: 1 }) +db.Artist.createIndex({ ArtistId: 1 }) +db.Customer.createIndex({ CustomerId: 1 }) +db.Customer.createIndex({ SupportRepId: 1 }) +db.Employee.createIndex({ EmployeeId: 1 }) +db.Employee.createIndex({ ReportsTo: 1 }) +db.Genre.createIndex({ GenreId: 1 }) +db.Invoice.createIndex({ CustomerId: 1 }) +db.Invoice.createIndex({ InvoiceId: 1 }) +db.InvoiceLine.createIndex({ InvoiceId: 1 }) +db.InvoiceLine.createIndex({ TrackId: 1 }) +db.MediaType.createIndex({ MediaTypeId: 1 }) +db.Playlist.createIndex({ PlaylistId: 1 }) +db.PlaylistTrack.createIndex({ PlaylistId: 1 }) +db.PlaylistTrack.createIndex({ TrackId: 1 }) +db.Track.createIndex({ AlbumId: 1 }) +db.Track.createIndex({ GenreId: 1 }) +db.Track.createIndex({ MediaTypeId: 1 }) +db.Track.createIndex({ TrackId: 1 }) diff --git a/fixtures/mongodb/sample_import.sh b/fixtures/mongodb/sample_import.sh index aa7d2c91..21340366 100755 --- a/fixtures/mongodb/sample_import.sh +++ b/fixtures/mongodb/sample_import.sh @@ -32,6 +32,7 @@ mongoimport --db sample_mflix --collection movies --file "$FIXTURES"/sample_mfli mongoimport --db sample_mflix --collection sessions --file "$FIXTURES"/sample_mflix/sessions.json mongoimport --db sample_mflix --collection theaters --file "$FIXTURES"/sample_mflix/theaters.json mongoimport --db sample_mflix --collection users --file "$FIXTURES"/sample_mflix/users.json +$MONGO_SH sample_mflix "$FIXTURES/sample_mflix/indexes.js" echo "✅ Mflix sample data imported..." # chinook diff --git a/fixtures/mongodb/sample_mflix/indexes.js b/fixtures/mongodb/sample_mflix/indexes.js new file mode 100644 index 00000000..1fb4807c --- /dev/null +++ b/fixtures/mongodb/sample_mflix/indexes.js @@ -0,0 +1,3 @@ +db.comments.createIndex({ movie_id: 1 }) +db.comments.createIndex({ email: 1 }) +db.users.createIndex({ email: 1 })