diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb8ed80..15e71b49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This changelog documents the changes between release versions. - Fix bug with operator lookup when filtering on nested fields ([#82](https://github.com/hasura/ndc-mongodb/pull/82)) - Rework query plans for requests with variable sets to allow use of indexes ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) - Fix: error when requesting query plan if MongoDB is target of a remote join ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) +- Fix: count aggregates return 0 instead of null if no rows match ([#85](https://github.com/hasura/ndc-mongodb/pull/85)) - Breaking change: remote joins no longer work in MongoDB v5 ([#83](https://github.com/hasura/ndc-mongodb/pull/83)) - Add configuration option to opt into "relaxed" mode for Extended JSON outputs ([#84](https://github.com/hasura/ndc-mongodb/pull/84)) diff --git a/crates/mongodb-agent-common/src/aggregation_function.rs b/crates/mongodb-agent-common/src/aggregation_function.rs index c22fdc0e..bc1cc264 100644 --- a/crates/mongodb-agent-common/src/aggregation_function.rs +++ b/crates/mongodb-agent-common/src/aggregation_function.rs @@ -31,4 +31,14 @@ impl AggregationFunction { aggregate_function: s.to_owned(), }) } + + pub fn is_count(self) -> bool { + match self { + A::Avg => false, + A::Count => true, + A::Min => false, + A::Max => false, + A::Sum => false, + } + } } diff --git a/crates/mongodb-agent-common/src/query/foreach.rs b/crates/mongodb-agent-common/src/query/foreach.rs index e11b7d2e..217019a8 100644 --- a/crates/mongodb-agent-common/src/query/foreach.rs +++ b/crates/mongodb-agent-common/src/query/foreach.rs @@ -240,10 +240,17 @@ mod tests { } }, { "$replaceWith": { "aggregates": { - "count": { "$getField": { - "field": "result", - "input": { "$first": { "$getField": { "$literal": "count" } } } - } }, + "count": { + "$ifNull": [ + { + "$getField": { + "field": "result", + "input": { "$first": { "$getField": { "$literal": "count" } } } + } + }, + 0, + ] + }, }, "rows": "$__ROWS__", } }, @@ -344,10 +351,17 @@ mod tests { } }, { "$replaceWith": { "aggregates": { - "count": { "$getField": { - "field": "result", - "input": { "$first": { "$getField": { "$literal": "count" } } } - } }, + "count": { + "$ifNull": [ + { + "$getField": { + "field": "result", + "input": { "$first": { "$getField": { "$literal": "count" } } } + } + }, + 0, + ] + }, }, } }, ] diff --git a/crates/mongodb-agent-common/src/query/mod.rs b/crates/mongodb-agent-common/src/query/mod.rs index 2a4f82b3..5c4e5dca 100644 --- a/crates/mongodb-agent-common/src/query/mod.rs +++ b/crates/mongodb-agent-common/src/query/mod.rs @@ -131,10 +131,17 @@ mod tests { "field": "result", "input": { "$first": { "$getField": { "$literal": "avg" } } }, } }, - "count": { "$getField": { - "field": "result", - "input": { "$first": { "$getField": { "$literal": "count" } } }, - } }, + "count": { + "$ifNull": [ + { + "$getField": { + "field": "result", + "input": { "$first": { "$getField": { "$literal": "count" } } }, + } + }, + 0, + ] + }, }, }, }, diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index ca82df78..745a608c 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -169,22 +169,28 @@ fn facet_pipelines_for_query( let aggregate_selections: bson::Document = aggregates .iter() .flatten() - .map(|(key, _aggregate)| { + .map(|(key, aggregate)| { // The facet result for each aggregate is an array containing a single document which // has a field called `result`. This code selects each facet result by name, and pulls // out the `result` value. - ( - // TODO: Is there a way we can prevent potential code injection in the use of `key` - // here? - key.clone(), + let value_expr = doc! { + "$getField": { + "field": RESULT_FIELD, // evaluates to the value of this field + "input": { "$first": get_field(key) }, // field is accessed from this document + }, + }; + + // Matching SQL semantics, if a **count** aggregation does not match any rows we want + // to return zero. Other aggregations should return null. + let value_expr = if is_count(aggregate) { doc! { - "$getField": { - "field": RESULT_FIELD, // evaluates to the value of this field - "input": { "$first": get_field(key) }, // field is accessed from this document - }, + "$ifNull": [value_expr, 0], } - .into(), - ) + } else { + value_expr + }; + + (key.clone(), value_expr.into()) }) .collect(); @@ -209,6 +215,14 @@ fn facet_pipelines_for_query( Ok((facet_pipelines, selection)) } +fn is_count(aggregate: &Aggregate) -> bool { + match aggregate { + Aggregate::ColumnCount { .. } => true, + Aggregate::StarCount { .. } => true, + Aggregate::SingleColumn { function, .. } => function.is_count(), + } +} + fn pipeline_for_aggregate( aggregate: Aggregate, limit: Option, @@ -240,20 +254,7 @@ fn pipeline_for_aggregate( bson::doc! { &column: { "$exists": true, "$ne": null } }, )), limit.map(Stage::Limit), - Some(Stage::Group { - key_expression: field_ref(&column), - accumulators: [(RESULT_FIELD.to_string(), Accumulator::Count)].into(), - }), - Some(Stage::Group { - key_expression: Bson::Null, - // Sums field values from the `result` field of the previous stage, and writes - // a new field which is also called `result`. - accumulators: [( - RESULT_FIELD.to_string(), - Accumulator::Sum(field_ref(RESULT_FIELD)), - )] - .into(), - }), + Some(Stage::Count(RESULT_FIELD.to_string())), ] .into_iter() .flatten(), diff --git a/crates/mongodb-agent-common/src/query/relations.rs b/crates/mongodb-agent-common/src/query/relations.rs index 22a162b0..bcbee0dc 100644 --- a/crates/mongodb-agent-common/src/query/relations.rs +++ b/crates/mongodb-agent-common/src/query/relations.rs @@ -636,10 +636,15 @@ mod tests { "$replaceWith": { "aggregates": { "aggregate_count": { - "$getField": { - "field": "result", - "input": { "$first": { "$getField": { "$literal": "aggregate_count" } } }, - }, + "$ifNull": [ + { + "$getField": { + "field": "result", + "input": { "$first": { "$getField": { "$literal": "aggregate_count" } } }, + }, + }, + 0, + ] }, }, },