Skip to content

Commit

Permalink
Miss-counting one_to_many... (#1585)
Browse files Browse the repository at this point in the history
* Interm Checkin

* Skip the leafy test for now

* add more complete tests, write the fix

* Wow, concat in postgres sucks.

* weird order of precedence with jsonb

* A better fix for the problem.
  • Loading branch information
lloydtabb authored Jan 18, 2024
1 parent 449c6c0 commit a9267b4
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 38 deletions.
6 changes: 3 additions & 3 deletions packages/malloy/src/dialect/postgres/postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,16 @@ export class PostgresDialect extends Dialect {
isNested: boolean,
_isArray: boolean
): string {
let ret = `${alias}->>'${fieldName}'`;
let ret = `(${alias}->>'${fieldName}')`;
if (isNested) {
switch (fieldType) {
case 'string':
break;
case 'number':
ret = `(${ret})::double precision`;
ret = `${ret}::double precision`;
break;
case 'struct':
ret = `(${ret})::jsonb`;
ret = `${ret}::jsonb`;
break;
}
return ret;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import {
AggregateFragment,
AggregateFunctionType,
expressionIsAggregate,
FieldDef,
FieldValueType,
Expand Down Expand Up @@ -50,7 +51,7 @@ export abstract class ExprAggregateFunction extends ExpressionDef {
explicitSource?: boolean;
legalChildTypes = [FT.numberT];
constructor(
readonly func: string,
readonly func: AggregateFunctionType,
expr?: ExpressionDef,
explicitSource?: boolean
) {
Expand Down
121 changes: 88 additions & 33 deletions packages/malloy/src/model/malloy_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {Dialect, DialectFieldList, getDialect} from '../dialect';
import {StandardSQLDialect} from '../dialect/standardsql/standardsql';
import {
AggregateFragment,
AggregateFunctionType,
CompiledQuery,
DialectFragment,
Expr,
Expand Down Expand Up @@ -127,6 +128,21 @@ interface OutputPipelinedSQL {
pipelineSQL: string;
}

// Track the times we might need a unique key
type UniqueKeyPossibleUse = AggregateFunctionType | 'generic_aggregate';

class UniqueKeyUse extends Set<UniqueKeyPossibleUse> {
add_use(k: UniqueKeyPossibleUse | undefined) {
if (k !== undefined) {
return this.add(k);
}
}

hasAsymetricFunctions(): boolean {
return this.has('sum') || this.has('avg') || this.has('count');
}
}

class StageWriter {
withs: string[] = [];
udfs: string[] = [];
Expand Down Expand Up @@ -307,8 +323,8 @@ class QueryField extends QueryNode {
this.fieldDef = fieldDef;
}

mayNeedUniqueKey(): boolean {
return false;
uniqueKeyPossibleUse(): UniqueKeyPossibleUse | undefined {
return undefined;
}

getJoinableParent(): QueryStruct {
Expand Down Expand Up @@ -750,16 +766,31 @@ class QueryField extends QueryNode {
): string {
let func = 'COUNT(';
let thing = '1';
const distinctKeySQL = this.generateDistinctKeyIfNecessary(
resultSet,
context,
expr.structPath
);
if (distinctKeySQL) {

let struct = context;
if (expr.structPath) {
struct = this.parent.root().getStructByName(expr.structPath);
}
const joinName = struct.getJoinableParent().getIdentifier();
const join = resultSet.root().joins.get(joinName);
if (!join) {
throw new Error(`Join ${joinName} not found in result set`);
}
if (!join.leafiest || join.makeUniqueKey) {
func = 'COUNT(DISTINCT';
thing = distinctKeySQL;
thing = struct.getDistinctKey().generateExpression(resultSet);
}

// const distinctKeySQL = this.generateDistinctKeyIfNecessary(
// resultSet,
// context,
// expr.structPath
// );
// if (distinctKeySQL) {
// func = 'COUNT(DISTINCT';
// thing = distinctKeySQL;
// }

// find the structDef and return the path to the field...
if (state.whereSQL) {
return `${func} CASE WHEN ${state.whereSQL} THEN ${thing} END)`;
Expand Down Expand Up @@ -1149,13 +1180,17 @@ class QueryFieldDistinctKey extends QueryAtomicField {
const parentKey = this.parent.parent
?.getDistinctKey()
.generateExpression(resultSet);
return `CONCAT(${parentKey}, 'x', ${this.parent.dialect.sqlFieldReference(
this.parent.getIdentifier(),
'__row_id',
'string',
true,
false
)})`;
return this.parent.dialect.concat(
parentKey || '', // shouldn't have to do this...
"'x'",
this.parent.dialect.sqlFieldReference(
this.parent.getIdentifier(),
'__row_id',
'string',
true,
false
)
);
} else {
// return this.parent.getIdentifier() + "." + "__distinct_key";
return this.parent.dialect.sqlFieldReference(
Expand Down Expand Up @@ -1499,7 +1534,7 @@ class FieldInstanceResult implements FieldInstance {
addStructToJoin(
qs: QueryStruct,
query: QueryQuery,
mayNeedUniqueKey: boolean,
uniqueKeyPossibleUse: UniqueKeyPossibleUse | undefined,
joinStack: string[]
): void {
const name = qs.getIdentifier();
Expand All @@ -1509,9 +1544,9 @@ class FieldInstanceResult implements FieldInstance {
return;
}

let join;
let join: JoinInstance | undefined;
if ((join = this.root().joins.get(name))) {
join.mayNeedUniqueKey ||= mayNeedUniqueKey;
join.uniqueKeyPossibleUses.add_use(uniqueKeyPossibleUse);
return;
}

Expand All @@ -1520,7 +1555,7 @@ class FieldInstanceResult implements FieldInstance {
const parentStruct = qs.parent?.getJoinableParent();
if (parentStruct) {
// add dependant expressions first...
this.addStructToJoin(parentStruct, query, false, joinStack);
this.addStructToJoin(parentStruct, query, undefined, joinStack);
parent = this.root().joins.get(parentStruct.getIdentifier());
}

Expand All @@ -1542,15 +1577,15 @@ class FieldInstanceResult implements FieldInstance {
join = new JoinInstance(qs, name, parent);
this.root().joins.set(name, join);
}
join.mayNeedUniqueKey ||= mayNeedUniqueKey;
join.uniqueKeyPossibleUses.add_use(uniqueKeyPossibleUse);
}

findJoins(query: QueryQuery) {
for (const dim of this.fields()) {
this.addStructToJoin(
dim.f.getJoinableParent(),
query,
dim.f.mayNeedUniqueKey(),
dim.f.uniqueKeyPossibleUse(),
[]
);
}
Expand Down Expand Up @@ -1667,7 +1702,7 @@ class FieldInstanceResultRoot extends FieldInstanceResult {
// look at all the fields again in the structs in the query

calculateSymmetricAggregates() {
let leafiest;
let leafiest: string | undefined;
for (const [name, join] of this.joins) {
// first join is by default the
const relationship = join.parentRelationship();
Expand Down Expand Up @@ -1702,8 +1737,23 @@ class FieldInstanceResultRoot extends FieldInstanceResult {
// Nested Unique keys are dependant on the primary key of the parent
// and the table.
for (const [_name, join] of this.joins) {
// don't need keys on leafiest
if (!join.leafiest && join.mayNeedUniqueKey) {
// in a one_to_many join we need a key to count there may be a failed
// match in a left join.
// users -> {
// group_by: user_id
// aggregate: order_count is orders.count()
if (join.leafiest) {
if (
join.parent !== null &&
join.uniqueKeyPossibleUses.has('count') &&
!join.queryStruct.primaryKey()
) {
join.makeUniqueKey = true;
}
} else if (
!join.leafiest &&
join.uniqueKeyPossibleUses.hasAsymetricFunctions()
) {
let j: JoinInstance | undefined = join;
while (j) {
if (!j.queryStruct.primaryKey()) {
Expand All @@ -1721,7 +1771,7 @@ class FieldInstanceResultRoot extends FieldInstanceResult {
}

class JoinInstance {
mayNeedUniqueKey = false;
uniqueKeyPossibleUses: UniqueKeyUse = new UniqueKeyUse();
makeUniqueKey = false;
leafiest = false;
joinFilterConditions?: QueryFieldBoolean[];
Expand Down Expand Up @@ -2133,7 +2183,7 @@ class QueryQuery extends QueryField {
resultStruct: FieldInstanceResult,
context: QueryStruct,
path: string,
mayNeedUniqueKey: boolean,
uniqueKeyPossibleUse: UniqueKeyPossibleUse | undefined,
joinStack: string[]
) {
const node = context.getFieldByName(path);
Expand All @@ -2150,7 +2200,7 @@ class QueryQuery extends QueryField {
.addStructToJoin(
struct.getJoinableParent(),
this,
mayNeedUniqueKey,
uniqueKeyPossibleUse,
joinStack
);
}
Expand Down Expand Up @@ -2202,7 +2252,7 @@ class QueryQuery extends QueryField {
.addStructToJoin(
field.parent.getJoinableParent(),
this,
false,
undefined,
joinStack
);
// this.addDependantPath(resultStruct, field.parent, expr.path, false);
Expand Down Expand Up @@ -2261,12 +2311,17 @@ class QueryQuery extends QueryField {
resultStruct,
context,
expr.structPath,
true,
expr.function,
joinStack
);
} else {
// we are doing a sum in the root. It may need symetric aggregates
resultStruct.addStructToJoin(context, this, true, joinStack);
resultStruct.addStructToJoin(
context,
this,
expr.function,
joinStack
);
}
}
this.addDependantExpr(resultStruct, context, expr.e, joinStack);
Expand All @@ -2276,7 +2331,7 @@ class QueryQuery extends QueryField {
resultStruct,
context,
expr.structPath,
true,
'generic_aggregate',
joinStack
);
}
Expand Down Expand Up @@ -2408,7 +2463,7 @@ class QueryQuery extends QueryField {
prepare(_stageWriter: StageWriter | undefined) {
if (!this.prepared) {
this.expandFields(this.rootResult);
this.rootResult.addStructToJoin(this.parent, this, false, []);
this.rootResult.addStructToJoin(this.parent, this, undefined, []);
this.rootResult.findJoins(this);
this.rootResult.calculateSymmetricAggregates();
this.prepared = true;
Expand Down
10 changes: 9 additions & 1 deletion packages/malloy/src/model/malloy_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,17 @@ export function isDialectFragment(f: Fragment): f is DialectFragment {
return (f as DialectFragment)?.type === 'dialect';
}

export type AggregateFunctionType =
| 'sum'
| 'avg'
| 'count'
| 'count_distinct'
| 'max'
| 'min';

export interface AggregateFragment {
type: 'aggregate';
function: string;
function: AggregateFunctionType;
e: Expr;
structPath?: string;
}
Expand Down
60 changes: 60 additions & 0 deletions test/src/databases/all/nomodel.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,66 @@ runtimes.runtimeMap.forEach((runtime, databaseName) => {
});
});

it(`leafy count - ${databaseName}`, async () => {
// in a joined table when the joined is leafiest
// we need to make sure we don't count rows that
// don't match the join.
await expect(`
source: am_states is ${databaseName}.table('malloytest.state_facts') -> {
select: *
where: state ~ r'^(A|M)'
}
source: states is ${databaseName}.table('malloytest.state_facts') extend {
join_many: am_states on state=am_states.state
}
run: states -> {
where: state = 'CA'
aggregate:
leafy_count is am_states.count()
root_count is count()
}
`).malloyResultMatches(runtime, {
leafy_count: 0,
root_count: 1,
});
});

it(`leafy nested count - ${databaseName}`, async () => {
// in a joined table when the joined is leafiest
// we need to make sure we don't count rows that
// don't match the join.
await expect(`
source: am_states is ${databaseName}.table('malloytest.state_facts') -> {
group_by: state
where: state ~ r'^(A|M)'
nest: nested_state is {
group_by: state
}
}
source: states is ${databaseName}.table('malloytest.state_facts') extend {
join_many: am_states on state=am_states.state
}
run: states -> {
where: state = 'CA'
group_by:
state
am_state is am_states.state
aggregate:
leafy_count is am_states.nested_state.count()
root_count is count()
}
`).malloyResultMatches(runtime, {
leafy_count: 0,
root_count: 1,
state: 'CA',
am_state: null,
});
});

it(`basic index - ${databaseName}`, async () => {
// Make sure basic indexing works.
await expect(`
Expand Down

0 comments on commit a9267b4

Please sign in to comment.