diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index d34a5425..c8197beb 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v5 + - name: Set up Go 1.20 + uses: actions/setup-go@v4 with: - go-version: 1.19 + go-version: "1.20" id: go - name: Check out code into the Go module directory diff --git a/ast/pos.go b/ast/pos.go index 98431a5d..c817c775 100644 --- a/ast/pos.go +++ b/ast/pos.go @@ -4,6 +4,51 @@ import ( "github.com/cloudspannerecosystem/memefish/token" ) +// ================================================================================ +// +// Helper functions for Pos(), End() +// These functions are intended for use within this file only. +// +// ================================================================================ + +// lastNode returns last element of Node slice. +// This function corresponds to NodeSliceVar[$] in ast.go. +func lastNode[T Node](s []T) T { + return s[len(s)-1] +} + +// firstValidEnd returns the first valid Pos() in argument. +// "valid" means the node is not nil and Pos().Invalid() is not true. +// This function corresponds to "(n0 ?? n1 ?? ...).End()" +func firstValidEnd(ns ...Node) token.Pos { + for _, n := range ns { + if n != nil && !n.End().Invalid() { + return n.End() + } + } + return token.InvalidPos +} + +// firstPos returns the Pos() of the first node. +// If argument is an empty slice, this function returns token.InvalidPos. +// This function corresponds to NodeSliceVar[0].pos in ast.go. +func firstPos[T Node](s []T) token.Pos { + if len(s) == 0 { + return token.InvalidPos + } + return s[0].Pos() +} + +// lastEnd returns the End() of the last node. +// If argument is an empty slice, this function returns token.InvalidPos. +// This function corresponds to NodeSliceVar[$].end in ast.go. +func lastEnd[T Node](s []T) token.Pos { + if len(s) == 0 { + return token.InvalidPos + } + return lastNode(s).End() +} + // ================================================================================ // // SELECT @@ -39,25 +84,7 @@ func (c *CTE) End() token.Pos { return c.Rparen + 1 } func (s *Select) Pos() token.Pos { return s.Select } func (s *Select) End() token.Pos { - if s.Limit != nil { - return s.Limit.End() - } - if s.OrderBy != nil { - return s.OrderBy.End() - } - if s.Having != nil { - return s.Having.End() - } - if s.GroupBy != nil { - return s.GroupBy.End() - } - if s.Where != nil { - return s.Where.End() - } - if s.From != nil { - return s.From.End() - } - return s.Results[len(s.Results)-1].End() + return firstValidEnd(s.Limit, s.OrderBy, s.Having, s.GroupBy, s.Where, s.From, lastNode(s.Results)) } func (c *CompoundQuery) Pos() token.Pos { @@ -376,8 +403,8 @@ func (p *Param) End() token.Pos { return p.Atmark + 1 + token.Pos(len(p.Name)) } func (i *Ident) Pos() token.Pos { return i.NamePos } func (i *Ident) End() token.Pos { return i.NameEnd } -func (p *Path) Pos() token.Pos { return p.Idents[0].Pos() } -func (p *Path) End() token.Pos { return p.Idents[len(p.Idents)-1].End() } +func (p *Path) Pos() token.Pos { return firstPos(p.Idents) } +func (p *Path) End() token.Pos { return lastEnd(p.Idents) } func (a *ArrayLiteral) Pos() token.Pos { if !a.Array.Invalid() { diff --git a/ast/sql.go b/ast/sql.go index 71889a1e..2cb2c744 100644 --- a/ast/sql.go +++ b/ast/sql.go @@ -2,8 +2,59 @@ package ast import ( "github.com/cloudspannerecosystem/memefish/token" + "strings" ) +// ================================================================================ +// +// Helper functions for SQL() +// These functions are intended for use within this file only. +// +// ================================================================================ + +// sqlOpt outputs: +// +// when node != nil: left + node.SQL() + right +// else : empty string +// +// This function corresponds to sqlOpt in ast.go +func sqlOpt[T interface { + Node + comparable +}](left string, node T, right string) string { + var zero T + if node == zero { + return "" + } + return left + node.SQL() + right +} + +// strOpt outputs: +// +// when pred == true: s +// else : empty string +// +// This function corresponds to {{if pred}}s{{end}} in ast.go +func strOpt(pred bool, s string) string { + if pred { + return s + } + return "" +} + +// sqlJoin outputs joined string of SQL() of all elems by sep. +// This function corresponds to sqlJoin in ast.go +func sqlJoin[T Node](elems []T, sep string) string { + var b strings.Builder + for i, r := range elems { + if i > 0 { + b.WriteString(sep) + } + b.WriteString(r.SQL()) + } + return b.String() +} + type prec int const ( @@ -116,36 +167,16 @@ func (c *CTE) SQL() string { } func (s *Select) SQL() string { - sql := "SELECT " - if s.Distinct { - sql += "DISTINCT " - } - if s.AsStruct { - sql += "AS STRUCT " - } - sql += s.Results[0].SQL() - for _, r := range s.Results[1:] { - sql += ", " + r.SQL() - } - if s.From != nil { - sql += " " + s.From.SQL() - } - if s.Where != nil { - sql += " " + s.Where.SQL() - } - if s.GroupBy != nil { - sql += " " + s.GroupBy.SQL() - } - if s.Having != nil { - sql += " " + s.Having.SQL() - } - if s.OrderBy != nil { - sql += " " + s.OrderBy.SQL() - } - if s.Limit != nil { - sql += " " + s.Limit.SQL() - } - return sql + return "SELECT " + + strOpt(s.Distinct, "DISTINCT ") + + strOpt(s.AsStruct, "AS STRUCT ") + + sqlJoin(s.Results, ", ") + + sqlOpt(" ", s.From, "") + + sqlOpt(" ", s.Where, "") + + sqlOpt(" ", s.GroupBy, "") + + sqlOpt(" ", s.Having, "") + + sqlOpt(" ", s.OrderBy, "") + + sqlOpt(" ", s.Limit, "") } func (c *CompoundQuery) SQL() string { @@ -464,27 +495,11 @@ func (i *IndexExpr) SQL() string { } func (c *CallExpr) SQL() string { - sql := c.Func.SQL() + "(" - if c.Distinct { - sql += "DISTINCT " - } - for i, a := range c.Args { - if i != 0 { - sql += ", " - } - sql += a.SQL() - } - if len(c.Args) > 0 && len(c.NamedArgs) > 0 { - sql += ", " - } - for i, v := range c.NamedArgs { - if i != 0 { - sql += ", " - } - sql += v.SQL() - } - sql += ")" - return sql + return c.Func.SQL() + "(" + strOpt(c.Distinct, "DISTINCT ") + + sqlJoin(c.Args, ", ") + + strOpt(len(c.Args) > 0 && len(c.NamedArgs) > 0, ", ") + + sqlJoin(c.NamedArgs, ", ") + + ")" } func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() } @@ -595,11 +610,7 @@ func (i *Ident) SQL() string { } func (p *Path) SQL() string { - sql := p.Idents[0].SQL() - for _, id := range p.Idents[1:] { - sql += "." + id.SQL() - } - return sql + return sqlJoin(p.Idents, ".") } func (a *ArrayLiteral) SQL() string { diff --git a/go.mod b/go.mod index 0defff2e..e65f1d20 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/cloudspannerecosystem/memefish -go 1.19 +go 1.20 require ( github.com/MakeNowJust/heredoc/v2 v2.0.1