diff --git a/packages/malloy/src/dialect/postgres/postgres.ts b/packages/malloy/src/dialect/postgres/postgres.ts index 5d6384af2..64506f97d 100644 --- a/packages/malloy/src/dialect/postgres/postgres.ts +++ b/packages/malloy/src/dialect/postgres/postgres.ts @@ -248,16 +248,16 @@ export class PostgresDialect extends Dialect { isNested: boolean, _isArray: boolean ): string { - let ret = `${alias}->>'${fieldName}'`; + let ret = `(${alias}->>'${fieldName}')`; if (isNested) { switch (fieldType) { case 'string': break; case 'number': - ret = `(${ret})::double precision`; + ret = `${ret}::double precision`; break; case 'struct': - ret = `(${ret})::jsonb`; + ret = `${ret}::jsonb`; break; } return ret; diff --git a/packages/malloy/src/lang/ast/expressions/expr-aggregate-function.ts b/packages/malloy/src/lang/ast/expressions/expr-aggregate-function.ts index f9c945c63..a49691fe9 100644 --- a/packages/malloy/src/lang/ast/expressions/expr-aggregate-function.ts +++ b/packages/malloy/src/lang/ast/expressions/expr-aggregate-function.ts @@ -23,6 +23,7 @@ import { AggregateFragment, + AggregateFunctionType, expressionIsAggregate, FieldDef, FieldValueType, @@ -50,7 +51,7 @@ export abstract class ExprAggregateFunction extends ExpressionDef { explicitSource?: boolean; legalChildTypes = [FT.numberT]; constructor( - readonly func: string, + readonly func: AggregateFunctionType, expr?: ExpressionDef, explicitSource?: boolean ) { diff --git a/packages/malloy/src/model/malloy_query.ts b/packages/malloy/src/model/malloy_query.ts index d836ec3a7..f8fc915f4 100644 --- a/packages/malloy/src/model/malloy_query.ts +++ b/packages/malloy/src/model/malloy_query.ts @@ -25,6 +25,7 @@ import {Dialect, DialectFieldList, getDialect} from '../dialect'; import {StandardSQLDialect} from '../dialect/standardsql/standardsql'; import { AggregateFragment, + AggregateFunctionType, CompiledQuery, DialectFragment, Expr, @@ -127,6 +128,21 @@ interface OutputPipelinedSQL { pipelineSQL: string; } +// Track the times we might need a unique key +type UniqueKeyPossibleUse = AggregateFunctionType | 'generic_aggregate'; + +class UniqueKeyUse extends Set { + add_use(k: UniqueKeyPossibleUse | undefined) { + if (k !== undefined) { + return this.add(k); + } + } + + hasAsymetricFunctions(): boolean { + return this.has('sum') || this.has('avg') || this.has('count'); + } +} + class StageWriter { withs: string[] = []; udfs: string[] = []; @@ -307,8 +323,8 @@ class QueryField extends QueryNode { this.fieldDef = fieldDef; } - mayNeedUniqueKey(): boolean { - return false; + uniqueKeyPossibleUse(): UniqueKeyPossibleUse | undefined { + return undefined; } getJoinableParent(): QueryStruct { @@ -750,16 +766,31 @@ class QueryField extends QueryNode { ): string { let func = 'COUNT('; let thing = '1'; - const distinctKeySQL = this.generateDistinctKeyIfNecessary( - resultSet, - context, - expr.structPath - ); - if (distinctKeySQL) { + + let struct = context; + if (expr.structPath) { + struct = this.parent.root().getStructByName(expr.structPath); + } + const joinName = struct.getJoinableParent().getIdentifier(); + const join = resultSet.root().joins.get(joinName); + if (!join) { + throw new Error(`Join ${joinName} not found in result set`); + } + if (!join.leafiest || join.makeUniqueKey) { func = 'COUNT(DISTINCT'; - thing = distinctKeySQL; + thing = struct.getDistinctKey().generateExpression(resultSet); } + // const distinctKeySQL = this.generateDistinctKeyIfNecessary( + // resultSet, + // context, + // expr.structPath + // ); + // if (distinctKeySQL) { + // func = 'COUNT(DISTINCT'; + // thing = distinctKeySQL; + // } + // find the structDef and return the path to the field... if (state.whereSQL) { return `${func} CASE WHEN ${state.whereSQL} THEN ${thing} END)`; @@ -1149,13 +1180,17 @@ class QueryFieldDistinctKey extends QueryAtomicField { const parentKey = this.parent.parent ?.getDistinctKey() .generateExpression(resultSet); - return `CONCAT(${parentKey}, 'x', ${this.parent.dialect.sqlFieldReference( - this.parent.getIdentifier(), - '__row_id', - 'string', - true, - false - )})`; + return this.parent.dialect.concat( + parentKey || '', // shouldn't have to do this... + "'x'", + this.parent.dialect.sqlFieldReference( + this.parent.getIdentifier(), + '__row_id', + 'string', + true, + false + ) + ); } else { // return this.parent.getIdentifier() + "." + "__distinct_key"; return this.parent.dialect.sqlFieldReference( @@ -1499,7 +1534,7 @@ class FieldInstanceResult implements FieldInstance { addStructToJoin( qs: QueryStruct, query: QueryQuery, - mayNeedUniqueKey: boolean, + uniqueKeyPossibleUse: UniqueKeyPossibleUse | undefined, joinStack: string[] ): void { const name = qs.getIdentifier(); @@ -1509,9 +1544,9 @@ class FieldInstanceResult implements FieldInstance { return; } - let join; + let join: JoinInstance | undefined; if ((join = this.root().joins.get(name))) { - join.mayNeedUniqueKey ||= mayNeedUniqueKey; + join.uniqueKeyPossibleUses.add_use(uniqueKeyPossibleUse); return; } @@ -1520,7 +1555,7 @@ class FieldInstanceResult implements FieldInstance { const parentStruct = qs.parent?.getJoinableParent(); if (parentStruct) { // add dependant expressions first... - this.addStructToJoin(parentStruct, query, false, joinStack); + this.addStructToJoin(parentStruct, query, undefined, joinStack); parent = this.root().joins.get(parentStruct.getIdentifier()); } @@ -1542,7 +1577,7 @@ class FieldInstanceResult implements FieldInstance { join = new JoinInstance(qs, name, parent); this.root().joins.set(name, join); } - join.mayNeedUniqueKey ||= mayNeedUniqueKey; + join.uniqueKeyPossibleUses.add_use(uniqueKeyPossibleUse); } findJoins(query: QueryQuery) { @@ -1550,7 +1585,7 @@ class FieldInstanceResult implements FieldInstance { this.addStructToJoin( dim.f.getJoinableParent(), query, - dim.f.mayNeedUniqueKey(), + dim.f.uniqueKeyPossibleUse(), [] ); } @@ -1667,7 +1702,7 @@ class FieldInstanceResultRoot extends FieldInstanceResult { // look at all the fields again in the structs in the query calculateSymmetricAggregates() { - let leafiest; + let leafiest: string | undefined; for (const [name, join] of this.joins) { // first join is by default the const relationship = join.parentRelationship(); @@ -1702,8 +1737,23 @@ class FieldInstanceResultRoot extends FieldInstanceResult { // Nested Unique keys are dependant on the primary key of the parent // and the table. for (const [_name, join] of this.joins) { - // don't need keys on leafiest - if (!join.leafiest && join.mayNeedUniqueKey) { + // in a one_to_many join we need a key to count there may be a failed + // match in a left join. + // users -> { + // group_by: user_id + // aggregate: order_count is orders.count() + if (join.leafiest) { + if ( + join.parent !== null && + join.uniqueKeyPossibleUses.has('count') && + !join.queryStruct.primaryKey() + ) { + join.makeUniqueKey = true; + } + } else if ( + !join.leafiest && + join.uniqueKeyPossibleUses.hasAsymetricFunctions() + ) { let j: JoinInstance | undefined = join; while (j) { if (!j.queryStruct.primaryKey()) { @@ -1721,7 +1771,7 @@ class FieldInstanceResultRoot extends FieldInstanceResult { } class JoinInstance { - mayNeedUniqueKey = false; + uniqueKeyPossibleUses: UniqueKeyUse = new UniqueKeyUse(); makeUniqueKey = false; leafiest = false; joinFilterConditions?: QueryFieldBoolean[]; @@ -2133,7 +2183,7 @@ class QueryQuery extends QueryField { resultStruct: FieldInstanceResult, context: QueryStruct, path: string, - mayNeedUniqueKey: boolean, + uniqueKeyPossibleUse: UniqueKeyPossibleUse | undefined, joinStack: string[] ) { const node = context.getFieldByName(path); @@ -2150,7 +2200,7 @@ class QueryQuery extends QueryField { .addStructToJoin( struct.getJoinableParent(), this, - mayNeedUniqueKey, + uniqueKeyPossibleUse, joinStack ); } @@ -2202,7 +2252,7 @@ class QueryQuery extends QueryField { .addStructToJoin( field.parent.getJoinableParent(), this, - false, + undefined, joinStack ); // this.addDependantPath(resultStruct, field.parent, expr.path, false); @@ -2261,12 +2311,17 @@ class QueryQuery extends QueryField { resultStruct, context, expr.structPath, - true, + expr.function, joinStack ); } else { // we are doing a sum in the root. It may need symetric aggregates - resultStruct.addStructToJoin(context, this, true, joinStack); + resultStruct.addStructToJoin( + context, + this, + expr.function, + joinStack + ); } } this.addDependantExpr(resultStruct, context, expr.e, joinStack); @@ -2276,7 +2331,7 @@ class QueryQuery extends QueryField { resultStruct, context, expr.structPath, - true, + 'generic_aggregate', joinStack ); } @@ -2408,7 +2463,7 @@ class QueryQuery extends QueryField { prepare(_stageWriter: StageWriter | undefined) { if (!this.prepared) { this.expandFields(this.rootResult); - this.rootResult.addStructToJoin(this.parent, this, false, []); + this.rootResult.addStructToJoin(this.parent, this, undefined, []); this.rootResult.findJoins(this); this.rootResult.calculateSymmetricAggregates(); this.prepared = true; diff --git a/packages/malloy/src/model/malloy_types.ts b/packages/malloy/src/model/malloy_types.ts index ac1dab4cc..11d6db2aa 100644 --- a/packages/malloy/src/model/malloy_types.ts +++ b/packages/malloy/src/model/malloy_types.ts @@ -186,9 +186,17 @@ export function isDialectFragment(f: Fragment): f is DialectFragment { return (f as DialectFragment)?.type === 'dialect'; } +export type AggregateFunctionType = + | 'sum' + | 'avg' + | 'count' + | 'count_distinct' + | 'max' + | 'min'; + export interface AggregateFragment { type: 'aggregate'; - function: string; + function: AggregateFunctionType; e: Expr; structPath?: string; } diff --git a/test/src/databases/all/nomodel.spec.ts b/test/src/databases/all/nomodel.spec.ts index a720d32a6..cadf442f9 100644 --- a/test/src/databases/all/nomodel.spec.ts +++ b/test/src/databases/all/nomodel.spec.ts @@ -410,6 +410,66 @@ runtimes.runtimeMap.forEach((runtime, databaseName) => { }); }); + it(`leafy count - ${databaseName}`, async () => { + // in a joined table when the joined is leafiest + // we need to make sure we don't count rows that + // don't match the join. + await expect(` + source: am_states is ${databaseName}.table('malloytest.state_facts') -> { + select: * + where: state ~ r'^(A|M)' + } + + source: states is ${databaseName}.table('malloytest.state_facts') extend { + join_many: am_states on state=am_states.state + } + + run: states -> { + where: state = 'CA' + aggregate: + leafy_count is am_states.count() + root_count is count() + } + `).malloyResultMatches(runtime, { + leafy_count: 0, + root_count: 1, + }); + }); + + it(`leafy nested count - ${databaseName}`, async () => { + // in a joined table when the joined is leafiest + // we need to make sure we don't count rows that + // don't match the join. + await expect(` + source: am_states is ${databaseName}.table('malloytest.state_facts') -> { + group_by: state + where: state ~ r'^(A|M)' + nest: nested_state is { + group_by: state + } + } + + source: states is ${databaseName}.table('malloytest.state_facts') extend { + join_many: am_states on state=am_states.state + } + + run: states -> { + where: state = 'CA' + group_by: + state + am_state is am_states.state + aggregate: + leafy_count is am_states.nested_state.count() + root_count is count() + } + `).malloyResultMatches(runtime, { + leafy_count: 0, + root_count: 1, + state: 'CA', + am_state: null, + }); + }); + it(`basic index - ${databaseName}`, async () => { // Make sure basic indexing works. await expect(`