Skip to content

Commit 70764a1

Browse files
authored
ISSUE-1088: Fix array_agg wildcard behavior (#1093)
Co-authored-by: Andrew Repp <arepp@cloudflare.com>
1 parent 398a810 commit 70764a1

File tree

4 files changed

+111
-36
lines changed

4 files changed

+111
-36
lines changed

src/ast/mod.rs

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,10 @@ pub enum Expr {
380380
right: Box<Expr>,
381381
},
382382
/// CompositeAccess (postgres) eg: SELECT (information_schema._pg_expandarray(array['i','i'])).n
383-
CompositeAccess { expr: Box<Expr>, key: Ident },
383+
CompositeAccess {
384+
expr: Box<Expr>,
385+
key: Ident,
386+
},
384387
/// `IS FALSE` operator
385388
IsFalse(Box<Expr>),
386389
/// `IS NOT FALSE` operator
@@ -474,7 +477,10 @@ pub enum Expr {
474477
right: Box<Expr>,
475478
},
476479
/// Unary operation e.g. `NOT foo`
477-
UnaryOp { op: UnaryOperator, expr: Box<Expr> },
480+
UnaryOp {
481+
op: UnaryOperator,
482+
expr: Box<Expr>,
483+
},
478484
/// CONVERT a value to a different data type or character encoding. e.g. `CONVERT(foo USING utf8mb4)`
479485
Convert {
480486
/// The expression to convert
@@ -545,7 +551,10 @@ pub enum Expr {
545551
/// ```sql
546552
/// POSITION(<expr> in <expr>)
547553
/// ```
548-
Position { expr: Box<Expr>, r#in: Box<Expr> },
554+
Position {
555+
expr: Box<Expr>,
556+
r#in: Box<Expr>,
557+
},
549558
/// ```sql
550559
/// SUBSTRING(<expr> [FROM <expr>] [FOR <expr>])
551560
/// ```
@@ -589,20 +598,32 @@ pub enum Expr {
589598
/// A literal value, such as string, number, date or NULL
590599
Value(Value),
591600
/// <https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html>
592-
IntroducedString { introducer: String, value: Value },
601+
IntroducedString {
602+
introducer: String,
603+
value: Value,
604+
},
593605
/// A constant of form `<data_type> 'value'`.
594606
/// This can represent ANSI SQL `DATE`, `TIME`, and `TIMESTAMP` literals (such as `DATE '2020-01-01'`),
595607
/// as well as constants of other types (a non-standard PostgreSQL extension).
596-
TypedString { data_type: DataType, value: String },
608+
TypedString {
609+
data_type: DataType,
610+
value: String,
611+
},
597612
/// Access a map-like object by field (e.g. `column['field']` or `column[4]`
598613
/// Note that depending on the dialect, struct like accesses may be
599614
/// parsed as [`ArrayIndex`](Self::ArrayIndex) or [`MapAccess`](Self::MapAccess)
600615
/// <https://clickhouse.com/docs/en/sql-reference/data-types/map/>
601-
MapAccess { column: Box<Expr>, keys: Vec<Expr> },
616+
MapAccess {
617+
column: Box<Expr>,
618+
keys: Vec<Expr>,
619+
},
602620
/// Scalar function call e.g. `LEFT(foo, 5)`
603621
Function(Function),
604622
/// Aggregate function with filter
605-
AggregateExpressionWithFilter { expr: Box<Expr>, filter: Box<Expr> },
623+
AggregateExpressionWithFilter {
624+
expr: Box<Expr>,
625+
filter: Box<Expr>,
626+
},
606627
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
607628
///
608629
/// Note we only recognize a complete single expression as `<condition>`,
@@ -616,7 +637,10 @@ pub enum Expr {
616637
},
617638
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
618639
/// `WHERE [ NOT ] EXISTS (SELECT ...)`.
619-
Exists { subquery: Box<Query>, negated: bool },
640+
Exists {
641+
subquery: Box<Query>,
642+
negated: bool,
643+
},
620644
/// A parenthesized subquery `(SELECT ...)`, used in expression like
621645
/// `SELECT (subquery) AS x` or `WHERE (subquery) = x`
622646
Subquery(Box<Query>),
@@ -653,9 +677,15 @@ pub enum Expr {
653677
/// 1 AS A
654678
/// ```
655679
/// [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
656-
Named { expr: Box<Expr>, name: Ident },
680+
Named {
681+
expr: Box<Expr>,
682+
name: Ident,
683+
},
657684
/// An array index expression e.g. `(ARRAY[1, 2])[1]` or `(current_schemas(FALSE))[1]`
658-
ArrayIndex { obj: Box<Expr>, indexes: Vec<Expr> },
685+
ArrayIndex {
686+
obj: Box<Expr>,
687+
indexes: Vec<Expr>,
688+
},
659689
/// An array expression e.g. `ARRAY[1, 2]`
660690
Array(Array),
661691
/// An interval expression e.g. `INTERVAL '1' YEAR`
@@ -678,6 +708,10 @@ pub enum Expr {
678708
/// `<search modifier>`
679709
opt_search_modifier: Option<SearchModifier>,
680710
},
711+
Wildcard,
712+
/// Qualified wildcard, e.g. `alias.*` or `schema.table.*`.
713+
/// (Same caveats apply to `QualifiedWildcard` as to `Wildcard`.)
714+
QualifiedWildcard(ObjectName),
681715
}
682716

683717
impl fmt::Display for CastFormat {
@@ -704,6 +738,8 @@ impl fmt::Display for Expr {
704738
}
705739
Ok(())
706740
}
741+
Expr::Wildcard => f.write_str("*"),
742+
Expr::QualifiedWildcard(prefix) => write!(f, "{}.*", prefix),
707743
Expr::CompoundIdentifier(s) => write!(f, "{}", display_separated(s, ".")),
708744
Expr::IsTrue(ast) => write!(f, "{ast} IS TRUE"),
709745
Expr::IsNotTrue(ast) => write!(f, "{ast} IS NOT TRUE"),
@@ -4188,6 +4224,16 @@ pub enum FunctionArgExpr {
41884224
Wildcard,
41894225
}
41904226

4227+
impl From<Expr> for FunctionArgExpr {
4228+
fn from(wildcard_expr: Expr) -> Self {
4229+
match wildcard_expr {
4230+
Expr::QualifiedWildcard(prefix) => Self::QualifiedWildcard(prefix),
4231+
Expr::Wildcard => Self::Wildcard,
4232+
expr => Self::Expr(expr),
4233+
}
4234+
}
4235+
}
4236+
41914237
impl fmt::Display for FunctionArgExpr {
41924238
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41934239
match self {

src/parser/mod.rs

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,6 @@ pub enum WildcardExpr {
161161
Wildcard,
162162
}
163163

164-
impl From<WildcardExpr> for FunctionArgExpr {
165-
fn from(wildcard_expr: WildcardExpr) -> Self {
166-
match wildcard_expr {
167-
WildcardExpr::Expr(expr) => Self::Expr(expr),
168-
WildcardExpr::QualifiedWildcard(prefix) => Self::QualifiedWildcard(prefix),
169-
WildcardExpr::Wildcard => Self::Wildcard,
170-
}
171-
}
172-
}
173-
174164
impl From<TokenizerError> for ParserError {
175165
fn from(e: TokenizerError) -> Self {
176166
ParserError::TokenizerError(e.to_string())
@@ -734,7 +724,7 @@ impl<'a> Parser<'a> {
734724
}
735725

736726
/// Parse a new expression including wildcard & qualified wildcard
737-
pub fn parse_wildcard_expr(&mut self) -> Result<WildcardExpr, ParserError> {
727+
pub fn parse_wildcard_expr(&mut self) -> Result<Expr, ParserError> {
738728
let index = self.index;
739729

740730
let next_token = self.next_token();
@@ -756,7 +746,7 @@ impl<'a> Parser<'a> {
756746
id_parts.push(Ident::with_quote('\'', s))
757747
}
758748
Token::Mul => {
759-
return Ok(WildcardExpr::QualifiedWildcard(ObjectName(id_parts)));
749+
return Ok(Expr::QualifiedWildcard(ObjectName(id_parts)));
760750
}
761751
_ => {
762752
return self
@@ -767,13 +757,13 @@ impl<'a> Parser<'a> {
767757
}
768758
}
769759
Token::Mul => {
770-
return Ok(WildcardExpr::Wildcard);
760+
return Ok(Expr::Wildcard);
771761
}
772762
_ => (),
773763
};
774764

775765
self.index = index;
776-
self.parse_expr().map(WildcardExpr::Expr)
766+
self.parse_expr()
777767
}
778768

779769
/// Parse a new expression
@@ -969,10 +959,22 @@ impl<'a> Parser<'a> {
969959
_ => match self.peek_token().token {
970960
Token::LParen | Token::Period => {
971961
let mut id_parts: Vec<Ident> = vec![w.to_ident()];
962+
let mut ends_with_wildcard = false;
972963
while self.consume_token(&Token::Period) {
973964
let next_token = self.next_token();
974965
match next_token.token {
975966
Token::Word(w) => id_parts.push(w.to_ident()),
967+
Token::Mul => {
968+
// Postgres explicitly allows funcnm(tablenm.*) and the
969+
// function array_agg traverses this control flow
970+
if dialect_of!(self is PostgreSqlDialect) {
971+
ends_with_wildcard = true;
972+
break;
973+
} else {
974+
return self
975+
.expected("an identifier after '.'", next_token);
976+
}
977+
}
976978
Token::SingleQuotedString(s) => {
977979
id_parts.push(Ident::with_quote('\'', s))
978980
}
@@ -983,7 +985,9 @@ impl<'a> Parser<'a> {
983985
}
984986
}
985987

986-
if self.consume_token(&Token::LParen) {
988+
if ends_with_wildcard {
989+
Ok(Expr::QualifiedWildcard(ObjectName(id_parts)))
990+
} else if self.consume_token(&Token::LParen) {
987991
self.prev_token();
988992
self.parse_function(ObjectName(id_parts))
989993
} else {
@@ -8051,9 +8055,9 @@ impl<'a> Parser<'a> {
80518055
let subquery = self.parse_query()?;
80528056
self.expect_token(&Token::RParen)?;
80538057
return Ok((
8054-
vec![FunctionArg::Unnamed(FunctionArgExpr::from(
8055-
WildcardExpr::Expr(Expr::Subquery(Box::new(subquery))),
8056-
))],
8058+
vec![FunctionArg::Unnamed(FunctionArgExpr::from(Expr::Subquery(
8059+
Box::new(subquery),
8060+
)))],
80578061
vec![],
80588062
));
80598063
}
@@ -8072,7 +8076,14 @@ impl<'a> Parser<'a> {
80728076
/// Parse a comma-delimited list of projections after SELECT
80738077
pub fn parse_select_item(&mut self) -> Result<SelectItem, ParserError> {
80748078
match self.parse_wildcard_expr()? {
8075-
WildcardExpr::Expr(expr) => {
8079+
Expr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(
8080+
prefix,
8081+
self.parse_wildcard_additional_options()?,
8082+
)),
8083+
Expr::Wildcard => Ok(SelectItem::Wildcard(
8084+
self.parse_wildcard_additional_options()?,
8085+
)),
8086+
expr => {
80768087
let expr: Expr = if self.dialect.supports_filter_during_aggregation()
80778088
&& self.parse_keyword(Keyword::FILTER)
80788089
{
@@ -8097,13 +8108,6 @@ impl<'a> Parser<'a> {
80978108
None => SelectItem::UnnamedExpr(expr),
80988109
})
80998110
}
8100-
WildcardExpr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(
8101-
prefix,
8102-
self.parse_wildcard_additional_options()?,
8103-
)),
8104-
WildcardExpr::Wildcard => Ok(SelectItem::Wildcard(
8105-
self.parse_wildcard_additional_options()?,
8106-
)),
81078111
}
81088112
}
81098113

tests/sqlparser_common.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,6 +2403,12 @@ fn parse_array_agg_func() {
24032403
] {
24042404
supported_dialects.verified_stmt(sql);
24052405
}
2406+
2407+
// follows special-case array_agg code path. fails in everything except postgres
2408+
let wc_sql = "SELECT ARRAY_AGG(sections_tbl.*) AS sections FROM sections_tbl";
2409+
all_dialects_but_pg()
2410+
.parse_sql_statements(wc_sql)
2411+
.expect_err("should have failed");
24062412
}
24072413

24082414
#[test]

tests/sqlparser_postgres.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,3 +3800,22 @@ fn test_simple_insert_with_quoted_alias() {
38003800
}
38013801
)
38023802
}
3803+
3804+
#[test]
3805+
fn parse_array_agg() {
3806+
// follows general function with wildcard code path
3807+
let sql = r#"SELECT GREATEST(sections_tbl.*) AS sections FROM sections_tbl"#;
3808+
pg().verified_stmt(sql);
3809+
3810+
// follows special-case array_agg code path
3811+
let sql2 = "SELECT ARRAY_AGG(sections_tbl.*) AS sections FROM sections_tbl";
3812+
pg().verified_stmt(sql2);
3813+
3814+
// handles multi-part identifier with general code path
3815+
let sql3 = "SELECT GREATEST(my_schema.sections_tbl.*) AS sections FROM sections_tbl";
3816+
pg().verified_stmt(sql3);
3817+
3818+
// handles multi-part identifier with array_agg code path
3819+
let sql4 = "SELECT ARRAY_AGG(my_schema.sections_tbl.*) AS sections FROM sections_tbl";
3820+
pg().verified_stmt(sql4);
3821+
}

0 commit comments

Comments
 (0)