Skip to content

Commit 225fc70

Browse files
authored
Add support for AVG on sharded queries (#14419)
1 parent 1a9119d commit 225fc70

File tree

8 files changed

+482
-107
lines changed

8 files changed

+482
-107
lines changed

go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func TestAggregateTypes(t *testing.T) {
7373
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
7474
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
7575
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
76+
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
7677
}
7778

7879
func TestGroupBy(t *testing.T) {
@@ -172,6 +173,13 @@ func TestAggrOnJoin(t *testing.T) {
172173

173174
mcmp.AssertMatches("select a.val1 from aggr_test a join t3 t on a.val2 = t.id7 group by a.val1 having count(*) = 4",
174175
`[[VARCHAR("a")]]`)
176+
177+
mcmp.AssertMatches(`select avg(a1.val2), avg(a2.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7`,
178+
"[[DECIMAL(1.5000) DECIMAL(1.0000)]]")
179+
180+
mcmp.AssertMatches(`select a1.val1, avg(a1.val2) from aggr_test a1 join aggr_test a2 on a1.val2 = a2.id join t3 t on a2.val2 = t.id7 group by a1.val1`,
181+
`[[VARCHAR("a") DECIMAL(1.0000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.0000)]]`)
182+
175183
}
176184

177185
func TestNotEqualFilterOnScatter(t *testing.T) {
@@ -314,22 +322,26 @@ func TestAggOnTopOfLimit(t *testing.T) {
314322
for _, workload := range []string{"oltp", "olap"} {
315323
t.Run(workload, func(t *testing.T) {
316324
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload))
317-
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
318-
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
319-
mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
320-
mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
321-
mcmp.AssertMatches(" select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
322-
mcmp.AssertMatches(" select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
323-
mcmp.AssertMatchesNoOrder(" select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
325+
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]")
326+
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]")
327+
mcmp.AssertMatches("select count(*) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2)]]")
328+
mcmp.AssertMatches("select count(val1) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1)]]")
329+
mcmp.AssertMatches("select count(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0)]]")
330+
mcmp.AssertMatches("select avg(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[NULL]]")
331+
mcmp.AssertMatches("select val1, count(*) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(1)]]`)
332+
mcmp.AssertMatchesNoOrder("select val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)
333+
mcmp.AssertMatchesNoOrder("select val1, avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL DECIMAL(2.0000)] [VARCHAR("a") DECIMAL(3.5000)] [VARCHAR("b") DECIMAL(1.0000)] [VARCHAR("c") DECIMAL(3.5000)]]`)
324334

325335
// mysql returns FLOAT64(0), vitess returns DECIMAL(0)
326-
mcmp.AssertMatchesNoCompare(" select count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) FLOAT64(0)]]")
327-
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
328-
mcmp.AssertMatches(" select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
329-
mcmp.AssertMatches(" select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
330-
mcmp.AssertMatches(" select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
331-
mcmp.AssertMatches(" select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
332-
mcmp.AssertMatchesNoOrder(" select val1, count(val2), sum(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1) DECIMAL(2)] [VARCHAR("a") INT64(2) DECIMAL(7)] [VARCHAR("b") INT64(1) DECIMAL(1)] [VARCHAR("c") INT64(2) DECIMAL(7)]]`)
336+
mcmp.AssertMatches("select count(*), sum(val1), avg(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0) FLOAT64(0)]]")
337+
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
338+
mcmp.AssertMatches("select count(val1), sum(id), avg(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7) DECIMAL(3.5000)]]")
339+
mcmp.AssertMatches("select count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
340+
mcmp.AssertMatches("select count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
341+
mcmp.AssertMatches("select count(val2), sum(val2) from (select id, val2 from aggr_test where val2 is null limit 2) as x", "[[INT64(0) NULL]]")
342+
mcmp.AssertMatches("select val1, count(*), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 limit 2) as x group by val1", `[[NULL INT64(1) DECIMAL(7)] [VARCHAR("a") INT64(1) DECIMAL(2)]]`)
343+
mcmp.AssertMatchesNoOrder("select val1, count(val2), sum(val2), avg(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1",
344+
`[[NULL INT64(1) DECIMAL(2) DECIMAL(2.0000)] [VARCHAR("a") INT64(2) DECIMAL(7) DECIMAL(3.5000)] [VARCHAR("b") INT64(1) DECIMAL(1) DECIMAL(1.0000)] [VARCHAR("c") INT64(2) DECIMAL(7) DECIMAL(3.5000)]]`)
333345
})
334346
}
335347
}
@@ -343,6 +355,8 @@ func TestEmptyTableAggr(t *testing.T) {
343355
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
344356
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
345357
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
358+
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
359+
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
346360
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
347361
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
348362
})
@@ -355,8 +369,10 @@ func TestEmptyTableAggr(t *testing.T) {
355369
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
356370
mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
357371
mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
358-
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
372+
mcmp.AssertMatches(" select count(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
373+
mcmp.AssertMatches(" select avg(t1.value) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[NULL]]")
359374
mcmp.AssertMatches(" select t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
375+
mcmp.AssertMatches(" select t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
360376
})
361377
}
362378

@@ -398,6 +414,8 @@ func TestAggregateLeftJoin(t *testing.T) {
398414
mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`)
399415
mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
400416
mcmp.AssertMatches("SELECT sum(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
417+
mcmp.AssertMatches("SELECT avg(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(0.5000)]]`)
418+
mcmp.AssertMatches("SELECT avg(t2.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1.0000)]]`)
401419
mcmp.AssertMatches("SELECT count(*) FROM t2 LEFT JOIN t1 ON t1.t1_id = t2.id WHERE IFNULL(t1.name, 'NOTSET') = 'r'", `[[INT64(1)]]`)
402420
}
403421

@@ -426,6 +444,7 @@ func TestScalarAggregate(t *testing.T) {
426444

427445
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
428446
mcmp.AssertMatches("select count(distinct val1) from aggr_test", `[[INT64(3)]]`)
447+
mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`)
429448
}
430449

431450
func TestAggregationRandomOnAnAggregatedValue(t *testing.T) {
@@ -478,6 +497,7 @@ func TestComplexAggregation(t *testing.T) {
478497
mcmp.Exec(`SELECT 1+COUNT(t1_id) FROM t1`)
479498
mcmp.Exec(`SELECT COUNT(t1_id)+1 FROM t1`)
480499
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey) FROM t1`)
500+
mcmp.Exec(`SELECT COUNT(t1_id)+MAX(shardkey)+AVG(t1_id) FROM t1`)
481501
mcmp.Exec(`SELECT shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
482502
mcmp.Exec(`SELECT shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
483503
mcmp.Exec(`SELECT name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)

go/vt/vtgate/engine/opcode/constants.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ const (
7474
AggregateAnyValue
7575
AggregateCountStar
7676
AggregateGroupConcat
77+
AggregateAvg
7778
_NumOfOpCodes // This line must be last of the opcodes!
7879
)
7980

@@ -85,6 +86,7 @@ var (
8586
AggregateCountStar: sqltypes.Int64,
8687
AggregateSumDistinct: sqltypes.Decimal,
8788
AggregateSum: sqltypes.Decimal,
89+
AggregateAvg: sqltypes.Decimal,
8890
AggregateGtid: sqltypes.VarChar,
8991
}
9092
)
@@ -96,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
9698
"sum": AggregateSum,
9799
"min": AggregateMin,
98100
"max": AggregateMax,
101+
"avg": AggregateAvg,
99102
// These functions don't exist in mysql, but are used
100103
// to display the plan.
101104
"count_distinct": AggregateCountDistinct,
@@ -117,6 +120,7 @@ var AggregateName = map[AggregateOpcode]string{
117120
AggregateCountStar: "count_star",
118121
AggregateGroupConcat: "group_concat",
119122
AggregateAnyValue: "any_value",
123+
AggregateAvg: "avg",
120124
}
121125

122126
func (code AggregateOpcode) String() string {
@@ -148,7 +152,7 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
148152
return sqltypes.Text
149153
case AggregateMax, AggregateMin, AggregateAnyValue:
150154
return typ
151-
case AggregateSumDistinct, AggregateSum:
155+
case AggregateSumDistinct, AggregateSum, AggregateAvg:
152156
if typ == sqltypes.Unknown {
153157
return sqltypes.Unknown
154158
}

go/vt/vtgate/planbuilder/operators/aggregation_pushing.go

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,33 @@ func tryPushAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator)
3434
if aggregator.Pushed {
3535
return aggregator, rewrite.SameTree, nil
3636
}
37+
38+
// this rewrite is always valid, and we should do it whenever possible
39+
if route, ok := aggregator.Source.(*Route); ok && (route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping)) {
40+
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
41+
}
42+
43+
// other rewrites require us to have reached this phase before we can consider them
44+
if !reachedPhase(ctx, delegateAggregation) {
45+
return aggregator, rewrite.SameTree, nil
46+
}
47+
48+
// if we have not yet been able to push this aggregation down,
49+
// we need to turn AVG into SUM/COUNT to support this over a sharded keyspace
50+
if needAvgBreaking(aggregator.Aggregations) {
51+
return splitAvgAggregations(ctx, aggregator)
52+
}
53+
3754
switch src := aggregator.Source.(type) {
3855
case *Route:
3956
// if we have a single sharded route, we can push it down
4057
output, applyResult, err = pushAggregationThroughRoute(ctx, aggregator, src)
4158
case *ApplyJoin:
42-
if reachedPhase(ctx, delegateAggregation) {
43-
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
44-
}
59+
output, applyResult, err = pushAggregationThroughJoin(ctx, aggregator, src)
4560
case *Filter:
46-
if reachedPhase(ctx, delegateAggregation) {
47-
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
48-
}
61+
output, applyResult, err = pushAggregationThroughFilter(ctx, aggregator, src)
4962
case *SubQueryContainer:
50-
if reachedPhase(ctx, delegateAggregation) {
51-
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
52-
}
63+
output, applyResult, err = pushAggregationThroughSubquery(ctx, aggregator, src)
5364
default:
5465
return aggregator, rewrite.SameTree, nil
5566
}
@@ -135,15 +146,6 @@ func pushAggregationThroughRoute(
135146
aggregator *Aggregator,
136147
route *Route,
137148
) (ops.Operator, *rewrite.ApplyResult, error) {
138-
// If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation
139-
if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) {
140-
return rewrite.Swap(aggregator, route, "push down aggregation under route - remove original")
141-
}
142-
143-
if !reachedPhase(ctx, delegateAggregation) {
144-
return nil, nil, nil
145-
}
146-
147149
// Create a new aggregator to be placed below the route.
148150
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
149151
aggrBelowRoute.Aggregations = nil
@@ -806,3 +808,74 @@ func initColReUse(size int) []int {
806808
}
807809

808810
func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr }
811+
812+
func needAvgBreaking(aggrs []Aggr) bool {
813+
for _, aggr := range aggrs {
814+
if aggr.OpCode == opcode.AggregateAvg {
815+
return true
816+
}
817+
}
818+
return false
819+
}
820+
821+
// splitAvgAggregations takes an aggregator that has AVG aggregations in it and splits
822+
// these into sum/count expressions that can be spread out to shards
823+
func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (ops.Operator, *rewrite.ApplyResult, error) {
824+
proj := newAliasedProjection(aggr)
825+
826+
var columns []*sqlparser.AliasedExpr
827+
var aggregations []Aggr
828+
829+
for offset, col := range aggr.Columns {
830+
avg, ok := col.Expr.(*sqlparser.Avg)
831+
if !ok {
832+
proj.addColumnWithoutPushing(ctx, col, false /* addToGroupBy */)
833+
continue
834+
}
835+
836+
if avg.Distinct {
837+
panic(vterrors.VT12001("AVG(distinct <>)"))
838+
}
839+
840+
// We have an AVG that we need to split
841+
sumExpr := &sqlparser.Sum{Arg: avg.Arg}
842+
countExpr := &sqlparser.Count{Args: []sqlparser.Expr{avg.Arg}}
843+
calcExpr := &sqlparser.BinaryExpr{
844+
Operator: sqlparser.DivOp,
845+
Left: sumExpr,
846+
Right: countExpr,
847+
}
848+
849+
outputColumn := aeWrap(col.Expr)
850+
outputColumn.As = sqlparser.NewIdentifierCI(col.ColumnName())
851+
_, err := proj.addUnexploredExpr(sqlparser.CloneRefOfAliasedExpr(col), calcExpr)
852+
if err != nil {
853+
return nil, nil, err
854+
}
855+
col.Expr = sumExpr
856+
found := false
857+
for aggrOffset, aggregation := range aggr.Aggregations {
858+
if offset == aggregation.ColOffset {
859+
// We have found the AVG column. We'll change it to SUM, and then we add a COUNT as well
860+
aggr.Aggregations[aggrOffset].OpCode = opcode.AggregateSum
861+
862+
countExprAlias := aeWrap(countExpr)
863+
countAggr := NewAggr(opcode.AggregateCount, countExpr, countExprAlias, sqlparser.String(countExpr))
864+
countAggr.ColOffset = len(aggr.Columns) + len(columns)
865+
aggregations = append(aggregations, countAggr)
866+
columns = append(columns, countExprAlias)
867+
found = true
868+
break // no need to search the remaining aggregations
869+
}
870+
}
871+
if !found {
872+
// if we get here, it's because we didn't find the aggregation. Something is wrong
873+
panic(vterrors.VT13001("no aggregation pointing to this column was found"))
874+
}
875+
}
876+
877+
aggr.Columns = append(aggr.Columns, columns...)
878+
aggr.Aggregations = append(aggr.Aggregations, aggregations...)
879+
880+
return proj, rewrite.NewTree("split avg aggregation", proj), nil
881+
}

0 commit comments

Comments
 (0)