diff --git a/Project.toml b/Project.toml index 0f3b1c15..df892bee 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.2" [deps] DBInterface = "a10d1c49-ce27-4219-8d33-6db1a4562965" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -13,6 +14,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] DBInterface = "2.5" +DataAPI = "1.13" LRUCache = "1.3" OrderedCollections = "1.4" PrettyPrinting = "0.3.2, 0.4" diff --git a/docs/src/examples/index.md b/docs/src/examples/index.md index d05400fe..0aceb823 100644 --- a/docs/src/examples/index.md +++ b/docs/src/examples/index.md @@ -346,10 +346,7 @@ the definitions programmatically. !contains(String(c), "source") q = From(:person) |> - Select(Get.(filter(is_not_source_column, person_table.columns))...) - - # q = From(:person) |> - # Select(args = [Get(c) for c in person_table.columns if is_not_source_column(c)]) + Select(args = [Get(c) for c in keys(person_table.columns) if is_not_source_column(c)]) display(q) #=> @@ -447,16 +444,16 @@ however we must ensure that all column names are unique. const visit_occurrence_table = conn.catalog[:visit_occurrence] q = q |> - Select(Get.(person_table.columns)..., - Get.(visit_occurrence_table.columns, over = Get.visit)...) + Select(Get.(keys(person_table.columns))..., + Get.(keys(visit_occurrence_table.columns), over = Get.visit)...) #=> ERROR: FunSQL.DuplicateLabelError: `person_id` is used more than once in: ⋮ =# q = q |> - Select(Get.(person_table.columns)..., - Get.(filter(!in(person_table.columns), visit_occurrence_table.columns), + Select(Get.(keys(person_table.columns))..., + Get.(filter(!in(keys(person_table.columns)), collect(keys(visit_occurrence_table.columns))), over = Get.visit)...) render(conn, q) |> print diff --git a/docs/src/reference/index.md b/docs/src/reference/index.md index 23a073ed..24420ce8 100644 --- a/docs/src/reference/index.md +++ b/docs/src/reference/index.md @@ -34,7 +34,7 @@ Pages = ["connections.jl"] ``` -## `SQLCatalog` and `SQLTable` +## `SQLCatalog`, `SQLTable`, and `SQLColumn` ```@autodocs Modules = [FunSQL] diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 060c5742..6f789429 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -2277,7 +2277,7 @@ for a `CREATE TABLE AS` or `SELECT INTO` statement. with_external_handler((tbl, def)) = println("CREATE TEMP TABLE ", render(ID(tbl.qualifiers, tbl.name)), - " (", join([render(ID(c)) for c in tbl.columns], ", "), ") AS\n", + " (", join([render(ID(c.name)) for (n, c) in tbl.columns], ", "), ") AS\n", render(def), ";\n") q = From(:male) |> @@ -3872,7 +3872,8 @@ and determines node types. ⋮ │ WithContext(over = Resolved(RowType(:person_id => ScalarType(), │ :max_visit_start_date => ScalarType()), - │ over = q9)) + │ over = q9), + │ catalog = SQLCatalog(dialect = SQLDialect(), cache = nothing)) │ end └ @ FunSQL … =# @@ -3896,7 +3897,8 @@ produce. │ q5 = Get.year_of_birth, │ q6 = Linked([q2, q3, q4, q5], 3, over = q1), ⋮ - │ WithContext(over = q33) + │ WithContext(over = q33, + │ catalog = SQLCatalog(dialect = SQLDialect(), cache = nothing)) │ end └ @ FunSQL … =# @@ -3941,7 +3943,8 @@ On the next stage, the query object is converted to a SQL syntax tree. │ ID(:visit_group_1) |> ID(:person_id)), │ left = true) |> │ SELECT(ID(:person_2) |> ID(:person_id), - │ ID(:visit_group_1) |> ID(:max) |> AS(:max_visit_start_date))) + │ ID(:visit_group_1) |> ID(:max) |> AS(:max_visit_start_date)), + │ columns = [SQLColumn(:person_id), SQLColumn(:max_visit_start_date)]) └ @ FunSQL … =# @@ -3976,6 +3979,7 @@ Finally, the SQL tree is serialized into SQL. │ "visit_occurrence_1"."person_id" │ FROM "visit_occurrence" AS "visit_occurrence_1" │ GROUP BY "visit_occurrence_1"."person_id" - │ ) AS "visit_group_1" ON ("person_2"."person_id" = "visit_group_1"."person_id")""") + │ ) AS "visit_group_1" ON ("person_2"."person_id" = "visit_group_1"."person_id")""", + │ columns = [SQLColumn(:person_id), SQLColumn(:max_visit_start_date)]) └ @ FunSQL … =# diff --git a/docs/src/test/other.md b/docs/src/test/other.md index d81945fc..c1bf6611 100644 --- a/docs/src/test/other.md +++ b/docs/src/test/other.md @@ -66,23 +66,25 @@ by name. DBInterface.close!(conn) -## `SQLCatalog` and `SQLTable` +## `SQLCatalog`, `SQLTable`, and `SQLColumn` In FunSQL, tables and table-like entities are represented using `SQLTable` -objects. A collection of `SQLTable` objects is represented as a `SQLCatalog` +objects. Their columns are represented using `SQLColumn` objects. +A collection of `SQLTable` objects is represented as a `SQLCatalog` object. - using FunSQL: SQLCatalog, SQLTable + using FunSQL: SQLCatalog, SQLColumn, SQLTable -A `SQLTable` constructor takes the table name, a vector of column names, -and, optionally, the name of the table schema and other qualifiers. A name -could be provided either as a `Symbol` or as a `String` value. +A `SQLTable` constructor takes the table name, a vector of columns, and, +optionally, the name of the table schema and other qualifiers. A name +could be provided either as a `Symbol` or as a `String` value. A column +can be specified just by its name. location = SQLTable(qualifiers = [:public], name = :location, columns = [:location_id, :address_1, :address_2, :city, :state, :zip]) - #-> SQLTable(:location, qualifiers = [:public], …) + #-> SQLTable(qualifiers = [:public], :location, …) person = SQLTable(name = "person", columns = ["person_id", "year_of_birth", "location_id"]) @@ -90,54 +92,105 @@ could be provided either as a `Symbol` or as a `String` value. The table and the column names could be provided as positional arguments. - vocabulary = SQLTable(:vocabulary, - columns = [:vocabulary_id, :vocabulary_name]) - #-> SQLTable(:vocabulary, …) - concept = SQLTable("concept", "concept_id", "concept_name", "vocabulary_id") #-> SQLTable(:concept, …) +A column may have a custom name for use with FunSQL and the original name +for generating SQL queries. + + vocabulary = SQLTable(:vocabulary, + :id => SQLColumn(:vocabulary_id), + :name => SQLColumn(:vocabulary_name)) + #-> SQLTable(:vocabulary, …) + A `SQLTable` object is displayed as a Julia expression that created the object. display(location) #=> - SQLTable(:location, - qualifiers = [:public], - columns = [:location_id, :address_1, :address_2, :city, :state, :zip]) + SQLTable(qualifiers = [:public], + :location, + SQLColumn(:location_id), + SQLColumn(:address_1), + SQLColumn(:address_2), + SQLColumn(:city), + SQLColumn(:state), + SQLColumn(:zip)) =# - display(person) + display(vocabulary) #=> - SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]) + SQLTable(:vocabulary, + :id => SQLColumn(:vocabulary_id), + :name => SQLColumn(:vocabulary_name)) =# +A `SQLTable` object behaves like a read-only dictionary. + + person[:person_id] + #-> SQLColumn(:person_id) + + person["person_id"] + #-> SQLColumn(:person_id) + + person[1] + #-> SQLColumn(:person_id) + + person[:visit_occurrence] + #-> ERROR: KeyError: key :visit_occurrence not found + + get(person, :person_id, nothing) + #-> SQLColumn(:person_id) + + get(person, "person_id", nothing) + #-> SQLColumn(:person_id) + + get(person, :visit_occurrence, missing) + #-> missing + + get(() -> missing, person, :visit_occurrence) + #-> missing + + length(person) + #-> 3 + + collect(keys(person)) + #-> [:person_id, :year_of_birth, :location_id] + A `SQLCatalog` constructor takes a collection of `SQLTable` objects, -the target dialect, and the size of the query cache. +the target dialect, and the size of the query cache. Just as columns, +a table may have a custom name for use with FunSQL and the original name +for generating SQL. - catalog = SQLCatalog(tables = [person, location, vocabulary, concept], + catalog = SQLCatalog(tables = [person, location, concept, :concept_vocabulary => vocabulary], dialect = :sqlite, cache = 128) #-> SQLCatalog(…4 tables…, dialect = SQLDialect(:sqlite), cache = 128) display(catalog) #=> - SQLCatalog( - :concept => SQLTable(:concept, - columns = - [:concept_id, :concept_name, :vocabulary_id]), - :location => - SQLTable( - :location, - qualifiers = [:public], - columns = - [:location_id, :address_1, :address_2, :city, :state, :zip]), - :person => SQLTable(:person, - columns = [:person_id, :year_of_birth, :location_id]), - :vocabulary => SQLTable(:vocabulary, - columns = [:vocabulary_id, :vocabulary_name]), - dialect = SQLDialect(:sqlite), - cache = 128) + SQLCatalog(SQLTable(:concept, + SQLColumn(:concept_id), + SQLColumn(:concept_name), + SQLColumn(:vocabulary_id)), + :concept_vocabulary => SQLTable(:vocabulary, + :id => SQLColumn(:vocabulary_id), + :name => SQLColumn( + :vocabulary_name)), + SQLTable(qualifiers = [:public], + :location, + SQLColumn(:location_id), + SQLColumn(:address_1), + SQLColumn(:address_2), + SQLColumn(:city), + SQLColumn(:state), + SQLColumn(:zip)), + SQLTable(:person, + SQLColumn(:person_id), + SQLColumn(:year_of_birth), + SQLColumn(:location_id)), + dialect = SQLDialect(:sqlite), + cache = 128) =# Number of tables in the catalog affects its representation. @@ -191,7 +244,61 @@ The catalog behaves as a read-only `Dict` object. #-> 4 sort(collect(keys(catalog))) - #-> [:concept, :location, :person, :vocabulary] + #-> [:concept, :concept_vocabulary, :location, :person] + +Catalog objects can be assigned arbitrary metadata. + + metadata_catalog = + SQLCatalog(SQLTable(:person, + SQLColumn(:person_id, metadata = (; label = "Person ID")), + SQLColumn(:year_of_birth, metadata = (;)), + metadata = (; caption = "Person", is_view = false)), + metadata = (; model = "OMOP")) + #-> SQLCatalog(…1 table…, dialect = SQLDialect(), metadata = …) + + display(metadata_catalog) + #=> + SQLCatalog(SQLTable(:person, + SQLColumn(:person_id, metadata = [:label => "Person ID"]), + SQLColumn(:year_of_birth), + metadata = [:caption => "Person", :is_view => false]), + dialect = SQLDialect(), + metadata = [:model => "OMOP"]) + =# + +FunSQL metadata supports DataAPI metadata interface. + + using DataAPI + + DataAPI.metadata(metadata_catalog) + #-> Dict("model" => "OMOP") + + DataAPI.metadata(metadata_catalog, style = true) + #-> Dict("model" => ("OMOP", :default)) + + DataAPI.metadata(metadata_catalog, :name, :default) + #-> :default + + DataAPI.metadata(metadata_catalog[:person])["caption"] + #-> "Person" + + DataAPI.metadata(metadata_catalog[:person], :is_view, true) + #-> false + + DataAPI.colmetadata(metadata_catalog[:person])[:person_id]["label"] + #-> "Person ID" + + DataAPI.colmetadata(metadata_catalog[:person], 1, :label) + #-> "Person ID" + + DataAPI.colmetadata(metadata_catalog[:person], :year_of_birth, :label, "") + #-> "" + + DataAPI.metadata(metadata_catalog[:person][:person_id]) + #-> Dict("label" => "Person ID") + + DataAPI.metadata(metadata_catalog[:person][:person_id], :label, "") + #-> "Person ID" ## `SQLDialect` @@ -274,6 +381,15 @@ A completely custom dialect can be specified. String(sql) #-> "SELECT * FROM person" +`SQLString` may carry a vector `columns` describing the output columns of +the query. + + sql = SQLString("SELECT person_id FROM person", columns = [SQLColumn(:person_id)]) + #-> SQLString("SELECT person_id FROM person", columns = […1 column…]) + + display(sql) + #-> SQLString("SELECT person_id FROM person", columns = [SQLColumn(:person_id)]) + When the query has parameters, `SQLString` should include a vector of parameter names in the order they should appear in `DBInterface.execute` call. diff --git a/src/FunSQL.jl b/src/FunSQL.jl index fd2e7e5f..09537994 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -84,6 +84,7 @@ using OrderedCollections: OrderedDict, OrderedSet using Tables using DBInterface using LRUCache +using DataAPI const SQLLiteralType = Union{Missing, Bool, Number, AbstractString, Dates.AbstractTime} @@ -96,10 +97,10 @@ end include("dissect.jl") include("quote.jl") -include("strings.jl") include("dialects.jl") include("types.jl") include("catalogs.jl") +include("strings.jl") include("clauses.jl") include("nodes.jl") include("connections.jl") diff --git a/src/catalogs.jl b/src/catalogs.jl index 0d1dfeab..a59992f6 100644 --- a/src/catalogs.jl +++ b/src/catalogs.jl @@ -1,57 +1,152 @@ # Database structure. +const SQLMetadata = Base.ImmutableDict{Symbol, Any} + +_metadata(::Nothing) = + SQLMetadata() + +_metadata(dict::SQLMetadata) = + dict + +_metadata(dict::SQLMetadata, kvs...) = + Base.ImmutableDict(dict, kvs...) + +_metadata(other) = + _metadata(SQLMetadata(), pairs(other)...) + +_metadata_style(@nospecialize(val)) = + :default + +_metadata_keys(dict::SQLMetadata) = + Base.Generator(string, keys(dict)) + +_metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}; style::Bool) = + let val = dict[Symbol(key)] + style ? (val, _metadata_style(val)) : val + end + +_metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}, default; style::Bool) = + let val = get(dict, Symbol(key), default) + style ? (val, _metadata_style(val)) : val + end + +""" + SQLColumn(; name, metadata = nothing) + SQLColumn(name; metadata = nothing) + +`SQLColumn` represents a column with the given `name` and optional `metadata`. +""" +struct SQLColumn + name::Symbol + metadata::SQLMetadata + + function SQLColumn(; name::Union{Symbol, AbstractString}, metadata = nothing) + new(Symbol(name), _metadata(metadata)) + end +end + +SQLColumn(name; metadata = nothing) = + SQLColumn(name = name, metadata = metadata) + +Base.show(io::IO, col::SQLColumn) = + print(io, quoteof(col, limit = true)) + +Base.show(io::IO, ::MIME"text/plain", col::SQLColumn) = + pprint(io, col) + +function PrettyPrinting.quoteof(col::SQLColumn; limit::Bool = false) + ex = Expr(:call, nameof(SQLColumn), QuoteNode(col.name)) + if !isempty(col.metadata) + push!(ex.args, Expr(:kw, :metadata, limit ? :… : quoteof(reverse!(collect(col.metadata))))) + end + ex +end + +DataAPI.metadatasupport(::Type{SQLColumn}) = + (read = true, write = false) + +DataAPI.metadata(col::SQLColumn, key::Union{Symbol, AbstractString}; style::Bool = false) = + _metadata_get(col.metadata, key; style) + +DataAPI.metadata(col::SQLColumn, key::Union{Symbol, AbstractString}, default; style::Bool = false) = + _metadata_get(col.metadata, key, default; style) + +DataAPI.metadatakeys(col::SQLColumn) = + _metadata_keys(col.metadata) + """ - SQLTable(; qualifiers = [], name, columns) - SQLTable(name; qualifiers = [], columns) - SQLTable(name, columns...; qualifiers = []) + SQLTable(; qualifiers = [], name, columns, metadata = nothing) + SQLTable(name; qualifiers = [], columns, metadata = nothing) + SQLTable(name, columns...; qualifiers = [], metadata = nothing) The structure of a SQL table or a table-like entity (`TEMP TABLE`, `VIEW`, etc) for use as a reference in assembling SQL queries. -The `SQLTable` constructor expects the table `name`, a vector `columns` of -column names, and, optionally, a vector containing the name of the table schema -and other `qualifiers`. A name can be a `Symbol` or a `String` value. +The `SQLTable` constructor expects the table `name`, an optional vector +containing the table schema and other `qualifiers`, an ordered dictionary +`columns` that maps names to columns, and an optional `metadata`. # Examples ```jldoctest julia> person = SQLTable(qualifiers = ["public"], name = "person", - columns = ["person_id", "year_of_birth"]) -SQLTable(:person, - qualifiers = [:public], - columns = [:person_id, :year_of_birth]) + columns = ["person_id", "year_of_birth"], + metadata = (; is_view = false)) +SQLTable(qualifiers = [:public], + :person, + SQLColumn(:person_id), + SQLColumn(:year_of_birth), + metadata = [:is_view => false]) ``` """ -struct SQLTable +struct SQLTable <: AbstractDict{Symbol, SQLColumn} qualifiers::Vector{Symbol} name::Symbol - columns::Vector{Symbol} - column_set::Set{Symbol} + columns::OrderedDict{Symbol, SQLColumn} + metadata::SQLMetadata function SQLTable(; qualifiers::AbstractVector{<:Union{Symbol, AbstractString}} = Symbol[], name::Union{Symbol, AbstractString}, - columns::AbstractVector{<:Union{Symbol, AbstractString}}) + columns, + metadata = nothing) qualifiers = !isa(qualifiers, Vector{Symbol}) ? Symbol[Symbol(ql) for ql in qualifiers] : qualifiers name = Symbol(name) - columns = - !isa(columns, Vector{Symbol}) ? - Symbol[Symbol(col) for col in columns] : - columns - column_set = Set{Symbol}(columns) - new(qualifiers, name, columns, column_set) + columns = _column_map(columns) + new(qualifiers, name, columns, _metadata(metadata)) end end -SQLTable(name; qualifiers = Symbol[], columns) = - SQLTable(qualifiers = qualifiers, name = name, columns = columns) +SQLTable(name; qualifiers = Symbol[], columns, metadata = nothing) = + SQLTable(qualifiers = qualifiers, name = name, columns = columns, metadata = metadata) + +SQLTable(name, columns...; qualifiers = Symbol[], metadata = nothing) = + SQLTable(qualifiers = qualifiers, name = name, columns = [columns...], metadata = metadata) + +_column_map(columns::OrderedDict{Symbol, SQLColumn}) = + columns -SQLTable(name, columns...; qualifiers = Symbol[]) = - SQLTable(qualifiers = qualifiers, name = name, columns = [columns...]) +_column_map(columns::AbstractVector{Pair{Symbol, SQLColumn}}) = + OrderedDict{Symbol, SQLColumn}(columns) + +_column_map(columns) = + OrderedDict{Symbol, SQLColumn}(Pair{Symbol, SQLColumn}[_column_entry(c) for c in columns]) + +_column_entry(c::Symbol) = + c => SQLColumn(c) + +_column_entry(c::AbstractString) = + _column_entry(Symbol(c)) + +_column_entry(c::SQLColumn) = + c.name => c + +_column_entry((n, c)::Pair{<:Union{Symbol, AbstractString}, SQLColumn}) = + Symbol(n) => c Base.show(io::IO, tbl::SQLTable) = print(io, quoteof(tbl, limit = true)) @@ -61,43 +156,84 @@ Base.show(io::IO, ::MIME"text/plain", tbl::SQLTable) = function PrettyPrinting.quoteof(tbl::SQLTable; limit::Bool = false) ex = Expr(:call, nameof(SQLTable)) - push!(ex.args, quoteof(tbl.name)) if !isempty(tbl.qualifiers) push!(ex.args, Expr(:kw, :qualifiers, quoteof(tbl.qualifiers))) end + push!(ex.args, quoteof(tbl.name)) if !limit - push!(ex.args, Expr(:kw, :columns, tbl.columns)) + for (name, col) in tbl.columns + arg = quoteof(col) + if name !== col.name + arg = Expr(:call, :(=>), QuoteNode(name), arg) + end + push!(ex.args, arg) + end + if !isempty(tbl.metadata) + push!(ex.args, Expr(:kw, :metadata, quoteof(reverse!(collect(tbl.metadata))))) + end else push!(ex.args, :…) end ex end -const default_cache_maxsize = 256 +Base.get(tbl::SQLTable, key::Union{Symbol, AbstractString}, default) = + get(tbl.columns, Symbol(key), default) -_table_map(tables::Dict{Symbol, SQLTable}) = - tables +Base.get(default::Base.Callable, tbl::SQLTable, key::Union{Symbol, AbstractString}) = + get(default, tbl.columns, Symbol(key)) -_table_map(tables::AbstractVector{Pair{Symbol, SQLTable}}) = - Dict{Symbol, SQLTable}(tables) +Base.getindex(tbl::SQLTable, key::Union{Symbol, AbstractString}) = + tbl.columns[Symbol(key)] -_table_map(tables) = - Dict{Symbol, SQLTable}(Pair{Symbol, SQLTable}[_table_entry(t) for t in tables]) +Base.getindex(tbl::SQLTable, key::Integer) = + tbl.columns.vals[key] -_table_entry(t::SQLTable) = - t.name => t +Base.iterate(tbl::SQLTable, state...) = + iterate(tbl.columns, state...) -_table_entry((n, t)::Pair{<:Union{Symbol, AbstractString}, SQLTable}) = - Symbol(n) => t +Base.length(tbl::SQLTable) = + length(tbl.columns) + +DataAPI.metadatasupport(::Type{SQLTable}) = + (read = true, write = false) + +DataAPI.metadata(tbl::SQLTable, key::Union{Symbol, AbstractString}; style::Bool = false) = + _metadata_get(tbl.metadata, key; style) + +DataAPI.metadata(tbl::SQLTable, key::Union{Symbol, AbstractString}, default; style::Bool = false) = + _metadata_get(tbl.metadata, key, default; style) + +DataAPI.metadatakeys(tbl::SQLTable) = + _metadata_keys(tbl.metadata) + +DataAPI.colmetadatasupport(::Type{SQLTable}) = + (read = true, write = false) + +DataAPI.colmetadata(tbl::SQLTable, col::Union{Symbol, Integer}, key::Union{Symbol, AbstractString}; style::Bool = false) = + _metadata_get(tbl[col].metadata, key; style) + +DataAPI.colmetadata(tbl::SQLTable, col::Union{Symbol, Integer}, key::Union{Symbol, AbstractString}, default; style::Bool = false) = + _metadata_get(tbl[col].metadata, key, default; style) + +DataAPI.colmetadatakeys(tbl::SQLTable) = + (k => _metadata_keys(v.metadata) for (k, v) in tbl.columns) + +DataAPI.colmetadatakeys(tbl::SQLTable, col::Union{Symbol, Integer}) = + _metadata_keys(tbl[col].metadata) + +const default_cache_maxsize = 256 """ SQLCatalog(; tables = Dict{Symbol, SQLTable}(), dialect = :default, - cache = $default_cache_maxsize) - SQLCatalog(tables...; dialect = :default, cache = $default_cache_maxsize) + cache = $default_cache_maxsize, + metadata = nothing) + SQLCatalog(tables...; + dialect = :default, cache = $default_cache_maxsize, metadata = nothing) `SQLCatalog` encapsulates available database `tables`, the target SQL `dialect`, -and a `cache` of serialized queries. +a `cache` of serialized queries, and an optional `metadata`. Parameter `tables` is either a dictionary or a vector of [`SQLTable`](@ref) objects, where the vector will be converted to a dictionary with @@ -120,10 +256,11 @@ julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :locati julia> location = SQLTable(:location, columns = [:location_id, :state]); julia> catalog = SQLCatalog(person, location, dialect = :postgresql) -SQLCatalog(:location => SQLTable(:location, columns = [:location_id, :state]), - :person => - SQLTable(:person, - columns = [:person_id, :year_of_birth, :location_id]), +SQLCatalog(SQLTable(:location, SQLColumn(:location_id), SQLColumn(:state)), + SQLTable(:person, + SQLColumn(:person_id), + SQLColumn(:year_of_birth), + SQLColumn(:location_id)), dialect = SQLDialect(:postgresql)) ``` """ @@ -131,23 +268,44 @@ struct SQLCatalog <: AbstractDict{Symbol, SQLTable} tables::Dict{Symbol, SQLTable} dialect::SQLDialect cache::Any # Union{AbstractDict{SQLNode, SQLString}, Nothing} + metadata::SQLMetadata - function SQLCatalog(; tables = Dict{Symbol, SQLTable}(), dialect = :default, cache = default_cache_maxsize) + function SQLCatalog(; tables = Dict{Symbol, SQLTable}(), dialect = :default, cache = default_cache_maxsize, metadata = nothing) table_map = _table_map(tables) if cache isa Number cache = LRU{SQLNode, SQLString}(maxsize = cache) end - new(table_map, dialect, cache) + new(table_map, dialect, cache, _metadata(metadata)) end end -SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize) = - SQLCatalog(tables = tables, dialect = dialect, cache = cache) +SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize, metadata = nothing) = + SQLCatalog(tables = tables, dialect = dialect, cache = cache, metadata = metadata) + +_table_map(tables::Dict{Symbol, SQLTable}) = + tables + +_table_map(tables::AbstractVector{Pair{Symbol, SQLTable}}) = + Dict{Symbol, SQLTable}(tables) + +_table_map(tables) = + Dict{Symbol, SQLTable}(Pair{Symbol, SQLTable}[_table_entry(t) for t in tables]) + +_table_entry(t::SQLTable) = + t.name => t + +_table_entry((n, t)::Pair{<:Union{Symbol, AbstractString}, SQLTable}) = + Symbol(n) => t function PrettyPrinting.quoteof(c::SQLCatalog) ex = Expr(:call, nameof(SQLCatalog)) for name in sort!(collect(keys(c.tables))) - push!(ex.args, Expr(:call, :(=>), QuoteNode(name), quoteof(c.tables[name]))) + tbl = c.tables[name] + arg = quoteof(tbl) + if name !== tbl.name + arg = Expr(:call, :(=>), QuoteNode(name), arg) + end + push!(ex.args, arg) end push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect))) cache = c.cache @@ -160,6 +318,9 @@ function PrettyPrinting.quoteof(c::SQLCatalog) else push!(ex.args, Expr(:kw, :cache, Expr(:call, typeof(cache)))) end + if !isempty(c.metadata) + push!(ex.args, Expr(:kw, :metadata, quoteof(reverse!(collect(c.metadata))))) + end ex end @@ -182,6 +343,9 @@ function Base.show(io::IO, c::SQLCatalog) else print(io, ", cache = ", typeof(cache), "()") end + if !isempty(c.metadata) + print(io, ", metadata = …") + end print(io, ')') nothing end @@ -204,3 +368,14 @@ Base.iterate(c::SQLCatalog, state...) = Base.length(c::SQLCatalog) = length(c.tables) +DataAPI.metadatasupport(::Type{SQLCatalog}) = + (read = true, write = false) + +DataAPI.metadata(c::SQLCatalog, key::Union{Symbol, AbstractString}; style::Bool = false) = + _metadata_get(c.metadata, key; style) + +DataAPI.metadata(c::SQLCatalog, key::Union{Symbol, AbstractString}, default; style::Bool = false) = + _metadata_get(c.metadata, key, default; style) + +DataAPI.metadatakeys(c::SQLCatalog) = + _metadata_keys(c.metadata) diff --git a/src/clauses/internal.jl b/src/clauses/internal.jl index 063f5211..4e88cb7c 100644 --- a/src/clauses/internal.jl +++ b/src/clauses/internal.jl @@ -5,9 +5,10 @@ mutable struct WithContextClause <: AbstractSQLClause over::SQLClause dialect::SQLDialect + columns::Union{Vector{SQLColumn}, Nothing} - WithContextClause(; over, dialect) = - new(over, dialect) + WithContextClause(; over, dialect, columns = nothing) = + new(over, dialect, columns) end WITH_CONTEXT(args...; kws...) = @@ -21,5 +22,8 @@ function PrettyPrinting.quoteof(c::WithContextClause, ctx::QuoteContext) if c.dialect !== default_dialect push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect))) end + if c.columns !== nothing + push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in c.columns]...))) + end ex end diff --git a/src/link.jl b/src/link.jl index 06ce0522..882bbac9 100644 --- a/src/link.jl +++ b/src/link.jl @@ -1,21 +1,21 @@ # Find select lists. struct LinkContext - dialect::SQLDialect + catalog::SQLCatalog defs::Vector{SQLNode} refs::Vector{SQLNode} cte_refs::Base.ImmutableDict{Tuple{Symbol, Int}, Vector{SQLNode}} knot_refs::Union{Vector{SQLNode}, Nothing} - LinkContext(dialect) = - new(dialect, + LinkContext(catalog) = + new(catalog, SQLNode[], SQLNode[], Base.ImmutableDict{Tuple{Symbol, Int}, Vector{SQLNode}}(), nothing) LinkContext(ctx::LinkContext; refs = ctx.refs, cte_refs = ctx.cte_refs, knot_refs = ctx.knot_refs) = - new(ctx.dialect, + new(ctx.catalog, ctx.defs, refs, cte_refs, @@ -23,8 +23,8 @@ struct LinkContext end function link(n::SQLNode) - @dissect(n, WithContext(over = over, dialect = dialect)) || throw(ILLFormedError()) - ctx = LinkContext(dialect) + @dissect(n, WithContext(over = over, catalog = catalog)) || throw(ILLFormedError()) + ctx = LinkContext(catalog) t = row_type(over) refs = SQLNode[] for (f, ft) in t.fields @@ -33,7 +33,7 @@ function link(n::SQLNode) end end over′ = Linked(refs, over = link(dismantle(over, ctx), ctx, refs)) - WithContext(over = over′, dialect = dialect, defs = ctx.defs) + WithContext(over = over′, catalog = catalog, defs = ctx.defs) end function dismantle(n::SQLNode, ctx) diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 00b61b84..df0340d4 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -3,12 +3,11 @@ # Preserve context between rendering passes. mutable struct WithContextNode <: AbstractSQLNode over::SQLNode - dialect::SQLDialect - tables::Dict{Symbol, SQLTable} + catalog::SQLCatalog defs::Vector{SQLNode} - WithContextNode(; over, dialect = default_dialect, tables = Dict{Symbol, SQLTable}(), defs = SQLNode[]) = - new(over, dialect, tables, defs) + WithContextNode(; over, catalog = SQLCatalog(), defs = SQLNode[]) = + new(over, catalog, defs) end WithContext(args...; kws...) = @@ -19,12 +18,7 @@ dissect(scr::Symbol, ::typeof(WithContext), pats::Vector{Any}) = function PrettyPrinting.quoteof(n::WithContextNode, ctx::QuoteContext) ex = Expr(:call, nameof(WithContext), Expr(:kw, :over, quoteof(n.over, ctx))) - if n.dialect != default_dialect - push!(ex.args, Expr(:kw, :dialect, quoteof(n.dialect))) - end - if !isempty(n.tables) - push!(ex.args, Expr(:kw, :tables, quoteof(n.tables))) - end + push!(ex.args, Expr(:kw, :catalog, quoteof(n.catalog))) if !isempty(n.defs) push!(ex.args, Expr(:kw, :defs, Expr(:vect, Any[quoteof(def, ctx) for def in n.defs]...))) end diff --git a/src/render.jl b/src/render.jl index cfdba5fc..e20b5327 100644 --- a/src/render.jl +++ b/src/render.jl @@ -56,7 +56,7 @@ function render(catalog::SQLCatalog, n::SQLNode) return sql end end - n = WithContext(over = n, dialect = catalog.dialect, tables = catalog.tables) + n = WithContext(over = n, catalog = catalog) n = resolve(n) @debug "FunSQL.resolve\n" * sprint(pprint, n) _group = Symbol("FunSQL.resolve") n = link(n) diff --git a/src/resolve.jl b/src/resolve.jl index 81b76d92..ce9f5d0f 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -1,8 +1,7 @@ # Resolving node types. struct ResolveContext - dialect::SQLDialect - tables::Dict{Symbol, SQLTable} + catalog::SQLCatalog path::Vector{SQLNode} row_type::RowType cte_types::Base.ImmutableDict{Symbol, Tuple{Int, RowType}} @@ -10,9 +9,8 @@ struct ResolveContext knot_type::Union{RowType, Nothing} implicit_knot::Bool - ResolveContext(dialect, tables) = - new(dialect, - tables, + ResolveContext(catalog) = + new(catalog, SQLNode[], EMPTY_ROW, Base.ImmutableDict{Symbol, Tuple{Int, RowType}}(), @@ -27,8 +25,7 @@ struct ResolveContext var_types = ctx.var_types, knot_type = ctx.knot_type, implicit_knot = ctx.implicit_knot) = - new(ctx.dialect, - ctx.tables, + new(ctx.catalog, ctx.path, row_type, cte_types, @@ -56,9 +53,9 @@ function type(n::SQLNode) end function resolve(n::SQLNode) - @dissect(n, WithContext(over = n′, dialect = dialect, tables = tables)) || throw(IllFormedError()) - ctx = ResolveContext(dialect, tables) - WithContext(over = resolve(n′, ctx), dialect = dialect) + @dissect(n, WithContext(over = n′, catalog = catalog)) || throw(IllFormedError()) + ctx = ResolveContext(catalog) + WithContext(over = resolve(n′, ctx), catalog = catalog) end function resolve(n::SQLNode, ctx) @@ -229,7 +226,7 @@ end function RowType(table::SQLTable) fields = FieldTypeMap() - for f in table.columns + for f in keys(table.columns) fields[f] = ScalarType() end RowType(fields) @@ -246,7 +243,7 @@ function resolve(n::FromNode, ctx) (depth, t) = v n′ = FromTableExpression(source, depth) else - table = get(ctx.tables, source, nothing) + table = get(ctx.catalog, source, nothing) if table === nothing throw( ReferenceError( diff --git a/src/serialize.jl b/src/serialize.jl index 4765ae3b..bebb4dda 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -12,11 +12,11 @@ mutable struct SerializeContext <: IO end function serialize(c::SQLClause) - @dissect(c, WITH_CONTEXT(over = c′, dialect = dialect)) || throw(IllFormedError()) + @dissect(c, WITH_CONTEXT(over = c′, dialect = dialect, columns = columns)) || throw(IllFormedError()) ctx = SerializeContext(dialect) serialize!(c′, ctx) raw = String(take!(ctx.io)) - SQLString(raw, vars = ctx.vars) + SQLString(raw, columns = columns, vars = ctx.vars) end Base.write(ctx::SerializeContext, octet::UInt8) = diff --git a/src/strings.jl b/src/strings.jl index 5f843ecc..122a8848 100644 --- a/src/strings.jl +++ b/src/strings.jl @@ -1,10 +1,12 @@ # Serialized SQL query with parameter mapping. """ - SQLString(raw, vars = Symbol[]) + SQLString(raw; columns = nothing, vars = Symbol[]) Serialized SQL query. +Parameter `columns` is a vector describing the output columns. + Parameter `vars` is a vector of query parameters (created with [`Var`](@ref)) in the order they are expected by the `DBInterface.execute()` function. @@ -20,7 +22,8 @@ SQLString(\""" SELECT "person_1"."person_id", "person_1"."year_of_birth" - FROM "person" AS "person_1\\"\""") + FROM "person" AS "person_1\\"\""", + columns = [SQLColumn(:person_id), SQLColumn(:year_of_birth)]) julia> q = From(person) |> Where(Fun.and(Get.year_of_birth .>= Var.YEAR, Get.year_of_birth .< Var.YEAR .+ 10)); @@ -34,6 +37,7 @@ SQLString(\""" WHERE (`person_1`.`year_of_birth` >= ?) AND (`person_1`.`year_of_birth` < (? + 10))\""", + columns = [SQLColumn(:person_id), SQLColumn(:year_of_birth)], vars = [:YEAR, :YEAR]) julia> render(q, dialect = :postgresql) @@ -45,15 +49,17 @@ SQLString(\""" WHERE ("person_1"."year_of_birth" >= \$1) AND ("person_1"."year_of_birth" < (\$1 + 10))\""", + columns = [SQLColumn(:person_id), SQLColumn(:year_of_birth)], vars = [:YEAR]) ``` """ struct SQLString <: AbstractString raw::String + columns::Union{Vector{SQLColumn}, Nothing} vars::Vector{Symbol} - SQLString(raw; vars = Symbol[]) = - new(raw, vars) + SQLString(raw; columns = nothing, vars = Symbol[]) = + new(raw, columns, vars) end Base.ncodeunits(sql::SQLString) = @@ -82,6 +88,9 @@ Base.write(io::IO, sql::SQLString) = function PrettyPrinting.quoteof(sql::SQLString) ex = Expr(:call, nameof(SQLString), sql.raw) + if sql.columns !== nothing + push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in sql.columns]...))) + end if !isempty(sql.vars) push!(ex.args, Expr(:kw, :vars, quoteof(sql.vars))) end @@ -91,6 +100,11 @@ end function Base.show(io::IO, sql::SQLString) print(io, "SQLString(") show(io, sql.raw) + if sql.columns !== nothing + print(io, ", columns = ") + l = length(sql.columns) + print(io, l == 0 ? "[]" : l == 1 ? "[…1 column…]" : "[…$l columns…]") + end if !isempty(sql.vars) print(io, ", vars = ") show(io, sql.vars) diff --git a/src/translate.jl b/src/translate.jl index fcdfbf14..cfc63187 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -137,7 +137,7 @@ end # Translating context. struct TranslateContext - dialect::SQLDialect + catalog::SQLCatalog defs::Vector{SQLNode} aliases::Dict{Symbol, Int} recursive::Ref{Bool} @@ -148,8 +148,8 @@ struct TranslateContext vars::Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause} subs::Dict{SQLNode, SQLClause} - TranslateContext(; dialect, defs) = - new(dialect, + TranslateContext(; catalog, defs) = + new(catalog, defs, Dict{Symbol, Int}(), Ref(false), @@ -161,7 +161,7 @@ struct TranslateContext Dict{Int, SQLClause}()) function TranslateContext(ctx::TranslateContext; cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs) - new(ctx.dialect, + new(ctx.catalog, ctx.defs, ctx.aliases, ctx.recursive, @@ -184,9 +184,14 @@ function allocate_alias(ctx::TranslateContext, alias::Symbol) end function translate(n::SQLNode) - @dissect(n, WithContext(over = n′, dialect = dialect, defs = defs)) || throw(IllFormedError()) - ctx = TranslateContext(dialect = dialect, defs = defs) - c = translate(n′, ctx) + @dissect(n, WithContext(over = Linked(over = n′, refs = refs), catalog = catalog, defs = defs)) || throw(IllFormedError()) + ctx = TranslateContext(catalog = catalog, defs = defs) + base = assemble(n′, TranslateContext(ctx, refs = refs)) + columns = nothing + if !isempty(base.cols) + columns = [SQLColumn(col) for col in keys(base.cols)] + end + c = complete(base) with_args = SQLClause[] for cte_a in ctx.ctes !cte_a.external || continue @@ -205,7 +210,7 @@ function translate(n::SQLNode) if !isempty(with_args) c = WITH(over = c, args = with_args, recursive = ctx.recursive[]) end - WITH_CONTEXT(over = c, dialect = ctx.dialect) + WITH_CONTEXT(over = c, dialect = ctx.catalog.dialect, columns = columns) end function translate(n::SQLNode, ctx) @@ -516,7 +521,7 @@ end function assemble(n::FromTableNode, ctx) seen = Set{Symbol}() for ref in ctx.refs - @dissect(ref, nothing |> Get(name = name)) && name in n.table.column_set || error() + @dissect(ref, nothing |> Get(name = name)) && name in keys(n.table.columns) || error() if !(name in seen) push!(seen, name) end @@ -525,9 +530,9 @@ function assemble(n::FromTableNode, ctx) tbl = ID(n.table.qualifiers, n.table.name) c = FROM(AS(over = tbl, name = alias)) cols = OrderedDict{Symbol, SQLClause}() - for col in n.table.columns - col in seen || continue - cols[col] = ID(over = alias, name = col) + for (name, col) in n.table.columns + name in seen || continue + cols[name] = ID(over = alias, name = col.name) end repl = Dict{SQLNode, Symbol}() for ref in ctx.refs @@ -566,15 +571,15 @@ function assemble(n::FromValuesNode, ctx) col in seen || continue cols[col] = LIT(missing) end - elseif ctx.dialect.has_as_columns + elseif ctx.catalog.dialect.has_as_columns c = FROM(AS(alias, columns = column_aliases, over = VALUES(rows))) for col in columns col in seen || continue cols[col] = ID(over = alias, name = col) end else - column_prefix = ctx.dialect.values_column_prefix - column_index = ctx.dialect.values_column_index + column_prefix = ctx.catalog.dialect.values_column_prefix + column_index = ctx.catalog.dialect.values_column_index column_prefix !== nothing || error() c = FROM(AS(alias, over = VALUES(rows))) for col in columns @@ -841,7 +846,7 @@ function assemble(n::RoutedJoinNode, ctx) for (ref, name) in right.repl subs[ref] = right.cols[name] end - if ctx.dialect.has_implicit_lateral + if ctx.catalog.dialect.has_implicit_lateral lateral = false end else diff --git a/test/Project.toml b/test/Project.toml index e91d85c3..d8de25a1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"