diff --git a/packages/malloy/src/dialect/duckdb/dialect_functions.ts b/packages/malloy/src/dialect/duckdb/dialect_functions.ts index ae1be02fb..639bc6d8c 100644 --- a/packages/malloy/src/dialect/duckdb/dialect_functions.ts +++ b/packages/malloy/src/dialect/duckdb/dialect_functions.ts @@ -11,6 +11,13 @@ import { OverloadedDefinitionBlueprint, } from '../functions/util'; +const array_first: DefinitionBlueprint = { + takes: {'value': {array: {generic: 'T'}}}, + generic: ['T', ['any']], + returns: {generic: 'T'}, + impl: {sql: 'list_extract(${value}, 1)'}, +}; + const dayname: DefinitionBlueprint = { takes: {'date_value': ['date', 'timestamp']}, returns: 'string', @@ -74,6 +81,7 @@ const string_agg_distinct: OverloadedDefinitionBlueprint = { }; export const DUCKDB_DIALECT_FUNCTIONS: DefinitionBlueprintMap = { + array_first, count_approx, dayname, to_timestamp, diff --git a/packages/malloy/src/dialect/duckdb/duckdb.ts b/packages/malloy/src/dialect/duckdb/duckdb.ts index d58b1e231..0dca7a9b2 100644 --- a/packages/malloy/src/dialect/duckdb/duckdb.ts +++ b/packages/malloy/src/dialect/duckdb/duckdb.ts @@ -359,6 +359,68 @@ export class DuckDBDialect extends PostgresBase { getDialectFunctions(): {[name: string]: DialectFunctionOverloadDef[]} { return expandBlueprintMap(DUCKDB_DIALECT_FUNCTIONS); + // { + // ...expandBlueprintMap(DUCKDB_DIALECT_FUNCTIONS), + // 'test_foo': [ + // { + // returnType: { + // type: 'generic', + // generic: 'T', + // expressionType: 'scalar', + // evalSpace: 'constant', + // }, + // params: [ + // { + // name: 'foo', + // allowedTypes: [ + // { + // type: 'generic', + // generic: 'T', + // expressionType: 'scalar', + // evalSpace: 'constant', + // }, + // ], + // isVariadic: false, + // }, + // ], + // genericTypes: [ + // {name: 'T', acceptibleTypes: [{type: 'string'}, {type: 'number'}]}, + // ], + // e: {node: 'function_parameter', name: 'foo'}, + // between: undefined, + // }, + // ], + // 'test_first': [ + // { + // returnType: { + // type: 'generic', + // generic: 'T', + // expressionType: 'scalar', + // evalSpace: 'constant', + // }, + // params: [ + // { + // name: 'foo', + // allowedTypes: [ + // { + // type: 'array', + // elementTypeDef: { + // type: 'generic', + // generic: 'T', + // }, + // expressionType: 'scalar', + // evalSpace: 'constant', + // }, + // ], + // isVariadic: false, + // }, + // ], + // genericTypes: [{name: 'T', acceptibleTypes: [{type: 'any'}]}], + // e: {node: 'function_parameter', name: 'foo'}, + // between: undefined, + // }, + // ], + // }; } malloyTypeToSQLType(malloyType: AtomicTypeDef): string { diff --git a/packages/malloy/src/dialect/functions/util.ts b/packages/malloy/src/dialect/functions/util.ts index 697889c2e..9c15dcb45 100644 --- a/packages/malloy/src/dialect/functions/util.ts +++ b/packages/malloy/src/dialect/functions/util.ts @@ -21,7 +21,6 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -import {emptyCompositeFieldUsage} from '../../model/composite_source_utils'; import { FunctionParameterDef, TypeDesc, @@ -32,13 +31,21 @@ import { TD, mkFieldDef, FieldDef, + FunctionReturnTypeDesc, + FunctionParamTypeDef, + ExpressionType, + EvalSpace, + FunctionReturnTypeDef, + TypedDef, + FunctionGenericTypeDef, } from '../../model/malloy_types'; import {SQLExprElement} from '../../model/utils'; export interface DialectFunctionOverloadDef { // The expression type here is the MINIMUM return type - returnType: TypeDesc; + returnType: FunctionReturnTypeDesc; params: FunctionParameterDef[]; + genericTypes?: {name: string; acceptibleTypes: FunctionGenericTypeDef[]}[]; e: Expr; needsWindowOrderBy?: boolean; isSymmetric?: boolean; @@ -90,21 +97,27 @@ export function sql( return ret; } -export function constant(type: FunctionParamTypeDesc): FunctionParamTypeDesc { +export function constant< + T extends {expressionType: ExpressionType | undefined}, +>(type: T): T & TypeDescExtras { return { ...type, evalSpace: 'constant', }; } -export function output(type: FunctionParamTypeDesc): FunctionParamTypeDesc { +export function output( + type: T +): T & TypeDescExtras { return { ...type, evalSpace: 'output', }; } -export function literal(type: FunctionParamTypeDesc): FunctionParamTypeDesc { +export function literal( + type: T +): T & TypeDescExtras { return { ...type, evalSpace: 'literal', @@ -143,46 +156,48 @@ export function makeParam( return {param: param(name, ...allowedTypes), arg: arg(name)}; } -export function maxScalar(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'scalar', evalSpace: 'input'}; +export function maxScalar(type: T): T & TypeDescExtras { + return {...type, expressionType: 'scalar', evalSpace: 'input'}; } -export function maxAggregate(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'aggregate', evalSpace: 'input'}; +export function maxAggregate(type: T): T & TypeDescExtras { + return {...type, expressionType: 'aggregate', evalSpace: 'input'}; } -export function anyExprType(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: undefined, evalSpace: 'input'}; +export function anyExprType(type: T): T & TypeDescExtras { + return {...type, expressionType: undefined, evalSpace: 'input'}; } -function anyExprTypeBP( - type: TypeDescBlueprint, - generic: {name: string; type: LeafExpressionType} | undefined -): FunctionParamTypeDesc { - const typeDesc = expandReturnTypeBlueprint(type, generic); - return {...typeDesc, expressionType: undefined, evalSpace: 'input'}; -} +// function anyExprTypeBP(type: T): T & TypeDescExtras { +// const typeDesc = expandReturnTypeBlueprint(type); +// return {...typeDesc, expressionType: undefined, evalSpace: 'input'}; +// } export function maxUngroupedAggregate( - type: LeafExpressionType + type: FunctionParamTypeDef ): FunctionParamTypeDesc { - return {type, expressionType: 'ungrouped_aggregate', evalSpace: 'input'}; + return {...type, expressionType: 'ungrouped_aggregate', evalSpace: 'input'}; } -export function maxAnalytic(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'aggregate_analytic', evalSpace: 'input'}; +type TypeDescExtras = { + expressionType: ExpressionType | undefined; + evalSpace: EvalSpace; +}; + +export function maxAnalytic(type: T): T & TypeDescExtras { + return {...type, expressionType: 'aggregate_analytic', evalSpace: 'input'}; } -export function minScalar(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'scalar', evalSpace: 'input'}; +export function minScalar(type: T): T & TypeDescExtras { + return {...type, expressionType: 'scalar', evalSpace: 'input'}; } -export function minAggregate(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'aggregate', evalSpace: 'input'}; +export function minAggregate(type: T): T & TypeDescExtras { + return {...type, expressionType: 'aggregate', evalSpace: 'input'}; } -export function minAnalytic(type: LeafExpressionType): FunctionParamTypeDesc { - return {type, expressionType: 'scalar_analytic', evalSpace: 'input'}; +export function minAnalytic(type: T): T & TypeDescExtras { + return {...type, expressionType: 'scalar_analytic', evalSpace: 'input'}; } export function overload( @@ -212,24 +227,28 @@ export function overload( } export interface ArrayBlueprint { - array: TypeDescElementBlueprint; + array: TypeDescElementBlueprintOrNamedGeneric; } +export type TypeDescElementBlueprintOrNamedGeneric = + | TypeDescElementBlueprint + | NamedGeneric; export interface RecordBlueprint { - record: Record; + record: Record; } +export type LeafPlusType = LeafExpressionType | 'any'; export type TypeDescElementBlueprint = - | LeafExpressionType + | LeafPlusType | ArrayBlueprint | RecordBlueprint; +export type NamedGeneric = {generic: string}; export type TypeDescBlueprint = - | TypeDescElementBlueprint - | {generic: string} - | {literal: LeafExpressionType | {generic: string}} - | {constant: LeafExpressionType | {generic: string}} - | {dimension: LeafExpressionType | {generic: string}} - | {measure: LeafExpressionType | {generic: string}} - | {calculation: LeafExpressionType | {generic: string}}; + | TypeDescElementBlueprintOrNamedGeneric + | {literal: TypeDescElementBlueprintOrNamedGeneric} + | {constant: TypeDescElementBlueprintOrNamedGeneric} + | {dimension: TypeDescElementBlueprintOrNamedGeneric} + | {measure: TypeDescElementBlueprintOrNamedGeneric} + | {calculation: TypeDescElementBlueprintOrNamedGeneric}; type ParamTypeBlueprint = | TypeDescBlueprint @@ -239,7 +258,7 @@ type ParamTypeBlueprint = export interface SignatureBlueprint { // today only one generic is allowed, but if we need more // we could change this to `{[name: string]: ExpressionValueType[]}` - generic?: [string, LeafExpressionType[]]; + generic?: [string, TypeDescElementBlueprintOrNamedGeneric[]]; takes: {[name: string]: ParamTypeBlueprint}; returns: TypeDescBlueprint; supportsOrderBy?: boolean | 'only_default'; @@ -290,86 +309,104 @@ export type OverrideMap = { [name: string]: ImplementationBlueprint | OverloadedImplementationBlueprint; }; -function removeGeneric( - type: LeafExpressionType | {generic: string}, - generic: {name: string; type: LeafExpressionType} | undefined -) { - if (typeof type === 'string') { - return type; - } - if (type.generic !== generic?.name) { - throw new Error(`Cannot expand generic name ${type.generic}`); +function expandTypeDescElementBlueprint( + blueprint: TypeDescElementBlueprintOrNamedGeneric, + allowAny: false +): FunctionReturnTypeDef; +function expandTypeDescElementBlueprint( + blueprint: TypeDescElementBlueprintOrNamedGeneric, + allowAny?: true +): FunctionParamTypeDef; +function expandTypeDescElementBlueprint( + blueprint: TypeDescElementBlueprintOrNamedGeneric, + allowAny: true, + allowGenerics: false +): FunctionGenericTypeDef; +function expandTypeDescElementBlueprint( + blueprint: TypeDescElementBlueprintOrNamedGeneric, + allowAny = true, + allowGenerics = true +): FunctionParamTypeDef | FunctionReturnTypeDef | TypedDef { + if (!allowAny && blueprint === 'any') { + throw new Error('Return type cannot include any'); } - return generic.type; -} - -function expandReturnTypeBlueprint( - blueprint: TypeDescBlueprint, - generic: {name: string; type: LeafExpressionType} | undefined -): TypeDesc { - let base: FunctionParamTypeDesc; if (typeof blueprint === 'string') { - base = minScalar(blueprint); + return {type: blueprint}; } else if ('array' in blueprint) { - const innerType = expandReturnTypeBlueprint(blueprint.array, generic); - const {expressionType, evalSpace} = innerType; - if (TD.isAtomic(innerType)) { - if (innerType.type !== 'record') { - base = { - type: 'array', - elementTypeDef: innerType, - expressionType, - evalSpace, - }; - } else { - base = { - type: 'array', - elementTypeDef: {type: 'record_element'}, - fields: innerType.fields, - expressionType, - evalSpace, - }; - } - } else { - // mtoy todo fix by doing "exapndElementBlueprint" ... - throw new Error( - `TypeDescElementBlueprint should never allow ${blueprint.array}` - ); + const innerType = allowAny + ? expandTypeDescElementBlueprint(blueprint.array, true) + : expandTypeDescElementBlueprint(blueprint.array, false); + if (innerType.type === 'record') { + return { + type: 'array', + elementTypeDef: {type: 'record_element'}, + fields: innerType.fields, + }; } + return { + type: 'array', + elementTypeDef: innerType, + }; } else if ('record' in blueprint) { const fields: FieldDef[] = []; for (const [fieldName, fieldBlueprint] of Object.entries( blueprint.record )) { - const fieldDesc = expandReturnTypeBlueprint(fieldBlueprint, generic); + const fieldDesc = allowAny + ? expandTypeDescElementBlueprint(fieldBlueprint, true) + : expandTypeDescElementBlueprint(fieldBlueprint, false); if (TD.isAtomic(fieldDesc)) { fields.push(mkFieldDef(fieldDesc, fieldName)); } } - base = { + return { type: 'record', fields, - evalSpace: 'input', - expressionType: 'scalar', }; } else if ('generic' in blueprint) { - base = minScalar(removeGeneric(blueprint, generic)); + if (!allowGenerics) { + throw new Error('Cannot use generic'); + } + return {type: 'generic', generic: blueprint.generic}; + } + throw new Error('Cannot figure out type'); +} + +function expandReturnTypeBlueprint( + blueprint: TypeDescBlueprint +): FunctionReturnTypeDesc { + if (blueprint === 'any') { + throw new Error('Cannot return any type'); + } + if (typeof blueprint === 'string') { + return minScalar({type: blueprint}); + } else if ('array' in blueprint) { + return anyExprType(expandTypeDescElementBlueprint(blueprint, false)); + } else if ('record' in blueprint) { + return anyExprType(expandTypeDescElementBlueprint(blueprint, false)); + } else if ('generic' in blueprint) { + return minScalar(expandTypeDescElementBlueprint(blueprint, false)); } else if ('literal' in blueprint) { - base = literal(minScalar(removeGeneric(blueprint.literal, generic))); + return literal( + minScalar(expandTypeDescElementBlueprint(blueprint.literal, false)) + ); } else if ('constant' in blueprint) { - base = constant(minScalar(removeGeneric(blueprint.constant, generic))); + return constant( + minScalar(expandTypeDescElementBlueprint(blueprint.constant, false)) + ); } else if ('dimension' in blueprint) { - base = minScalar(removeGeneric(blueprint.dimension, generic)); + return minScalar( + expandTypeDescElementBlueprint(blueprint.dimension, false) + ); } else if ('measure' in blueprint) { - base = minAggregate(removeGeneric(blueprint.measure, generic)); + return minAggregate( + expandTypeDescElementBlueprint(blueprint.measure, false) + ); } else { - base = minAnalytic(removeGeneric(blueprint.calculation, generic)); + return minAnalytic( + expandTypeDescElementBlueprint(blueprint.calculation, false) + ); } - return { - ...base, - compositeFieldUsage: emptyCompositeFieldUsage(), - expressionType: base.expressionType ?? 'scalar', - }; } function isTypeDescBlueprint( @@ -403,37 +440,35 @@ function extractParamTypeBlueprints( } function expandParamTypeBlueprint( - blueprint: TypeDescBlueprint, - generic: {name: string; type: LeafExpressionType} | undefined + blueprint: TypeDescBlueprint ): FunctionParamTypeDesc { if (typeof blueprint === 'string') { - return anyExprType(blueprint); + return anyExprType({type: blueprint}); } else if ('generic' in blueprint) { - return anyExprType(removeGeneric(blueprint, generic)); + return anyExprType(expandTypeDescElementBlueprint(blueprint)); } else if ('literal' in blueprint) { - return literal(maxScalar(removeGeneric(blueprint.literal, generic))); + return literal( + maxScalar(expandTypeDescElementBlueprint(blueprint.literal)) + ); } else if ('constant' in blueprint) { - return constant(maxScalar(removeGeneric(blueprint.constant, generic))); + return constant( + maxScalar(expandTypeDescElementBlueprint(blueprint.constant)) + ); } else if ('dimension' in blueprint) { - return maxScalar(removeGeneric(blueprint.dimension, generic)); + return maxScalar(expandTypeDescElementBlueprint(blueprint.dimension)); } else if ('measure' in blueprint) { - return maxAggregate(removeGeneric(blueprint.measure, generic)); + return maxAggregate(expandTypeDescElementBlueprint(blueprint.measure)); } else if ('array' in blueprint) { - return anyExprTypeBP(blueprint, generic); + return anyExprType(expandTypeDescElementBlueprint(blueprint, false)); } else if ('record' in blueprint) { - return anyExprTypeBP(blueprint, generic); + return anyExprType(expandTypeDescElementBlueprint(blueprint, false)); } else { - return maxAnalytic(removeGeneric(blueprint.calculation, generic)); + return maxAnalytic(expandTypeDescElementBlueprint(blueprint.calculation)); } } -function expandParamTypeBlueprints( - blueprints: TypeDescBlueprint[], - generic: {name: string; type: LeafExpressionType} | undefined -) { - return blueprints.map(blueprint => - expandParamTypeBlueprint(blueprint, generic) - ); +function expandParamTypeBlueprints(blueprints: TypeDescBlueprint[]) { + return blueprints.map(blueprint => expandParamTypeBlueprint(blueprint)); } function isVariadicParamBlueprint(blueprint: ParamTypeBlueprint): boolean { @@ -442,26 +477,23 @@ function isVariadicParamBlueprint(blueprint: ParamTypeBlueprint): boolean { function expandParamBlueprint( name: string, - blueprint: ParamTypeBlueprint, - generic: {name: string; type: LeafExpressionType} | undefined + blueprint: ParamTypeBlueprint ): FunctionParameterDef { return { name, allowedTypes: expandParamTypeBlueprints( - extractParamTypeBlueprints(blueprint), - generic + extractParamTypeBlueprints(blueprint) ), isVariadic: isVariadicParamBlueprint(blueprint), }; } -function expandParamsBlueprints( - blueprints: {[name: string]: ParamTypeBlueprint}, - generic: {name: string; type: LeafExpressionType} | undefined -) { +function expandParamsBlueprints(blueprints: { + [name: string]: ParamTypeBlueprint; +}) { const paramsArray = Object.entries(blueprints); return paramsArray.map(blueprint => - expandParamBlueprint(blueprint[0], blueprint[1], generic) + expandParamBlueprint(blueprint[0], blueprint[1]) ); } @@ -562,32 +594,34 @@ function expandImplBlueprint(blueprint: DefinitionBlueprint): { }; } -function expandOneBlueprint( - blueprint: DefinitionBlueprint, - generic?: {name: string; type: LeafExpressionType} +function expandGenericDefinitions( + blueprint: [string, TypeDescElementBlueprintOrNamedGeneric[]] | undefined +): {name: string; acceptibleTypes: FunctionGenericTypeDef[]}[] | undefined { + if (blueprint === undefined) return undefined; + return [ + { + name: blueprint[0], + acceptibleTypes: blueprint[1].map(t => + expandTypeDescElementBlueprint(t, true, false) + ), + }, + ]; +} + +function expandBlueprint( + blueprint: DefinitionBlueprint ): DialectFunctionOverloadDef { return { - returnType: expandReturnTypeBlueprint(blueprint.returns, generic), - params: expandParamsBlueprints(blueprint.takes, generic), + returnType: expandReturnTypeBlueprint(blueprint.returns), + params: expandParamsBlueprints(blueprint.takes), isSymmetric: blueprint.isSymmetric, supportsOrderBy: blueprint.supportsOrderBy, supportsLimit: blueprint.supportsLimit, + genericTypes: expandGenericDefinitions(blueprint.generic), ...expandImplBlueprint(blueprint), }; } -function expandBlueprint( - blueprint: DefinitionBlueprint -): DialectFunctionOverloadDef[] { - if (blueprint.generic !== undefined) { - const name = blueprint.generic[0]; - return blueprint.generic[1].map(type => - expandOneBlueprint(blueprint, {name, type}) - ); - } - return [expandOneBlueprint(blueprint)]; -} - function isDefinitionBlueprint( blueprint: DefinitionBlueprint | OverloadedDefinitionBlueprint ): blueprint is DefinitionBlueprint { @@ -604,7 +638,7 @@ function expandOverloadedBlueprint( blueprint: DefinitionBlueprint | OverloadedDefinitionBlueprint ): DialectFunctionOverloadDef[] { if (isDefinitionBlueprint(blueprint)) { - return expandBlueprint(blueprint); + return [expandBlueprint(blueprint)]; } else { return Object.values(blueprint).flatMap(overload => expandBlueprint(overload) @@ -624,7 +658,7 @@ function expandImplementationBlueprint( base: DefinitionBlueprint, impl: ImplementationBlueprint ): DialectFunctionOverloadDef[] { - return expandBlueprint({...base, impl}); + return [expandBlueprint({...base, impl})]; } function expandOverloadedOverrideBlueprint( diff --git a/packages/malloy/src/lang/ast/expressions/expr-func.ts b/packages/malloy/src/lang/ast/expressions/expr-func.ts index 8edfa54ff..fafbc2f38 100644 --- a/packages/malloy/src/lang/ast/expressions/expr-func.ts +++ b/packages/malloy/src/lang/ast/expressions/expr-func.ts @@ -33,12 +33,25 @@ import { ExpressionValueTypeDef, FunctionCallNode, FunctionDef, + FunctionGenericNonAnyTypeDef, + FunctionGenericTypeDef, FunctionOverloadDef, FunctionParameterDef, + FunctionParamFieldDef, + FunctionParamTypeDef, + FunctionReturnTypeDef, + FunctionReturnTypeDesc, + isAtomic, isAtomicFieldType, + isAtomicXYZ, isExpressionTypeLEQ, + isRepeatedRecordFunctionParam, + isScalarArray, maxOfExpressionTypes, mergeEvalSpaces, + RecordFunctionParamTypeDef, + RecordFunctionReturnTypeDef, + RecordTypeDef, TD, } from '../../../model/malloy_types'; import {errorFor} from '../ast-utils'; @@ -199,8 +212,13 @@ export class ExprFunc extends ExpressionDef { .join(', ')})` ); } - const {overload, expressionTypeErrors, evalSpaceErrors, nullabilityErrors} = - result; + const { + overload, + expressionTypeErrors, + evalSpaceErrors, + nullabilityErrors, + returnType, + } = result; // Report errors for expression type mismatch for (const error of expressionTypeErrors) { const adjustedIndex = error.argIndex - (implicitExpr ? 1 : 0); @@ -245,7 +263,7 @@ export class ExprFunc extends ExpressionDef { } const type = overload.returnType; const expressionType = maxOfExpressionTypes([ - type.expressionType, + type.expressionType ?? 'scalar', ...argExprs.map(e => e.expressionType), ]); if ( @@ -410,12 +428,6 @@ export class ExprFunc extends ExpressionDef { funcCall = composeSQLExpr(expr); } } - if (type.type === 'any') { - return this.loggedErrorExpr( - 'function-returns-any', - `Invalid return type ${type.type} for function '${this.name}'` - ); - } const maxEvalSpace = mergeEvalSpaces(...argExprs.map(e => e.evalSpace)); // If the merged eval space of all args is constant, the result is constant. // If the expression is scalar, then the eval space is that merged eval space. @@ -431,7 +443,8 @@ export class ExprFunc extends ExpressionDef { // TODO consider if I can use `computedExprValue` here... // seems like the rules for the evalSpace is a bit different from normal though return { - ...TDU.atomicDef(type), + // TODO need to handle this??? + ...(isAtomic(returnType) ? TDU.atomicDef(returnType) : returnType), expressionType, value: funcCall, evalSpace, @@ -461,6 +474,12 @@ type NullabilityError = { param: FunctionParameterDef; }; +type ReturnTypeError = { + // TODO + code: string; + data: string; +}; + function findOverload( func: FunctionDef, args: ExprValue[] @@ -470,9 +489,13 @@ function findOverload( expressionTypeErrors: ExpressionTypeError[]; evalSpaceErrors: EvalSpaceError[]; nullabilityErrors: NullabilityError[]; + returnType: ExpressionValueTypeDef; + returnTypeError?: ReturnTypeError; } | undefined { for (const overload of func.overloads) { + // Map from generic name to selected type + const genericsSelected = new Map(); let paramIndex = 0; let ok = true; let matchedVariadic = false; @@ -489,14 +512,14 @@ function findOverload( const argOk = param.allowedTypes.some(paramT => { // Check whether types match (allowing for nullability errors, expression type errors, // eval space errors, and unknown types due to prior errors in args) - const dataTypeMatch = - TD.eq(paramT, arg) || - paramT.type === 'any' || - // TODO We should consider whether `nulls` should always be allowed. It probably - // does not make sense to limit function calls to not allow nulls, since have - // so little control over nullability. - arg.type === 'null' || - arg.type === 'error'; + const {dataTypeMatch, genericsSet} = isDataTypeMatch( + overload.genericTypes ?? [], + arg, + paramT + ); + for (const genericSet of genericsSet) { + genericsSelected.set(genericSet.name, genericSet.type); + } // Check expression type errors if (paramT.expressionType) { const expressionTypeMatch = isExpressionTypeLEQ( @@ -560,12 +583,20 @@ function findOverload( ) { continue; } + const resolveReturnType = resolveGenerics( + overload.returnType, + genericsSelected + ); + const returnType = resolveReturnType.returnType ?? {type: 'number'}; if (ok) { return { overload, expressionTypeErrors, evalSpaceErrors, nullabilityErrors, + returnTypeError: resolveReturnType.error, + // TODO don't be bad!!! + returnType: returnType as ExpressionValueTypeDef, }; } } @@ -601,3 +632,176 @@ function parseSQLInterpolation(template: string): InterpolationPart[] { } return parts; } + +type OneSetGeneric = {name: string; type: FunctionGenericNonAnyTypeDef}; + +function isDataTypeMatch( + genericTypes: {name: string; acceptibleTypes: FunctionGenericTypeDef[]}[], + arg: ExpressionValueTypeDef, + paramT: FunctionGenericTypeDef | FunctionParamTypeDef +): { + dataTypeMatch: boolean; + genericsSet: OneSetGeneric[]; +} { + if ( + TD.eq(paramT, arg) || + paramT.type === 'any' || + // TODO We should consider whether `nulls` should always be allowed. It probably + // does not make sense to limit function calls to not allow nulls, since have + // so little control over nullability. + arg.type === 'null' || + arg.type === 'error' + ) { + return {dataTypeMatch: true, genericsSet: []}; + } + if (paramT.type === 'array' && arg.type === 'array') { + if (isScalarArray(arg)) { + if (!isRepeatedRecordFunctionParam(paramT)) { + return isDataTypeMatch( + genericTypes, + arg.elementTypeDef, + paramT.elementTypeDef + ); + } else { + return {dataTypeMatch: false, genericsSet: []}; + } + } else if (isRepeatedRecordFunctionParam(paramT)) { + const fakeParamRecord: RecordFunctionParamTypeDef = { + type: 'record', + fields: paramT.fields, + }; + const fakeArgRecord: RecordTypeDef = { + type: 'record', + fields: arg.fields, + }; + return isDataTypeMatch(genericTypes, fakeArgRecord, fakeParamRecord); + } else { + return {dataTypeMatch: false, genericsSet: []}; + } + } else if (paramT.type === 'record' && arg.type === 'record') { + const genericsSet: OneSetGeneric[] = []; + const paramFieldsByName = new Map(); + for (const field of paramT.fields) { + paramFieldsByName.set(field.as ?? field.name, field); + } + for (const field of arg.fields) { + const match = paramFieldsByName.get(field.as ?? field.name); + if (match === undefined) { + return {dataTypeMatch: false, genericsSet: []}; + } + const result = isDataTypeMatch(genericTypes, field, match); + genericsSet.push(...result.genericsSet); + } + return {dataTypeMatch: true, genericsSet}; + } else if (paramT.type === 'generic') { + const allowedTypes = + genericTypes.find(t => t.name === paramT.generic)?.acceptibleTypes ?? []; + for (const type of allowedTypes) { + const result = isDataTypeMatch(genericTypes, arg, type); + if (result.dataTypeMatch) { + if (!isAtomic(arg)) { + continue; + } + const newGenericSet: OneSetGeneric = { + name: paramT.generic, + type: arg, + }; + return { + dataTypeMatch: true, + genericsSet: [...result.genericsSet, newGenericSet], + }; + } + } + } + return {dataTypeMatch: false, genericsSet: []}; +} + +function resolveGenerics( + returnType: + | FunctionReturnTypeDesc + | Exclude, + genericsSelected: Map +): + | {error: undefined; returnType: FunctionGenericNonAnyTypeDef} + | {error: {code: string; data: string}; returnType: undefined} { + switch (returnType.type) { + case 'array': { + if ('fields' in returnType) { + const fields = returnType.fields.map(f => { + const type = resolveGenerics(f, genericsSelected); + return { + ...f, + ...type, + }; + }); + return { + error: undefined, + returnType: { + type: 'array', + elementTypeDef: returnType.elementTypeDef, + fields, + }, + }; + } + const resolve = resolveGenerics( + returnType.elementTypeDef, + genericsSelected + ); + if (resolve.error) { + return resolve; + } + const elementTypeDef = resolve.returnType; + if (elementTypeDef.type === 'record') { + // TODO if this happens, construct the repeated record + return { + error: { + code: 'invalid-resolved-type-for-array', + data: 'Invalid resolved type for array; cannot be record', + }, + returnType: undefined, + }; + } + if (!isAtomicXYZ(elementTypeDef)) { + return { + error: { + code: 'invalid-resolved-type-for-array', + data: 'Invalid resolved type for array; cannot be non-atomic', + }, + returnType: undefined, + }; + } + return { + error: undefined, + returnType: {type: 'array', elementTypeDef}, + }; + } + case 'record': { + const fields = returnType.fields.map(f => { + const type = resolveGenerics(f, genericsSelected); + return { + ...f, + ...type, + }; + }); + return {error: undefined, returnType: {type: 'record', fields}}; + } + case 'generic': { + const resolved = genericsSelected.get(returnType.generic); + if (resolved === undefined) { + return { + error: { + code: 'generic-not-resolved', + data: `Generic ${returnType.generic} in return type could not be resolved`, + }, + returnType: undefined, + }; + } + return { + error: undefined, + returnType: resolved, + }; + } + default: + return {error: undefined, returnType}; + } +} diff --git a/packages/malloy/src/lang/ast/types/dialect-name-space.ts b/packages/malloy/src/lang/ast/types/dialect-name-space.ts index 91a47a6cf..0223ed6f8 100644 --- a/packages/malloy/src/lang/ast/types/dialect-name-space.ts +++ b/packages/malloy/src/lang/ast/types/dialect-name-space.ts @@ -29,6 +29,7 @@ export class DialectNameSpace implements NameSpace { supportsOrderBy: overload.supportsOrderBy, supportsLimit: overload.supportsLimit, isSymmetric: overload.isSymmetric, + genericTypes: overload.genericTypes, dialect: { [dialect.name]: { e: overload.e, diff --git a/packages/malloy/src/lang/ast/types/global-name-space.ts b/packages/malloy/src/lang/ast/types/global-name-space.ts index a0d3bd651..66b1450f4 100644 --- a/packages/malloy/src/lang/ast/types/global-name-space.ts +++ b/packages/malloy/src/lang/ast/types/global-name-space.ts @@ -104,6 +104,7 @@ export function getDialectFunctions(): Map { dialect: {}, supportsOrderBy: baseOverload.supportsOrderBy, supportsLimit: baseOverload.supportsLimit, + genericTypes: baseOverload.genericTypes, isSymmetric: baseOverload.isSymmetric, }; for (const dialect of dialects) { diff --git a/packages/malloy/src/model/malloy_types.ts b/packages/malloy/src/model/malloy_types.ts index da97f8cc9..9ea4934bf 100644 --- a/packages/malloy/src/model/malloy_types.ts +++ b/packages/malloy/src/model/malloy_types.ts @@ -786,6 +786,30 @@ export interface RepeatedRecordDef export type ArrayTypeDef = ScalarArrayTypeDef | RepeatedRecordTypeDef; export type ArrayDef = ScalarArrayDef | RepeatedRecordDef; +// function isRepeatedRecordXYZ( +// paramT: XYZTypeDef +// ): paramT is RepeatedRecordXYZTypeDef { +// return ( +// paramT.type === 'array' && paramT.elementTypeDef.type === 'record_element' +// ); +// } + +// function isScalarArrayXYZ( +// paramT: XYZTypeDef +// ): paramT is ScalarArrayXYZTypeDef { +// return ( +// paramT.type === 'array' && paramT.elementTypeDef.type !== 'record_element' +// ); +// } + +export function isRepeatedRecordFunctionParam( + paramT: FunctionParamTypeDef +): paramT is RepeatedRecordFunctionParamTypeDef { + return ( + paramT.type === 'array' && paramT.elementTypeDef.type === 'record_element' + ); +} + export function isRepeatedRecord( fd: FieldDef | QueryFieldDef | StructDef | AtomicTypeDef ): fd is RepeatedRecordTypeDef { @@ -1200,7 +1224,6 @@ export type NonAtomicType = | 'turtle' // do NOT have the full type info, just noting the type | 'null' | 'duration' - | 'any' | 'regular expression'; export interface NonAtomicTypeDef { type: NonAtomicType; @@ -1221,11 +1244,100 @@ export type TypeInfo = { export type TypeDesc = ExpressionValueTypeDef & TypeInfo; -export type FunctionParamType = ExpressionValueTypeDef | {type: 'any'}; -export type FunctionParamTypeDesc = FunctionParamType & { +export type FunctionParamTypeDef = XYZTypeDef; +export type FunctionParamTypeDesc = FunctionParamTypeDef & { expressionType: ExpressionType | undefined; evalSpace: EvalSpace; }; +// +export interface ScalarArrayXYZTypeDef { + type: 'array'; + elementTypeDef: Exclude, RecordXYZTypeDef>; +} + +// TODO? +export type FieldTypeDef = LeafAtomicTypeDef; +// | Join +// | Turtle; + +// export type XYZTypeDef = FieldTypeDef | XYZ; +export type XYZTypeDef = + | AtomicTypeDef + | NonAtomicTypeDef + | ScalarArrayXYZTypeDef + | RecordXYZTypeDef + | RepeatedRecordXYZTypeDef + | XYZ; + +export interface RecordXYZTypeDef { + type: 'record'; + fields: XYZFieldDef[]; +} + +export type XYZFieldDef = FieldDef | (XYZ & FieldBase); + +export interface RepeatedRecordXYZTypeDef { + type: 'array'; + elementTypeDef: RecordElementTypeDef; + fields: XYZFieldDef[]; +} + +// + +type ReturnTypeExtensions = GenericTypeDef; + +export type ScalarArrayFunctionReturnTypeDef = + ScalarArrayXYZTypeDef; + +export type FunctionReturnFieldDef = XYZFieldDef; + +export type RecordFunctionReturnTypeDef = + RecordXYZTypeDef; + +export type RepeatedRecordFunctionReturnTypeDef = + RepeatedRecordXYZTypeDef; + +type ParamTypeExtensions = GenericTypeDef | AnyTypeDef; + +export type ScalarArrayFunctionParamTypeDef = + ScalarArrayXYZTypeDef; + +export type FunctionParamFieldDef = XYZFieldDef; + +export type RecordFunctionParamTypeDef = RecordXYZTypeDef; + +export type RepeatedRecordFunctionParamTypeDef = + RepeatedRecordXYZTypeDef; + +type GenericTypeExtensions = AnyTypeDef; + +export type ScalarArrayFunctionGenericTypeDef = + ScalarArrayXYZTypeDef; + +export type FunctionGenericFieldDef = XYZFieldDef; + +export type RecordFunctionGenericTypeDef = + RecordXYZTypeDef; + +export type RepeatedRecordFunctionGenericTypeDef = + RepeatedRecordXYZTypeDef; + +export interface GenericTypeDef { + type: 'generic'; + generic: string; +} + +export interface AnyTypeDef { + type: 'any'; +} + +export type TypeDescExtensions = { + expressionType: ExpressionType | undefined; + evalSpace: EvalSpace; +}; + +export type FunctionReturnTypeDef = XYZTypeDef; +export type FunctionReturnTypeDesc = FunctionReturnTypeDef & TypeDescExtensions; export type EvalSpace = 'constant' | 'input' | 'output' | 'literal'; @@ -1255,13 +1367,19 @@ export interface FunctionParameterDef { isVariadic: boolean; } +// TODO name? +export type FunctionGenericNonAnyTypeDef = XYZTypeDef; + +export type FunctionGenericTypeDef = XYZTypeDef; + export interface FunctionOverloadDef { // The expression type here is the MINIMUM return type - returnType: TypeDesc; + returnType: FunctionReturnTypeDesc; isSymmetric?: boolean; params: FunctionParameterDef[]; supportsOrderBy?: boolean | 'only_default'; supportsLimit?: boolean; + genericTypes?: {name: string; acceptibleTypes: FunctionGenericTypeDef[]}[]; dialect: { [dialect: string]: { e: Expr; @@ -1442,7 +1560,16 @@ export function isTurtle(def: TypedDef): def is TurtleDef { return def.type === 'turtle'; } -export function isAtomic(def: TypedDef): def is AtomicTypeDef { +export function isAtomicXYZ< + T extends object, + XYZ extends TypedDef | ExpressionValueTypeDef | XYZTypeDef, +>(def: XYZTypeDef): def is AtomicTypeDef { + return 'type' in def && isAtomicFieldType(def.type); +} + +export function isAtomic( + def: TypedDef | ExpressionValueTypeDef +): def is AtomicTypeDef { return isAtomicFieldType(def.type); } @@ -1507,7 +1634,12 @@ export interface PrepareResultOptions { materializedTablePrefix?: string; } -type UTD = AtomicTypeDef | FunctionParamTypeDesc | undefined; +type UTD = + | AtomicTypeDef + | TypedDef + | FunctionParamTypeDef + | FunctionReturnTypeDef + | undefined; /** * A set of utilities for asking questions TypeDef/TypeDesc * (which is OK because TypeDesc is an extension of a TypeDef) @@ -1561,9 +1693,9 @@ export const TD = { ) { return TD.eq(x.elementTypeDef, y.elementTypeDef); } - return checkFields(x, y); + return TD.isAtomic(x) && TD.isAtomic(y) && checkFields(x, y); } else if (x.type === 'record' && y.type === 'record') { - return checkFields(x, y); + return TD.isAtomic(x) && TD.isAtomic(y) && checkFields(x, y); } if (x.type === 'sql native' && y.type === 'sql native') { return x.rawType !== undefined && x.rawType === y.rawType; diff --git a/test/src/databases/all/functions.spec.ts b/test/src/databases/all/functions.spec.ts index 64de3f27f..7958cf79b 100644 --- a/test/src/databases/all/functions.spec.ts +++ b/test/src/databases/all/functions.spec.ts @@ -27,7 +27,8 @@ import {booleanResult, brokenIn, databasesFromEnvironmentOr} from '../../util'; import '../../util/db-jest-matchers'; import * as malloy from '@malloydata/malloy'; -const runtimes = new RuntimeList(databasesFromEnvironmentOr(allDatabases)); +const runtimes = new RuntimeList(databasesFromEnvironmentOr(['duckdb'])); +// const runtimes = new RuntimeList(databasesFromEnvironmentOr(allDatabases)); function modelText(databaseName: string) { return ` @@ -1312,13 +1313,19 @@ expressionModels.forEach((x, databaseName) => { describe('dialect functions', () => { describe('duckdb', () => { - const duckdb = it.when(databaseName === 'duckdb'); - duckdb('to_timestamp', async () => { + const isDuckdb = databaseName === 'duckdb'; + it.when(isDuckdb)('to_timestamp', async () => { await funcTest( 'to_timestamp(1725555835) = @2024-09-05 17:03:55', booleanResult(true, databaseName) ); }); + it.when(isDuckdb)('test_foo', async () => { + await funcTest('test_foo(5)', 5); + }); + it.when(isDuckdb)('array_first', async () => { + await funcTest('array_first(array_first([[5]]))', 5); + }); }); describe('trino', () => {