diff --git a/HISTORY.md b/HISTORY.md index 726937e3..b6b81597 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,7 @@ * [ADDED] Support for ANY and ALL operators. [#196](https://github.com/doug-martin/goqu/issues/196) * [ADDED] Support for CASE statements [#193](https://github.com/doug-martin/goqu/issues/193) +* [ADDED] Support for getting column identifiers from AliasExpressions. [#203](https://github.com/doug-martin/goqu/issues/203) # v9.7.1 diff --git a/exp/alias.go b/exp/alias.go index 81b6f95c..b07b54d3 100644 --- a/exp/alias.go +++ b/exp/alias.go @@ -36,3 +36,24 @@ func (ae aliasExpression) Aliased() Expression { func (ae aliasExpression) GetAs() IdentifierExpression { return ae.alias } + +// Returns a new IdentifierExpression with the specified schema +func (ae aliasExpression) Schema(schema string) IdentifierExpression { + return ae.alias.Schema(schema) +} + +// Returns a new IdentifierExpression with the specified table +func (ae aliasExpression) Table(table string) IdentifierExpression { + return ae.alias.Table(table) +} + +// Returns a new IdentifierExpression with the specified column +func (ae aliasExpression) Col(col interface{}) IdentifierExpression { + return ae.alias.Col(col) +} + +// Returns a new IdentifierExpression with the column set to * +// I("my_table").As("t").All() //"t".* +func (ae aliasExpression) All() IdentifierExpression { + return ae.alias.All() +} diff --git a/exp/alias_test.go b/exp/alias_test.go new file mode 100644 index 00000000..c0698621 --- /dev/null +++ b/exp/alias_test.go @@ -0,0 +1,68 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type aliasExpressionSuite struct { + suite.Suite +} + +func TestAliasExpressionSuite(t *testing.T) { + suite.Run(t, &aliasExpressionSuite{}) +} + +func (aes *aliasExpressionSuite) TestClone() { + ae := aliased(NewIdentifierExpression("", "", "col"), "c") + aes.Equal(ae, ae.Clone()) +} + +func (aes *aliasExpressionSuite) TestExpression() { + ae := aliased(NewIdentifierExpression("", "", "col"), "c") + aes.Equal(ae, ae.Expression()) +} + +func (aes *aliasExpressionSuite) TestAliased() { + ident := NewIdentifierExpression("", "", "col") + ae := aliased(ident, "c") + aes.Equal(ident, ae.Aliased()) +} + +func (aes *aliasExpressionSuite) TestGetAs() { + ae := aliased(NewIdentifierExpression("", "", "col"), "c") + aes.Equal(NewIdentifierExpression("", "", "c"), ae.GetAs()) +} + +func (aes *aliasExpressionSuite) TestSchema() { + si := aliased( + NewIdentifierExpression("", "t", nil), + NewIdentifierExpression("", "t", nil), + ).Schema("s") + aes.Equal(NewIdentifierExpression("s", "t", nil), si) +} + +func (aes *aliasExpressionSuite) TestTable() { + si := aliased( + NewIdentifierExpression("schema", "", nil), + NewIdentifierExpression("s", "", nil), + ).Table("t") + aes.Equal(NewIdentifierExpression("s", "t", nil), si) +} + +func (aes *aliasExpressionSuite) TestCol() { + si := aliased( + NewIdentifierExpression("", "table", nil), + NewIdentifierExpression("", "t", nil), + ).Col("c") + aes.Equal(NewIdentifierExpression("", "t", "c"), si) +} + +func (aes *aliasExpressionSuite) TestAll() { + si := aliased( + NewIdentifierExpression("", "table", nil), + NewIdentifierExpression("", "t", nil), + ).All() + aes.Equal(NewIdentifierExpression("", "t", Star()), si) +} diff --git a/exp/exp.go b/exp/exp.go index 384d411a..90460786 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -172,6 +172,16 @@ type ( Aliased() Expression // Returns the alias value as an identiier expression GetAs() IdentifierExpression + + // Returns a new IdentifierExpression with the specified schema + Schema(string) IdentifierExpression + // Returns a new IdentifierExpression with the specified table + Table(string) IdentifierExpression + // Returns a new IdentifierExpression with the specified column + Col(interface{}) IdentifierExpression + // Returns a new IdentifierExpression with the column set to * + // I("my_table").All() //"my_table".* + All() IdentifierExpression } BooleanOperation int diff --git a/exp/ident.go b/exp/ident.go index 29f5a19b..98a7ab0e 100644 --- a/exp/ident.go +++ b/exp/ident.go @@ -1,6 +1,8 @@ package exp -import "strings" +import ( + "strings" +) type ( identifier struct { @@ -117,7 +119,21 @@ func (i identifier) GetCol() interface{} { return i.col } func (i identifier) Set(val interface{}) UpdateExpression { return set(i, val) } // Alias an identifier (e.g "my_col" AS "other_col") -func (i identifier) As(val interface{}) AliasedExpression { return aliased(i, val) } +func (i identifier) As(val interface{}) AliasedExpression { + if v, ok := val.(string); ok { + ident := ParseIdentifier(v) + if i.col != nil && i.col != "" { + return aliased(i, ident) + } + aliasCol := ident.GetCol() + if i.table != "" { + return aliased(i, NewIdentifierExpression("", aliasCol.(string), nil)) + } else if i.schema != "" { + return aliased(i, NewIdentifierExpression(aliasCol.(string), "", nil)) + } + } + return aliased(i, val) +} // Returns a BooleanExpression for equality (e.g "my_col" = 1) func (i identifier) Eq(val interface{}) BooleanExpression { return eq(i, val) } diff --git a/exp/ident_test.go b/exp/ident_test.go index 556a0c6c..9a2b19c1 100644 --- a/exp/ident_test.go +++ b/exp/ident_test.go @@ -165,6 +165,33 @@ func (ies *identifierExpressionSuite) TestIsEmpty() { } } +func (ies *identifierExpressionSuite) TestAs() { + cases := []struct { + Alias AliasedExpression + Expected Expression + }{ + { + Alias: NewIdentifierExpression("", "", "col").As("c"), + Expected: aliased(NewIdentifierExpression("", "", "col"), NewIdentifierExpression("", "", "c")), + }, + { + Alias: NewIdentifierExpression("", "table", nil).As("t"), + Expected: aliased(NewIdentifierExpression("", "table", nil), NewIdentifierExpression("", "t", nil)), + }, + { + Alias: NewIdentifierExpression("", "table", nil).As("s.t"), + Expected: aliased(NewIdentifierExpression("", "table", nil), NewIdentifierExpression("", "t", nil)), + }, + { + Alias: NewIdentifierExpression("schema", "", nil).As("s"), + Expected: aliased(NewIdentifierExpression("schema", "", nil), NewIdentifierExpression("s", "", nil)), + }, + } + for _, tc := range cases { + ies.Equal(tc.Expected, tc.Alias) + } +} + func (ies *identifierExpressionSuite) TestAllOthers() { ident := NewIdentifierExpression("", "", "a") rv := NewRangeVal(1, 2) diff --git a/issues_test.go b/issues_test.go index 78921c9d..44cb91e7 100644 --- a/issues_test.go +++ b/issues_test.go @@ -422,6 +422,34 @@ func (gis *githubIssuesSuite) TestIssue185() { gis.Equal([]int{1, 2, 3, 4}, i) } +// Test for https://github.com/doug-martin/goqu/issues/203 +func (gis *githubIssuesSuite) TestIssue203() { + // Schema definitions. + authSchema := goqu.S("company_auth") + + // Table definitions + usersTable := authSchema.Table("users") + + u := usersTable.As("u") + + ds := goqu.From(u).Select( + u.Col("id"), + u.Col("name"), + u.Col("created_at"), + u.Col("updated_at"), + ) + + sql, args, err := ds.ToSQL() + gis.NoError(err) + gis.Equal(`SELECT "u"."id", "u"."name", "u"."created_at", "u"."updated_at" FROM "company_auth"."users" AS "u"`, sql) + gis.Empty(args, []interface{}{}) + + sql, args, err = ds.Prepared(true).ToSQL() + gis.NoError(err) + gis.Equal(`SELECT "u"."id", "u"."name", "u"."created_at", "u"."updated_at" FROM "company_auth"."users" AS "u"`, sql) + gis.Empty(args, []interface{}{}) +} + func TestGithubIssuesSuite(t *testing.T) { suite.Run(t, new(githubIssuesSuite)) }