Skip to content

Commit

Permalink
Added support for minimum_should_match in quickwit es API
Browse files Browse the repository at this point in the history
Fixing rest compatilibility tests.

Closes #4828
  • Loading branch information
fulmicoton committed Oct 15, 2024
1 parent f5da576 commit d7a2c5b
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 63 deletions.
193 changes: 188 additions & 5 deletions quickwit/quickwit-query/src/elastic_query_dsl/bool_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,81 @@ pub struct BoolQuery {
filter: Vec<ElasticQueryDslInner>,
#[serde(default)]
pub boost: Option<NotNaNf32>,
#[serde(default)]
pub minimum_should_match: Option<MinimumShouldMatch>,
}

#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(untagged)]
pub enum MinimumShouldMatch {
Str(String),
Int(isize),
}

impl MinimumShouldMatch {
fn resolve(&self, num_should_clauses: usize) -> anyhow::Result<MinimumShouldMatchResolved> {
match self {
MinimumShouldMatch::Str(minimum_should_match_dsl) => {
let Some(percentage) = parse_percentage(minimum_should_match_dsl) else {
anyhow::bail!(
"Unsupported minimum should match dsl {}. quickwit currently only \
supports the format '35%'",
minimum_should_match_dsl
);
};
let min_should_match = num_should_clauses * percentage as usize / 100;
Ok(MinimumShouldMatchResolved::Min(min_should_match))
}
MinimumShouldMatch::Int(neg_num_missing_should_clauses)
if *neg_num_missing_should_clauses < 0 =>
{
let num_missing_should_clauses = -neg_num_missing_should_clauses as usize;
if num_missing_should_clauses >= num_should_clauses {
Ok(MinimumShouldMatchResolved::Unspecified)
} else {
Ok(MinimumShouldMatchResolved::Min(
num_should_clauses - num_missing_should_clauses,
))
}
}
MinimumShouldMatch::Int(num_required_should_clauses) => {
let num_required_should_clauses: usize = *num_required_should_clauses as usize;
if num_required_should_clauses > num_should_clauses {
Ok(MinimumShouldMatchResolved::NoMatch)
} else {
Ok(MinimumShouldMatchResolved::Min(
num_required_should_clauses as usize,
))
}
}
}
}
}

#[derive(Deserialize, Debug, Copy, Clone, Eq, PartialEq)]
enum MinimumShouldMatchResolved {
Unspecified,
Min(usize),
NoMatch,
}

fn parse_percentage(s: &str) -> Option<u32> {
let percentage_u32_str = s.strip_suffix('%')?;
let percentage_u32: u32 = percentage_u32_str.parse::<u32>().ok()?;
if percentage_u32 > 100 {
return None;
}
Some(percentage_u32)
}

impl BoolQuery {
fn resolve_minimum_should_match(&self) -> anyhow::Result<MinimumShouldMatchResolved> {
let num_should_clauses = self.should.len();
let Some(minimum_should_match) = &self.minimum_should_match else {
return Ok(MinimumShouldMatchResolved::Unspecified);
};
minimum_should_match.resolve(num_should_clauses)
}
}

impl BoolQuery {
Expand All @@ -57,6 +132,7 @@ impl BoolQuery {
should: children,
filter: Vec::new(),
boost: None,
minimum_should_match: None,
}
}
}
Expand All @@ -70,11 +146,25 @@ fn convert_vec(query_dsls: Vec<ElasticQueryDslInner>) -> anyhow::Result<Vec<Quer

impl ConvertibleToQueryAst for BoolQuery {
fn convert_to_query_ast(self) -> anyhow::Result<QueryAst> {
let minimum_should_match_resolved = self.resolve_minimum_should_match()?;
let must = convert_vec(self.must)?;
let must_not = convert_vec(self.must_not)?;
let should = convert_vec(self.should)?;
let filter = convert_vec(self.filter)?;

let minimum_should_match_opt = match minimum_should_match_resolved {
MinimumShouldMatchResolved::Unspecified => None,
MinimumShouldMatchResolved::Min(minimum_should_match) => Some(minimum_should_match),
MinimumShouldMatchResolved::NoMatch => {
return Ok(QueryAst::MatchNone);
}
};
let bool_query_ast = query_ast::BoolQuery {
must: convert_vec(self.must)?,
must_not: convert_vec(self.must_not)?,
should: convert_vec(self.should)?,
filter: convert_vec(self.filter)?,
must,
must_not,
should,
filter,
minimum_should_match: minimum_should_match_opt,
};
Ok(bool_query_ast.into())
}
Expand All @@ -88,8 +178,13 @@ impl From<BoolQuery> for ElasticQueryDslInner {

#[cfg(test)]
mod tests {
use crate::elastic_query_dsl::bool_query::BoolQuery;
use super::parse_percentage;
use crate::elastic_query_dsl::bool_query::{
BoolQuery, MinimumShouldMatch, MinimumShouldMatchResolved,
};
use crate::elastic_query_dsl::term_query::term_query_from_field_value;
use crate::elastic_query_dsl::ConvertibleToQueryAst;
use crate::query_ast::QueryAst;

#[test]
fn test_dsl_bool_query_deserialize_simple() {
Expand All @@ -111,6 +206,7 @@ mod tests {
should: Vec::new(),
filter: Vec::new(),
boost: None,
minimum_should_match: None
}
);
}
Expand All @@ -130,6 +226,7 @@ mod tests {
should: Vec::new(),
filter: vec![term_query_from_field_value("product_id", "2").into(),],
boost: None,
minimum_should_match: None,
}
);
}
Expand All @@ -152,7 +249,93 @@ mod tests {
should: Vec::new(),
filter: Vec::new(),
boost: None,
minimum_should_match: None,
}
);
}

#[test]
fn test_dsl_bool_query_deserialize_minimum_should_match() {
let bool_query: super::BoolQuery = serde_json::from_str(
r#"{
"must": [
{ "term": {"product_id": {"value": "1" }} },
{ "term": {"product_id": {"value": "2" }} }
],
"minimum_should_match": -2
}"#,
)
.unwrap();
assert_eq!(
bool_query.minimum_should_match.as_ref().unwrap(),
&MinimumShouldMatch::Int(-2)
);
}

#[test]
fn test_dsl_query_with_minimum_should_match() {
let bool_query_json = r#"{
"should": [
{ "term": {"product_id": {"value": "1" }} },
{ "term": {"product_id": {"value": "2" }} },
{ "term": {"product_id": {"value": "3" }} }
],
"minimum_should_match": 2
}"#;
let bool_query: BoolQuery = serde_json::from_str(bool_query_json).unwrap();
assert_eq!(bool_query.should.len(), 3);
assert_eq!(
bool_query.minimum_should_match.as_ref().unwrap(),
&super::MinimumShouldMatch::Int(2)
);
let QueryAst::Bool(bool_query_ast) = bool_query.convert_to_query_ast().unwrap() else {
panic!();
};
assert_eq!(bool_query_ast.should.len(), 3);
assert_eq!(bool_query_ast.minimum_should_match, Some(2));
}

#[test]
fn test_parse_percentage() {
assert_eq!(parse_percentage("10%"), Some(10));
assert_eq!(parse_percentage("101%"), None);
assert_eq!(parse_percentage("0%"), Some(0));
assert_eq!(parse_percentage("100%"), Some(100));
assert_eq!(parse_percentage("-20%"), None);
assert_eq!(parse_percentage("20"), None);
assert_eq!(parse_percentage("20a%"), None);
}

#[test]
fn test_resolve_minimum_should_match() {
assert_eq!(
MinimumShouldMatch::Str("30%".to_string())
.resolve(10)
.unwrap(),
MinimumShouldMatchResolved::Min(3)
);
// not supported yet
assert!(MinimumShouldMatch::Str("-30%".to_string())
.resolve(10)
.is_err());
assert!(MinimumShouldMatch::Str("-30!".to_string())
.resolve(10)
.is_err());
assert_eq!(
MinimumShouldMatch::Int(10).resolve(11).unwrap(),
MinimumShouldMatchResolved::Min(10)
);
assert_eq!(
MinimumShouldMatch::Int(-10).resolve(11).unwrap(),
MinimumShouldMatchResolved::Min(1)
);
assert_eq!(
MinimumShouldMatch::Int(-12).resolve(11).unwrap(),
MinimumShouldMatchResolved::Unspecified
);
assert_eq!(
MinimumShouldMatch::Int(12).resolve(11).unwrap(),
MinimumShouldMatchResolved::NoMatch
);
}
}
3 changes: 3 additions & 0 deletions quickwit/quickwit-query/src/query_ast/bool_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct BoolQuery {
pub should: Vec<QueryAst>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub filter: Vec<QueryAst>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub minimum_should_match: Option<usize>,
}

impl From<BoolQuery> for QueryAst {
Expand All @@ -65,6 +67,7 @@ impl BuildTantivyAst for BoolQuery {
with_validation: bool,
) -> Result<TantivyQueryAst, InvalidQuery> {
let mut boolean_query = super::tantivy_query_ast::TantivyBoolQuery::default();
boolean_query.minimum_should_match = self.minimum_should_match;
for must in &self.must {
let must_leaf = must.build_tantivy_ast_call(
schema,
Expand Down
2 changes: 2 additions & 0 deletions quickwit/quickwit-query/src/query_ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl QueryAst {
must_not,
should,
filter,
minimum_should_match,
}) => {
let must = parse_user_query_in_asts(must, default_search_fields)?;
let must_not = parse_user_query_in_asts(must_not, default_search_fields)?;
Expand All @@ -92,6 +93,7 @@ impl QueryAst {
must_not,
should,
filter,
minimum_should_match,
}
.into())
}
Expand Down
Loading

0 comments on commit d7a2c5b

Please sign in to comment.