From 3e76bf0a9fa17b666495f82b443bf16aca57672b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Wed, 6 Dec 2023 17:45:55 +0100 Subject: [PATCH] bugfix: use the original expression and not the alias (#14704) Signed-off-by: Andres Taylor --- .../operators/aggregation_pushing.go | 11 +-- .../planbuilder/testdata/aggr_cases.json | 70 +++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index e50483ce8d2..b0fdf683121 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -433,19 +433,22 @@ func splitGroupingToLeftAndRight(ctx *plancontext.PlanningContext, rootAggr *Agg var groupingJCs []JoinColumn for _, groupBy := range rootAggr.Grouping { - deps := ctx.SemTable.RecursiveDeps(groupBy.Inner) - expr := groupBy.Inner + expr, err := rootAggr.QP.GetSimplifiedExpr(ctx, groupBy.Inner) + if err != nil { + panic(err) + } + deps := ctx.SemTable.RecursiveDeps(expr) switch { case deps.IsSolvedBy(lhs.tableID): lhs.addGrouping(ctx, groupBy) groupingJCs = append(groupingJCs, JoinColumn{ - Original: groupBy.Inner, + Original: expr, LHSExprs: []BindVarExpr{{Expr: expr}}, }) case deps.IsSolvedBy(rhs.tableID): rhs.addGrouping(ctx, groupBy) groupingJCs = append(groupingJCs, JoinColumn{ - Original: groupBy.Inner, + Original: expr, RHSExpr: expr, }) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 2254baa36a6..d1e9c42c1dd 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6372,5 +6372,75 @@ "user.user" ] } + }, + { + "comment": "Group by aliases on both sides of a join", + "query": "select count(*), cast(user.foo as datetime) as f1, cast(music.foo as datetime) as f2 from user join music group by f1, f2", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*), cast(user.foo as datetime) as f1, cast(music.foo as datetime) as f2 from user join music group by f1, f2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "(1|3), (2|4)", + "ResultColumns": 3, + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|3) ASC, (2|4) ASC", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "count(*) * count(*) as count(*)", + ":2 as f1", + ":3 as f2", + ":4 as weight_string(cast(`user`.foo as datetime))", + ":5 as weight_string(cast(music.foo as datetime))" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,R:2", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), cast(`user`.foo as datetime) as f1, weight_string(cast(`user`.foo as datetime)) from `user` where 1 != 1 group by f1, weight_string(cast(`user`.foo as datetime))", + "Query": "select count(*), cast(`user`.foo as datetime) as f1, weight_string(cast(`user`.foo as datetime)) from `user` group by f1, weight_string(cast(`user`.foo as datetime))", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), cast(music.foo as datetime) as f2, weight_string(cast(music.foo as datetime)) from music where 1 != 1 group by f2, weight_string(cast(music.foo as datetime))", + "Query": "select count(*), cast(music.foo as datetime) as f2, weight_string(cast(music.foo as datetime)) from music group by f2, weight_string(cast(music.foo as datetime))", + "Table": "music" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } } ]