Skip to content

Commit

Permalink
Eliminate some duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
xitology committed Feb 15, 2024
1 parent e210b53 commit bbcb49b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 89 deletions.
29 changes: 5 additions & 24 deletions src/link.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ function _cte_depth(dict, name)
0
end

function link(n::WithNode, ctx)
function link(n::Union{WithNode, WithExternalNode}, ctx)
cte_refs′ = ctx.cte_refs
refs_map = Vector{SQLNode}[]
for name in keys(n.label_map)
Expand All @@ -484,30 +484,11 @@ function link(n::WithNode, ctx)
push!(args′, arg′)
label_map′[f] = lastindex(args′)
end
With(over = over′, args = args′, materialized = n.materialized, label_map = label_map′)
end

function link(n::WithExternalNode, ctx)
cte_refs′ = ctx.cte_refs
refs_map = Vector{SQLNode}[]
for name in keys(n.label_map)
depth = _cte_depth(ctx.cte_refs, name) + 1
refs = SQLNode[]
cte_refs′ = Base.ImmutableDict(cte_refs′, (name, depth) => refs)
push!(refs_map, refs)
end
ctx′ = LinkContext(ctx, cte_refs = cte_refs′)
over′ = Linked(ctx′.refs, over = link(n.over, ctx′))
args′ = SQLNode[]
label_map′ = OrderedDict{Symbol, Int}()
for (f, i) in n.label_map
arg = n.args[i]
refs = refs_map[i]
arg′ = Linked(refs, over = link(arg, ctx, refs))
push!(args′, arg′)
label_map′[f] = lastindex(args′)
if n isa WithNode
With(over = over′, args = args′, materialized = n.materialized, label_map = label_map′)
else
WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = label_map′)
end
WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = label_map′)
end

function gather!(n::SQLNode, ctx)
Expand Down
76 changes: 17 additions & 59 deletions src/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ function row_type(n::SQLNode)
type
end

function scalar_type(n::SQLNode)
@dissect(n, Resolved(type = type::ScalarType)) || throw(IllFormedError())
type
end

function type(n::SQLNode)
@dissect(n, Resolved(type = t)) || throw(IllFormedError())
t
Expand Down Expand Up @@ -169,50 +174,24 @@ function resolve_scalar(n::AsNode, ctx)
Resolved(type(over′), over = n′)
end

function resolve(n::BindNode, ctx)
function resolve(n::BindNode, ctx, scalar = false)
args′ = resolve_scalar(n.args, ctx)
var_types′ = ctx.var_types
for (name, i) in n.label_map
v = get(ctx.var_types, name, nothing)
depth = 1 + (v !== nothing ? v[1] : 0)
t = type(args′[i])
if !(t isa ScalarType)
throw(
ReferenceError(
REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE,
name = name,
path = get_path(ctx)))

end
var_types′ = Base.ImmutableDict(var_types′, name => (depth, t))
end
over′ = resolve(n.over, ResolveContext(ctx, var_types = var_types′))
n′ = Bind(over = over′, args = args′, label_map = n.label_map)
Resolved(row_type(over′), over = n′)
end

function resolve_scalar(n::BindNode, ctx)
args′ = resolve_scalar(n.args, ctx)
var_types′ = ctx.var_types
for (name, i) in n.label_map
v = get(ctx.var_types, name, nothing)
depth = 1 + (v !== nothing ? v[1] : 0)
t = type(args′[i])
if !(t isa ScalarType)
throw(
ReferenceError(
REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE,
name = name,
path = get_path(ctx)))

end
t = scalar_type(args′[i])
var_types′ = Base.ImmutableDict(var_types′, name => (depth, t))
end
over′ = resolve_scalar(n.over, ResolveContext(ctx, var_types = var_types′))
ctx′ = ResolveContext(ctx, var_types = var_types′)
over′ = !scalar ? resolve(n.over, ctx′) : resolve_scalar(n.over, ctx′)
n′ = Bind(over = over′, args = args′, label_map = n.label_map)
Resolved(type(over′), over = n′)
end

resolve_scalar(n::BindNode, ctx) =
resolve(n, ctx, true)

function resolve_scalar(n::NestedNode, ctx)
t = get(ctx.row_type.fields, n.name, EmptyType())
if !(t isa RowType)
Expand Down Expand Up @@ -481,7 +460,7 @@ function resolve(n::WhereNode, ctx)
Resolved(t, over = n′)
end

function resolve(n::WithNode, ctx)
function resolve(n::Union{WithNode, WithExternalNode}, ctx)
ctx′ = ResolveContext(ctx, knot_type = nothing, implicit_knot = false)
args′ = resolve(n.args, ctx′)
cte_types′ = ctx.cte_types
Expand All @@ -502,31 +481,10 @@ function resolve(n::WithNode, ctx)
end
ctx′ = ResolveContext(ctx, cte_types = cte_types′)
over′ = resolve(n.over, ctx′)
n′ = With(over = over′, args = args′, materialized = n.materialized, label_map = n.label_map)
Resolved(row_type(over′), over = n′)
end

function resolve(n::WithExternalNode, ctx)
ctx′ = ResolveContext(ctx, knot_type = nothing, implicit_knot = false)
args′ = resolve(n.args, ctx′)
cte_types′ = ctx.cte_types
for (name, i) in n.label_map
v = get(ctx.cte_types, name, nothing)
depth = 1 + (v !== nothing ? v[1] : 0)
t = row_type(args′[i])
cte_t = get(t.fields, name, EmptyType())
if !(cte_t isa RowType)
throw(
ReferenceError(
REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE,
name = name,
path = get_path(ctx)))

end
cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t))
if n isa WithNode
n′ = With(over = over′, args = args′, materialized = n.materialized, label_map = n.label_map)
else
n′ = WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = n.label_map)
end
ctx′ = ResolveContext(ctx, cte_types = cte_types′)
over′ = resolve(n.over, ctx′)
n′ = WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = n.label_map)
Resolved(row_type(over′), over = n′)
end
12 changes: 6 additions & 6 deletions src/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,6 @@ function translate(n::BoundVariableNode, ctx)
ctx.vars[(n.name, n.depth)]
end

function translate(n::LinkedNode, ctx)
base = assemble(n.over, TranslateContext(ctx, refs = n.refs))
complete(base)
end

function translate(n::FunctionNode, ctx)
args = translate(n.args, ctx)
if n.name === :and
Expand Down Expand Up @@ -304,6 +299,11 @@ function translate(n::IsolatedNode, ctx)
complete(base)
end

function translate(n::LinkedNode, ctx)
base = assemble(n.over, TranslateContext(ctx, refs = n.refs))
complete(base)
end

function translate(n::LiteralNode, ctx)
LIT(n.val)
end
Expand Down Expand Up @@ -595,7 +595,7 @@ end

function assemble(n::GroupNode, ctx)
has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs)
if isempty(n.by) && !has_aggregates
if isempty(n.by) && !has_aggregates # NOOP: already processed in link()
return assemble(nothing, ctx)
end
base = assemble(n.over, ctx)
Expand Down

0 comments on commit bbcb49b

Please sign in to comment.