diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5b65e1d1..01a2b792 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -7,20 +7,31 @@ on: pull_request: # Specify a second event with pattern matching jobs: test: - name: Test go ${{ matrix.go_version }} + name: Test go - ${{ matrix.go_version }} mysql - ${{ matrix.db_versions.mysql_version}} postgres - ${{ matrix.db_versions.postgres_version}} runs-on: ubuntu-latest strategy: matrix: go_version: ["1.10", "1.11", "latest"] + db_versions: + - mysql_version: 5 + postgres_version: 9.6 + - mysql_version: 5 + postgres_version: "10.10" + - mysql_version: 8 + postgres_version: 11.5 steps: - name: checkout uses: actions/checkout@v1 - name: Test env: GO_VERSION: ${{ matrix.go_version }} + MYSQL_VERSION: ${{ matrix.db_versions.mysql_version }} + POSTGRES_VERSION: ${{ matrix.db_versions.postgres_version }} run: docker-compose run goqu-coverage - name: Upload coverage to Codecov - run: bash <(curl -s https://codecov.io/bash) -n $GO_VERSION -e GO_VERSION,GITHUB_WORKFLOW,GITHUB_ACTION + run: bash <(curl -s https://codecov.io/bash) -n $GO_VERSION -e GO_VERSION,MYSQL_VERSION,POSTGRES_VERSION,GITHUB_WORKFLOW,GITHUB_ACTION env: CODECOV_TOKEN: ${{secrets.CODECOV_TOKEN}} GO_VERSION: ${{ matrix.go_version }} + MYSQL_VERSION: ${{ matrix.db_versions.mysql_version }} + POSTGRES_VERSION: ${{ matrix.db_versions.postgres_version }} diff --git a/HISTORY.md b/HISTORY.md index 49b6f1cc..12dba341 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +## 8.5.0 + +* [ADDED] Window Function support [#128](https://github.com/doug-martin/goqu/issues/128) - [@Xuyuanp](https://github.com/Xuyuanp) + ## 8.4.1 * [FIXED] Returning func be able to handle nil [#140](https://github.com/doug-martin/goqu/issues/140) diff --git a/database_test.go b/database_test.go index 3574f61e..92980bc6 100644 --- a/database_test.go +++ b/database_test.go @@ -6,8 +6,6 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" - "github.com/DATA-DOG/go-sqlmock" "github.com/doug-martin/goqu/v8/internal/errors" "github.com/stretchr/testify/suite" @@ -317,9 +315,8 @@ func (ds *databaseSuite) TestWithTx() { } func (ds *databaseSuite) TestDataRace() { - t := ds.T() mDb, mock, err := sqlmock.New() - assert.NoError(t, err) + ds.NoError(err) db := newDatabase("mock", mDb) const concurrency = 10 @@ -340,10 +337,10 @@ func (ds *databaseSuite) TestDataRace() { sql := db.From("items").Limit(1) var item testActionItem found, err := sql.ScanStruct(&item) - assert.NoError(t, err) - assert.True(t, found) - assert.Equal(t, item.Address, "111 Test Addr") - assert.Equal(t, item.Name, "Test1") + ds.NoError(err) + ds.True(found) + ds.Equal(item.Address, "111 Test Addr") + ds.Equal(item.Name, "Test1") }() } @@ -661,13 +658,12 @@ func (tds *txdatabaseSuite) TestWrap() { } func (tds *txdatabaseSuite) TestDataRace() { - t := tds.T() mDb, mock, err := sqlmock.New() - assert.NoError(t, err) + tds.NoError(err) mock.ExpectBegin() db := newDatabase("mock", mDb) tx, err := db.Begin() - assert.NoError(t, err) + tds.NoError(err) const concurrency = 10 @@ -687,16 +683,16 @@ func (tds *txdatabaseSuite) TestDataRace() { sql := tx.From("items").Limit(1) var item testActionItem found, err := sql.ScanStruct(&item) - assert.NoError(t, err) - assert.True(t, found) - assert.Equal(t, item.Address, "111 Test Addr") - assert.Equal(t, item.Name, "Test1") + tds.NoError(err) + tds.True(found) + tds.Equal(item.Address, "111 Test Addr") + tds.Equal(item.Name, "Test1") }() } wg.Wait() mock.ExpectCommit() - assert.NoError(t, tx.Commit()) + tds.NoError(tx.Commit()) } func TestTxDatabaseSuite(t *testing.T) { diff --git a/dialect/mysql/mysql.go b/dialect/mysql/mysql.go index eab00cbe..f0e99551 100644 --- a/dialect/mysql/mysql.go +++ b/dialect/mysql/mysql.go @@ -19,6 +19,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.SupportsWithCTE = false opts.SupportsWithCTERecursive = false opts.SupportsDistinctOn = false + opts.SupportsWindowFunction = false opts.UseFromClauseForMultipleUpdateTables = false @@ -65,6 +66,13 @@ func DialectOptions() *goqu.SQLDialectOptions { return opts } +func DialectOptionsV8() *goqu.SQLDialectOptions { + opts := DialectOptions() + opts.SupportsWindowFunction = true + return opts +} + func init() { goqu.RegisterDialect("mysql", DialectOptions()) + goqu.RegisterDialect("mysql8", DialectOptionsV8()) } diff --git a/dialect/mysql/mysql_test.go b/dialect/mysql/mysql_test.go index 876f95d7..3a34f9b9 100644 --- a/dialect/mysql/mysql_test.go +++ b/dialect/mysql/mysql_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "os" + "strconv" + "strings" "testing" "time" @@ -393,6 +395,44 @@ func (mt *mysqlTest) TestInsert_OnConflict() { mt.EqualError(err, "goqu: dialect does not support upsert with where clause [dialect=mysql]") } +func (mt *mysqlTest) TestWindowFunction() { + var version string + ok, err := mt.db.Select(goqu.Func("version")).ScanVal(&version) + mt.NoError(err) + mt.True(ok) + + fields := strings.Split(version, ".") + mt.True(len(fields) > 0) + major, err := strconv.Atoi(fields[0]) + mt.NoError(err) + if major < 8 { + fmt.Printf("SKIPPING MYSQL WINDOW FUNCTION TEST BECAUSE VERSION IS < 8 [mysql_version:=%d]\n", major) + return + } + + ds := mt.db.From("entry"). + Select("int", goqu.ROW_NUMBER().OverName(goqu.I("w")).As("id")). + Window(goqu.W("w").OrderBy(goqu.I("int").Desc())) + + var entries []entry + mt.NoError(ds.WithDialect("mysql8").ScanStructs(&entries)) + + mt.Equal([]entry{ + {Int: 9, ID: 1}, + {Int: 8, ID: 2}, + {Int: 7, ID: 3}, + {Int: 6, ID: 4}, + {Int: 5, ID: 5}, + {Int: 4, ID: 6}, + {Int: 3, ID: 7}, + {Int: 2, ID: 8}, + {Int: 1, ID: 9}, + {Int: 0, ID: 10}, + }, entries) + + mt.Error(ds.WithDialect("mysql").ScanStructs(&entries), "goqu: adapter does not support window function clause") +} + func TestMysqlSuite(t *testing.T) { suite.Run(t, new(mysqlTest)) } diff --git a/dialect/postgres/postgres_test.go b/dialect/postgres/postgres_test.go index b1d662de..424f5106 100644 --- a/dialect/postgres/postgres_test.go +++ b/dialect/postgres/postgres_test.go @@ -414,6 +414,28 @@ func (pt *postgresTest) TestInsert_OnConflict() { pt.Equal("upsert", entry9.String) } +func (pt *postgresTest) TestWindowFunction() { + ds := pt.db.From("entry"). + Select("int", goqu.ROW_NUMBER().OverName(goqu.I("w")).As("id")). + Window(goqu.W("w").OrderBy(goqu.I("int").Desc())) + + var entries []entry + pt.NoError(ds.ScanStructs(&entries)) + + pt.Equal([]entry{ + {Int: 9, ID: 1}, + {Int: 8, ID: 2}, + {Int: 7, ID: 3}, + {Int: 6, ID: 4}, + {Int: 5, ID: 5}, + {Int: 4, ID: 6}, + {Int: 3, ID: 7}, + {Int: 2, ID: 8}, + {Int: 1, ID: 9}, + {Int: 0, ID: 10}, + }, entries) +} + func TestPostgresSuite(t *testing.T) { suite.Run(t, new(postgresTest)) } diff --git a/dialect/sqlite3/sqlite3.go b/dialect/sqlite3/sqlite3.go index 32f3b1f2..5318645f 100644 --- a/dialect/sqlite3/sqlite3.go +++ b/dialect/sqlite3/sqlite3.go @@ -19,6 +19,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.SupportsMultipleUpdateTables = false opts.WrapCompoundsInParens = false opts.SupportsDistinctOn = false + opts.SupportsWindowFunction = false opts.PlaceHolderRune = '?' opts.IncludePlaceholderNum = false diff --git a/docker-compose.yml b/docker-compose.yml index bace1735..14c04cf6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: "2" services: postgres: - image: postgres:9.6 + image: "postgres:${POSTGRES_VERSION}" environment: - "POSTGRES_USER=postgres" - "POSTGRES_DB=goqupostgres" @@ -10,7 +10,7 @@ services: - "5432" mysql: - image: mysql:5 + image: "mysql:${MYSQL_VERSION}" environment: - "MYSQL_DATABASE=goqumysql" - "MYSQL_ALLOW_EMPTY_PASSWORD=yes" diff --git a/docs/selecting.md b/docs/selecting.md index e6941297..116104b5 100644 --- a/docs/selecting.md +++ b/docs/selecting.md @@ -11,6 +11,7 @@ * [`Offset`](#offset) * [`GroupBy`](#group_by) * [`Having`](#having) + * [`Window`](#window) * Executing Queries * [`ScanStructs`](#scan-structs) - Scans rows into a slice of structs * [`ScanStruct`](#scan-struct) - Scans a row into a slice a struct, returns false if a row wasnt found @@ -610,6 +611,42 @@ Output: SELECT * FROM "test" GROUP BY "age" HAVING (SUM("income") > 1000) ``` + + +**[`Window Function`](https://godoc.org/github.com/doug-martin/goqu/#SelectDataset.Window)** + +**NOTE** currently only the `postgres`, `mysql8` (NOT `mysql`) and the default dialect support `Window Function` + +To use windowing in select you can use the `Over` method on an `SQLFunction` + +```go +sql, _, _ := goqu.From("test").Select( + goqu.ROW_NUMBER().Over(goqu.W().PartitionBy("a").OrderBy(goqu.I("b").Asc())), +) +fmt.Println(sql) +``` + +Output: + +``` +SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "b") FROM "test" +``` + +`goqu` also supports the `WINDOW` clause. + +```go +sql, _, _ := goqu.From("test"). + Select(goqu.ROW_NUMBER().OverName(goqu.I("w"))). + Window(goqu.W("w").PartitionBy("a").OrderBy(goqu.I("b").Asc())) +fmt.Println(sql) +``` + +Output: + +``` +SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w" AS (PARTITION BY "a" ORDER BY "b") +``` + ## Executing Queries To execute your query use [`goqu.Database#From`](https://godoc.org/github.com/doug-martin/goqu/#Database.From) to create your dataset @@ -748,4 +785,5 @@ if err := db.From("user").Pluck(&ids, "id"); err != nil{ return } fmt.Printf("\nIds := %+v", ids) -``` \ No newline at end of file +``` + diff --git a/exp/col.go b/exp/col.go index 55def0e1..7387bd48 100644 --- a/exp/col.go +++ b/exp/col.go @@ -12,7 +12,7 @@ type columnList struct { } func NewColumnListExpression(vals ...interface{}) ColumnListExpression { - var cols []Expression + cols := []Expression{} for _, val := range vals { switch t := val.(type) { case nil: // do nothing diff --git a/exp/exp.go b/exp/exp.go index b08fe730..d225f48f 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -341,6 +341,11 @@ type ( End() interface{} } + Windowable interface { + Over(WindowExpression) SQLWindowFunctionExpression + OverName(IdentifierExpression) SQLWindowFunctionExpression + } + // Expression for representing a SQLFunction(e.g. COUNT, SUM, MIN, MAX...) SQLFunctionExpression interface { Expression @@ -350,6 +355,7 @@ type ( Isable Inable Likeable + Windowable // The function name Name() string // Arguments to be passed to the function @@ -360,6 +366,41 @@ type ( Col() IdentifierExpression Val() interface{} } + + SQLWindowFunctionExpression interface { + Expression + Aliaseable + Rangeable + Comparable + Isable + Inable + Likeable + Func() SQLFunctionExpression + + Window() WindowExpression + WindowName() IdentifierExpression + + HasWindow() bool + HasWindowName() bool + } + + WindowExpression interface { + Expression + + Name() IdentifierExpression + HasName() bool + + Parent() IdentifierExpression + HasParent() bool + PartitionCols() ColumnListExpression + HasPartitionBy() bool + OrderCols() ColumnListExpression + HasOrder() bool + + Inherit(parent string) WindowExpression + PartitionBy(cols ...interface{}) WindowExpression + OrderBy(cols ...interface{}) WindowExpression + } ) const ( diff --git a/exp/func.go b/exp/func.go index df60df9f..67b62a39 100644 --- a/exp/func.go +++ b/exp/func.go @@ -46,3 +46,11 @@ func (sfe sqlFunctionExpression) IsTrue() BooleanExpression { retu func (sfe sqlFunctionExpression) IsNotTrue() BooleanExpression { return isNot(sfe, true) } func (sfe sqlFunctionExpression) IsFalse() BooleanExpression { return is(sfe, false) } func (sfe sqlFunctionExpression) IsNotFalse() BooleanExpression { return isNot(sfe, false) } + +func (sfe sqlFunctionExpression) Over(we WindowExpression) SQLWindowFunctionExpression { + return NewSQLWindowFunctionExpression(sfe, nil, we) +} + +func (sfe sqlFunctionExpression) OverName(windowName IdentifierExpression) SQLWindowFunctionExpression { + return NewSQLWindowFunctionExpression(sfe, windowName, nil) +} diff --git a/exp/select_clauses.go b/exp/select_clauses.go index 400e3544..f5443caf 100644 --- a/exp/select_clauses.go +++ b/exp/select_clauses.go @@ -58,6 +58,11 @@ type ( CommonTables() []CommonTableExpression CommonTablesAppend(cte CommonTableExpression) SelectClauses + + Windows() []WindowExpression + SetWindows(ws []WindowExpression) SelectClauses + WindowsAppend(ws ...WindowExpression) SelectClauses + ClearWindows() SelectClauses } selectClauses struct { commonTables []CommonTableExpression @@ -74,6 +79,7 @@ type ( offset uint compounds []CompoundExpression lock Lock + windows []WindowExpression } ) @@ -116,6 +122,7 @@ func (c *selectClauses) clone() *selectClauses { offset: c.offset, compounds: c.compounds, lock: c.lock, + windows: c.windows, } } @@ -331,3 +338,25 @@ func (c *selectClauses) CompoundsAppend(ce CompoundExpression) SelectClauses { ret.compounds = append(ret.compounds, ce) return ret } + +func (c *selectClauses) Windows() []WindowExpression { + return c.windows +} + +func (c *selectClauses) SetWindows(ws []WindowExpression) SelectClauses { + ret := c.clone() + ret.windows = ws + return ret +} + +func (c *selectClauses) WindowsAppend(ws ...WindowExpression) SelectClauses { + ret := c.clone() + ret.windows = append(ret.windows, ws...) + return ret +} + +func (c *selectClauses) ClearWindows() SelectClauses { + ret := c.clone() + ret.windows = nil + return ret +} diff --git a/exp/select_clauses_test.go b/exp/select_clauses_test.go index f6a52d47..dfe06056 100644 --- a/exp/select_clauses_test.go +++ b/exp/select_clauses_test.go @@ -255,6 +255,48 @@ func (scs *selectClausesSuite) TestHavingAppend() { scs.Equal(NewExpressionList(AndType, w, w2), c4.Having()) } +func (scs *selectClausesSuite) TestWindows() { + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + + c := NewSelectClauses() + c2 := c.WindowsAppend(w) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestSetWindows() { + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + + c := NewSelectClauses() + c2 := c.SetWindows([]WindowExpression{w}) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestWindowsAppend() { + w1 := NewWindowExpression(NewIdentifierExpression("", "", "w1"), nil, nil, nil) + w2 := NewWindowExpression(NewIdentifierExpression("", "", "w2"), nil, nil, nil) + + c := NewSelectClauses() + c2 := c.WindowsAppend(w1).WindowsAppend(w2) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w1, w2}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestClearWindows() { + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + + c := NewSelectClauses().SetWindows([]WindowExpression{w}) + scs.Nil(c.ClearWindows().Windows()) + scs.Equal([]WindowExpression{w}, c.Windows()) +} + func (scs *selectClausesSuite) TestOrder() { oe := NewIdentifierExpression("", "", "a").Desc() diff --git a/exp/window.go b/exp/window.go new file mode 100644 index 00000000..a33178a0 --- /dev/null +++ b/exp/window.go @@ -0,0 +1,90 @@ +package exp + +type sqlWindowExpression struct { + name IdentifierExpression + parent IdentifierExpression + partitionCols ColumnListExpression + orderCols ColumnListExpression +} + +func NewWindowExpression(window, parent IdentifierExpression, partitionCols, orderCols ColumnListExpression) WindowExpression { + if partitionCols == nil { + partitionCols = NewColumnListExpression() + } + if orderCols == nil { + orderCols = NewColumnListExpression() + } + return sqlWindowExpression{ + name: window, + parent: parent, + partitionCols: partitionCols, + orderCols: orderCols, + } +} + +func (we sqlWindowExpression) clone() sqlWindowExpression { + return sqlWindowExpression{ + name: we.name, + parent: we.parent, + partitionCols: we.partitionCols.Clone().(ColumnListExpression), + orderCols: we.orderCols.Clone().(ColumnListExpression), + } +} + +func (we sqlWindowExpression) Clone() Expression { + return we.clone() +} + +func (we sqlWindowExpression) Expression() Expression { + return we +} + +func (we sqlWindowExpression) Name() IdentifierExpression { + return we.name +} + +func (we sqlWindowExpression) HasName() bool { + return we.name != nil +} + +func (we sqlWindowExpression) Parent() IdentifierExpression { + return we.parent +} + +func (we sqlWindowExpression) HasParent() bool { + return we.parent != nil +} + +func (we sqlWindowExpression) PartitionCols() ColumnListExpression { + return we.partitionCols +} + +func (we sqlWindowExpression) HasPartitionBy() bool { + return we.partitionCols != nil && !we.partitionCols.IsEmpty() +} + +func (we sqlWindowExpression) OrderCols() ColumnListExpression { + return we.orderCols +} + +func (we sqlWindowExpression) HasOrder() bool { + return we.orderCols != nil && !we.orderCols.IsEmpty() +} + +func (we sqlWindowExpression) PartitionBy(cols ...interface{}) WindowExpression { + ret := we.clone() + ret.partitionCols = NewColumnListExpression(cols...) + return ret +} + +func (we sqlWindowExpression) OrderBy(cols ...interface{}) WindowExpression { + ret := we.clone() + ret.orderCols = NewColumnListExpression(cols...) + return ret +} + +func (we sqlWindowExpression) Inherit(parent string) WindowExpression { + ret := we.clone() + ret.parent = ParseIdentifier(parent) + return ret +} diff --git a/exp/window_func.go b/exp/window_func.go new file mode 100644 index 00000000..8580f957 --- /dev/null +++ b/exp/window_func.go @@ -0,0 +1,96 @@ +package exp + +type sqlWindowFunctionExpression struct { + fn SQLFunctionExpression + windowName IdentifierExpression + window WindowExpression +} + +func NewSQLWindowFunctionExpression( + fn SQLFunctionExpression, + windowName IdentifierExpression, + window WindowExpression) SQLWindowFunctionExpression { + return sqlWindowFunctionExpression{ + fn: fn, + windowName: windowName, + window: window, + } +} + +func (swfe sqlWindowFunctionExpression) clone() sqlWindowFunctionExpression { + return sqlWindowFunctionExpression{ + fn: swfe.fn.Clone().(SQLFunctionExpression), + windowName: swfe.windowName, + window: swfe.window, + } +} + +func (swfe sqlWindowFunctionExpression) Clone() Expression { + return swfe.clone() +} +func (swfe sqlWindowFunctionExpression) Expression() Expression { + return swfe +} +func (swfe sqlWindowFunctionExpression) As(val interface{}) AliasedExpression { + return aliased(swfe, val) +} +func (swfe sqlWindowFunctionExpression) Eq(val interface{}) BooleanExpression { return eq(swfe, val) } +func (swfe sqlWindowFunctionExpression) Neq(val interface{}) BooleanExpression { return neq(swfe, val) } +func (swfe sqlWindowFunctionExpression) Gt(val interface{}) BooleanExpression { return gt(swfe, val) } +func (swfe sqlWindowFunctionExpression) Gte(val interface{}) BooleanExpression { return gte(swfe, val) } +func (swfe sqlWindowFunctionExpression) Lt(val interface{}) BooleanExpression { return lt(swfe, val) } +func (swfe sqlWindowFunctionExpression) Lte(val interface{}) BooleanExpression { return lte(swfe, val) } +func (swfe sqlWindowFunctionExpression) Between(val RangeVal) RangeExpression { + return between(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotBetween(val RangeVal) RangeExpression { + return notBetween(swfe, val) +} +func (swfe sqlWindowFunctionExpression) Like(val interface{}) BooleanExpression { + return like(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotLike(val interface{}) BooleanExpression { + return notLike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) ILike(val interface{}) BooleanExpression { + return iLike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotILike(val interface{}) BooleanExpression { + return notILike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) In(vals ...interface{}) BooleanExpression { + return in(swfe, vals...) +} +func (swfe sqlWindowFunctionExpression) NotIn(vals ...interface{}) BooleanExpression { + return notIn(swfe, vals...) +} +func (swfe sqlWindowFunctionExpression) Is(val interface{}) BooleanExpression { return is(swfe, val) } +func (swfe sqlWindowFunctionExpression) IsNot(val interface{}) BooleanExpression { + return isNot(swfe, val) +} +func (swfe sqlWindowFunctionExpression) IsNull() BooleanExpression { return is(swfe, nil) } +func (swfe sqlWindowFunctionExpression) IsNotNull() BooleanExpression { return isNot(swfe, nil) } +func (swfe sqlWindowFunctionExpression) IsTrue() BooleanExpression { return is(swfe, true) } +func (swfe sqlWindowFunctionExpression) IsNotTrue() BooleanExpression { return isNot(swfe, true) } +func (swfe sqlWindowFunctionExpression) IsFalse() BooleanExpression { return is(swfe, false) } +func (swfe sqlWindowFunctionExpression) IsNotFalse() BooleanExpression { return isNot(swfe, false) } + +func (swfe sqlWindowFunctionExpression) Func() SQLFunctionExpression { + return swfe.fn +} + +func (swfe sqlWindowFunctionExpression) Window() WindowExpression { + return swfe.window +} + +func (swfe sqlWindowFunctionExpression) WindowName() IdentifierExpression { + return swfe.windowName +} + +func (swfe sqlWindowFunctionExpression) HasWindow() bool { + return swfe.window != nil +} + +func (swfe sqlWindowFunctionExpression) HasWindowName() bool { + return swfe.windowName != nil +} diff --git a/exp/window_func_test.go b/exp/window_func_test.go new file mode 100644 index 00000000..3ee6f772 --- /dev/null +++ b/exp/window_func_test.go @@ -0,0 +1,181 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type sqlWindowFunctionExpressionTest struct { + suite.Suite + fn SQLFunctionExpression +} + +func TestSQLWindowFunctionExpressionSuite(t *testing.T) { + suite.Run(t, &sqlWindowFunctionExpressionTest{ + fn: NewSQLFunctionExpression("COUNT", Star()), + }) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestClone() { + wf := NewSQLWindowFunctionExpression(swfet.fn, NewIdentifierExpression("", "", "a"), nil) + wf2 := wf.Clone() + swfet.Equal(wf, wf2) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestExpression() { + wf := NewSQLWindowFunctionExpression(swfet.fn, NewIdentifierExpression("", "", "a"), nil) + wf2 := wf.Expression() + swfet.Equal(wf, wf2) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestFunc() { + wf := NewSQLWindowFunctionExpression(swfet.fn, NewIdentifierExpression("", "", "a"), nil) + swfet.Equal(swfet.fn, wf.Func()) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestWindow() { + w := NewWindowExpression( + NewIdentifierExpression("", "", "w"), + nil, + nil, + nil, + ) + wf := NewSQLWindowFunctionExpression(swfet.fn, NewIdentifierExpression("", "", "a"), nil) + swfet.False(wf.HasWindow()) + + wf = swfet.fn.Over(w) + swfet.True(wf.HasWindow()) + swfet.Equal(wf.Window(), w) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestWindowName() { + windowName := NewIdentifierExpression("", "", "a") + wf := NewSQLWindowFunctionExpression(swfet.fn, nil, nil) + swfet.False(wf.HasWindowName()) + + wf = swfet.fn.OverName(windowName) + swfet.True(wf.HasWindowName()) + swfet.Equal(wf.WindowName(), windowName) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestAllOthers() { + wf := NewSQLWindowFunctionExpression(swfet.fn, nil, nil) + + expAs := wf.As("a") + swfet.Equal(expAs.Aliased(), wf) + + expEq := wf.Eq(1) + swfet.Equal(expEq.LHS(), wf) + swfet.Equal(expEq.Op(), EqOp) + swfet.Equal(expEq.RHS(), 1) + + expNeq := wf.Neq(1) + swfet.Equal(expNeq.LHS(), wf) + swfet.Equal(expNeq.Op(), NeqOp) + swfet.Equal(expNeq.RHS(), 1) + + expGt := wf.Gt(1) + swfet.Equal(expGt.LHS(), wf) + swfet.Equal(expGt.Op(), GtOp) + swfet.Equal(expGt.RHS(), 1) + + expGte := wf.Gte(1) + swfet.Equal(expGte.LHS(), wf) + swfet.Equal(expGte.Op(), GteOp) + swfet.Equal(expGte.RHS(), 1) + + expLt := wf.Lt(1) + swfet.Equal(expLt.LHS(), wf) + swfet.Equal(expLt.Op(), LtOp) + swfet.Equal(expLt.RHS(), 1) + + expLte := wf.Lte(1) + swfet.Equal(expLte.LHS(), wf) + swfet.Equal(expLte.Op(), LteOp) + swfet.Equal(expLte.RHS(), 1) + + rv := NewRangeVal(1, 2) + expBetween := wf.Between(rv) + swfet.Equal(expBetween.LHS(), wf) + swfet.Equal(expBetween.Op(), BetweenOp) + swfet.Equal(expBetween.RHS(), rv) + + expNotBetween := wf.NotBetween(rv) + swfet.Equal(expNotBetween.LHS(), wf) + swfet.Equal(expNotBetween.Op(), NotBetweenOp) + swfet.Equal(expNotBetween.RHS(), rv) + + pattern := "a%" + expLike := wf.Like(pattern) + swfet.Equal(expLike.LHS(), wf) + swfet.Equal(expLike.Op(), LikeOp) + swfet.Equal(expLike.RHS(), pattern) + + expNotLike := wf.NotLike(pattern) + swfet.Equal(expNotLike.LHS(), wf) + swfet.Equal(expNotLike.Op(), NotLikeOp) + swfet.Equal(expNotLike.RHS(), pattern) + + expILike := wf.ILike(pattern) + swfet.Equal(expILike.LHS(), wf) + swfet.Equal(expILike.Op(), ILikeOp) + swfet.Equal(expILike.RHS(), pattern) + + expNotILike := wf.NotILike(pattern) + swfet.Equal(expNotILike.LHS(), wf) + swfet.Equal(expNotILike.Op(), NotILikeOp) + swfet.Equal(expNotILike.RHS(), pattern) + + vals := []interface{}{1, 2} + expIn := wf.In(vals) + swfet.Equal(expIn.LHS(), wf) + swfet.Equal(expIn.Op(), InOp) + swfet.Equal(expIn.RHS(), vals) + + expNotIn := wf.NotIn(vals) + swfet.Equal(expNotIn.LHS(), wf) + swfet.Equal(expNotIn.Op(), NotInOp) + swfet.Equal(expNotIn.RHS(), vals) + + obj := 1 + expIs := wf.Is(obj) + swfet.Equal(expIs.LHS(), wf) + swfet.Equal(expIs.Op(), IsOp) + swfet.Equal(expIs.RHS(), obj) + + expIsNot := wf.IsNot(obj) + swfet.Equal(expIsNot.LHS(), wf) + swfet.Equal(expIsNot.Op(), IsNotOp) + swfet.Equal(expIsNot.RHS(), obj) + + expIsNull := wf.IsNull() + swfet.Equal(expIsNull.LHS(), wf) + swfet.Equal(expIsNull.Op(), IsOp) + swfet.Nil(expIsNull.RHS()) + + expIsNotNull := wf.IsNotNull() + swfet.Equal(expIsNotNull.LHS(), wf) + swfet.Equal(expIsNotNull.Op(), IsNotOp) + swfet.Nil(expIsNotNull.RHS()) + + expIsTrue := wf.IsTrue() + swfet.Equal(expIsTrue.LHS(), wf) + swfet.Equal(expIsTrue.Op(), IsOp) + swfet.Equal(expIsTrue.RHS(), true) + + expIsNotTrue := wf.IsNotTrue() + swfet.Equal(expIsNotTrue.LHS(), wf) + swfet.Equal(expIsNotTrue.Op(), IsNotOp) + swfet.Equal(expIsNotTrue.RHS(), true) + + expIsFalse := wf.IsFalse() + swfet.Equal(expIsFalse.LHS(), wf) + swfet.Equal(expIsFalse.Op(), IsOp) + swfet.Equal(expIsFalse.RHS(), false) + + expIsNotFalse := wf.IsNotFalse() + swfet.Equal(expIsNotFalse.LHS(), wf) + swfet.Equal(expIsNotFalse.Op(), IsNotOp) + swfet.Equal(expIsNotFalse.RHS(), false) +} diff --git a/exp/window_test.go b/exp/window_test.go new file mode 100644 index 00000000..e643e680 --- /dev/null +++ b/exp/window_test.go @@ -0,0 +1,83 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type windowExpressionTest struct { + suite.Suite +} + +func TestWindowExpressionSuite(t *testing.T) { + suite.Run(t, new(windowExpressionTest)) +} + +func (wet *windowExpressionTest) TestClone() { + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + w2 := w.Clone() + + wet.Equal(w, w2) +} + +func (wet *windowExpressionTest) TestExpression() { + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + w2 := w.Expression() + + wet.Equal(w, w2) +} + +func (wet *windowExpressionTest) TestName() { + name := NewIdentifierExpression("", "", "w") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil) + + wet.Equal(name, w.Name()) +} + +func (wet *windowExpressionTest) TestPartitionCols() { + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, cols, nil) + + wet.Equal(cols, w.PartitionCols()) + wet.Equal(cols, w.Clone().(WindowExpression).PartitionCols()) +} + +func (wet *windowExpressionTest) TestOrderCols() { + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, cols) + + wet.Equal(cols, w.OrderCols()) + wet.Equal(cols, w.Clone().(WindowExpression).OrderCols()) +} + +func (wet *windowExpressionTest) TestPartitionBy() { + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil).PartitionBy("a", "b") + + wet.Equal(cols, w.PartitionCols()) +} + +func (wet *windowExpressionTest) TestOrderBy() { + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), nil, nil, nil).OrderBy("a", "b") + + wet.Equal(cols, w.OrderCols()) +} + +func (wet *windowExpressionTest) TestParent() { + parent := NewIdentifierExpression("", "", "w1") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), parent, nil, nil) + + wet.Equal(parent, w.Parent()) +} + +func (wet *windowExpressionTest) TestInherit() { + parent := NewIdentifierExpression("", "", "w1") + w := NewWindowExpression(NewIdentifierExpression("", "", "w"), parent, nil, nil) + + wet.Equal(parent, w.Parent()) + + w = w.Inherit("w2") + wet.Equal(NewIdentifierExpression("", "", "w2"), w.Parent()) +} diff --git a/expressions.go b/expressions.go index 62dc1ac7..b7281358 100644 --- a/expressions.go +++ b/expressions.go @@ -15,6 +15,9 @@ type ( TruncateOptions = exp.TruncateOptions ) +// emptyWindow is an empty WINDOW clause without name +var emptyWindow = exp.NewWindowExpression(nil, nil, nil, nil) + const ( Wait = exp.Wait NoWait = exp.NoWait @@ -114,7 +117,53 @@ func SUM(col interface{}) exp.SQLFunctionExpression { return newIdentifierFunc(" // COALESCE(I("a"), "a") -> COALESCE("a", 'a') // COALESCE(I("a"), I("b"), nil) -> COALESCE("a", "b", NULL) func COALESCE(vals ...interface{}) exp.SQLFunctionExpression { - return exp.NewSQLFunctionExpression("COALESCE", vals...) + return Func("COALESCE", vals...) +} + +// nolint: golint +func ROW_NUMBER() exp.SQLFunctionExpression { + return Func("ROW_NUMBER") +} + +func RANK() exp.SQLFunctionExpression { + return Func("RANK") +} + +// nolint: golint +func DENSE_RANK() exp.SQLFunctionExpression { + return Func("DENSE_RANK") +} + +// nolint: golint +func PERCENT_RANK() exp.SQLFunctionExpression { + return Func("PERCENT_RANK") +} + +// nolint: golint +func CUME_DIST() exp.SQLFunctionExpression { + return Func("CUME_DIST") +} + +func NTILE(n int) exp.SQLFunctionExpression { + return Func("NTILE", n) +} + +// nolint: golint +func FIRST_VALUE(val interface{}) exp.SQLFunctionExpression { + return newIdentifierFunc("FIRST_VALUE", val) +} + +// nolint: golint +func LAST_VALUE(val interface{}) exp.SQLFunctionExpression { + return newIdentifierFunc("LAST_VALUE", val) +} + +// nolint: golint +func NTH_VALUE(val interface{}, nth int) exp.SQLFunctionExpression { + if s, ok := val.(string); ok { + val = I(s) + } + return Func("NTH_VALUE", val, nth) } // Creates a new Identifier, the generated sql will use adapter specific quoting or '"' by default, this ensures case @@ -163,6 +212,29 @@ func T(table string) exp.IdentifierExpression { return exp.NewIdentifierExpression("", table, "") } +// Create a new WINDOW clause +// W() -> () +// W().PartitionBy("a") -> (PARTITION BY "a") +// W().PartitionBy("a").OrderBy("b") -> (PARTITION BY "a" ORDER BY "b") +// W().PartitionBy("a").OrderBy("b").Inherit("w1") -> ("w1" PARTITION BY "a" ORDER BY "b") +// W().PartitionBy("a").OrderBy(I("b").Desc()).Inherit("w1") -> ("w1" PARTITION BY "a" ORDER BY "b" DESC) +// W("w") -> "w" AS () +// W("w", "w1") -> "w" AS ("w1") +// W("w").Inherit("w1") -> "w" AS ("w1") +// W("w").PartitionBy("a") -> "w" AS (PARTITION BY "a") +// W("w", "w1").PartitionBy("a") -> "w" AS ("w1" PARTITION BY "a") +// W("w", "w1").PartitionBy("a").OrderBy("b") -> "w" AS ("w1" PARTITION BY "a" ORDER BY "b") +func W(ws ...string) exp.WindowExpression { + switch len(ws) { + case 0: + return emptyWindow + case 1: + return exp.NewWindowExpression(I(ws[0]), nil, nil, nil) + default: + return exp.NewWindowExpression(I(ws[0]), I(ws[1]), nil, nil) + } +} + // Creates a new ON clause to be used within a join // ds.Join(goqu.T("my_table"), goqu.On( // goqu.I("my_table.fkey").Eq(goqu.I("other_table.id")), @@ -184,12 +256,12 @@ func Using(columns ...interface{}) exp.JoinCondition { // Literals can also contain placeholders for other expressions // L("(? AND ?) OR (?)", I("a").Eq(1), I("b").Eq("b"), I("c").In([]string{"a", "b", "c"})) func L(sql string, args ...interface{}) exp.LiteralExpression { - return exp.NewLiteralExpression(sql, args...) + return Literal(sql, args...) } // Alias for goqu.L func Literal(sql string, args ...interface{}) exp.LiteralExpression { - return L(sql, args...) + return exp.NewLiteralExpression(sql, args...) } // Create a new SQL value ( alias for goqu.L("?", val) ). The prrimary use case for this would be in selects. diff --git a/expressions_example_test.go b/expressions_example_test.go index 12449574..126fd6c0 100644 --- a/expressions_example_test.go +++ b/expressions_example_test.go @@ -1710,3 +1710,36 @@ func ExampleVals() { // Output: // INSERT INTO "user" ("first_name", "last_name", "is_verified") VALUES ('Greg', 'Farley', TRUE), ('Jimmy', 'Stewart', TRUE), ('Jeff', 'Jeffers', FALSE) [] } + +func ExampleW() { + ds := goqu.From("test"). + Select(goqu.ROW_NUMBER().Over(goqu.W().PartitionBy("a").OrderBy(goqu.I("b").Asc()))) + query, args, _ := ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().OverName(goqu.I("w"))). + Window(goqu.W("w").PartitionBy("a").OrderBy(goqu.I("b").Asc())) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().OverName(goqu.I("w1"))). + Window( + goqu.W("w1").PartitionBy("a"), + goqu.W("w").Inherit("w1").OrderBy(goqu.I("b").Asc()), + ) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().Over(goqu.W().Inherit("w").OrderBy("b"))). + Window(goqu.W("w").PartitionBy("a")) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + // Output + // SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "b" ASC) FROM "test" [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w" AS (PARTITION BY "a" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w1" AS (PARTITION BY "a"), "w" AS ("w1" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER ("w" ORDER BY "b") FROM "test" WINDOW "w" AS (PARTITION BY "a") [] +} diff --git a/expressions_test.go b/expressions_test.go new file mode 100644 index 00000000..3b837af0 --- /dev/null +++ b/expressions_test.go @@ -0,0 +1,174 @@ +package goqu + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/stretchr/testify/suite" +) + +type ( + goquExpressionsSuite struct { + suite.Suite + } +) + +func (ges *goquExpressionsSuite) TestCast() { + ges.Equal(exp.NewCastExpression(C("test"), "string"), Cast(C("test"), "string")) +} + +func (ges *goquExpressionsSuite) TestDoNothing() { + ges.Equal(exp.NewDoNothingConflictExpression(), DoNothing()) +} + +func (ges *goquExpressionsSuite) TestDoUpdate() { + ges.Equal(exp.NewDoUpdateConflictExpression("test", Record{"a": "b"}), DoUpdate("test", Record{"a": "b"})) +} + +func (ges *goquExpressionsSuite) TestOr() { + e1 := C("a").Eq("b") + e2 := C("b").Eq(2) + ges.Equal(exp.NewExpressionList(exp.OrType, e1, e2), Or(e1, e2)) +} + +func (ges *goquExpressionsSuite) TestAnd() { + e1 := C("a").Eq("b") + e2 := C("b").Eq(2) + ges.Equal(exp.NewExpressionList(exp.AndType, e1, e2), And(e1, e2)) +} + +func (ges *goquExpressionsSuite) TestFunc() { + ges.Equal(exp.NewSQLFunctionExpression("count", L("*")), Func("count", L("*"))) +} + +func (ges *goquExpressionsSuite) TestDISTINCT() { + ges.Equal(exp.NewSQLFunctionExpression("DISTINCT", I("col")), DISTINCT("col")) +} + +func (ges *goquExpressionsSuite) TestCOUNT() { + ges.Equal(exp.NewSQLFunctionExpression("COUNT", I("col")), COUNT("col")) +} + +func (ges *goquExpressionsSuite) TestMIN() { + ges.Equal(exp.NewSQLFunctionExpression("MIN", I("col")), MIN("col")) +} + +func (ges *goquExpressionsSuite) TestMAX() { + ges.Equal(exp.NewSQLFunctionExpression("MAX", I("col")), MAX("col")) +} + +func (ges *goquExpressionsSuite) TestAVG() { + ges.Equal(exp.NewSQLFunctionExpression("AVG", I("col")), AVG("col")) +} + +func (ges *goquExpressionsSuite) TestFIRST() { + ges.Equal(exp.NewSQLFunctionExpression("FIRST", I("col")), FIRST("col")) +} + +func (ges *goquExpressionsSuite) TestLAST() { + ges.Equal(exp.NewSQLFunctionExpression("LAST", I("col")), LAST("col")) +} + +func (ges *goquExpressionsSuite) TestSUM() { + ges.Equal(exp.NewSQLFunctionExpression("SUM", I("col")), SUM("col")) +} + +func (ges *goquExpressionsSuite) TestCOALESCE() { + ges.Equal(exp.NewSQLFunctionExpression("COALESCE", I("col"), nil), COALESCE(I("col"), nil)) +} + +func (ges *goquExpressionsSuite) TestROW_NUMBER() { + ges.Equal(exp.NewSQLFunctionExpression("ROW_NUMBER"), ROW_NUMBER()) +} + +func (ges *goquExpressionsSuite) TestRANK() { + ges.Equal(exp.NewSQLFunctionExpression("RANK"), RANK()) +} + +func (ges *goquExpressionsSuite) TestDENSE_RANK() { + ges.Equal(exp.NewSQLFunctionExpression("DENSE_RANK"), DENSE_RANK()) +} + +func (ges *goquExpressionsSuite) TestPERCENT_RANK() { + ges.Equal(exp.NewSQLFunctionExpression("PERCENT_RANK"), PERCENT_RANK()) +} + +func (ges *goquExpressionsSuite) TestCUME_DIST() { + ges.Equal(exp.NewSQLFunctionExpression("CUME_DIST"), CUME_DIST()) +} + +func (ges *goquExpressionsSuite) TestNTILE() { + ges.Equal(exp.NewSQLFunctionExpression("NTILE", 1), NTILE(1)) +} + +func (ges *goquExpressionsSuite) TestFIRST_VALUE() { + ges.Equal(exp.NewSQLFunctionExpression("FIRST_VALUE", I("col")), FIRST_VALUE("col")) +} + +func (ges *goquExpressionsSuite) TestLAST_VALUE() { + ges.Equal(exp.NewSQLFunctionExpression("LAST_VALUE", I("col")), LAST_VALUE("col")) +} + +func (ges *goquExpressionsSuite) TestNTH_VALUE() { + ges.Equal(exp.NewSQLFunctionExpression("NTH_VALUE", I("col"), 1), NTH_VALUE("col", 1)) + ges.Equal(exp.NewSQLFunctionExpression("NTH_VALUE", I("col"), 1), NTH_VALUE(C("col"), 1)) +} + +func (ges *goquExpressionsSuite) TestI() { + ges.Equal(exp.NewIdentifierExpression("s", "t", "c"), I("s.t.c")) +} + +func (ges *goquExpressionsSuite) TestC() { + ges.Equal(exp.NewIdentifierExpression("", "", "c"), C("c")) +} + +func (ges *goquExpressionsSuite) TestS() { + ges.Equal(exp.NewIdentifierExpression("s", "", ""), S("s")) +} + +func (ges *goquExpressionsSuite) TestT() { + ges.Equal(exp.NewIdentifierExpression("", "t", ""), T("t")) +} + +func (ges *goquExpressionsSuite) TestW() { + ges.Equal(emptyWindow, W()) + ges.Equal(exp.NewWindowExpression(I("a"), nil, nil, nil), W("a")) + ges.Equal(exp.NewWindowExpression(I("a"), I("b"), nil, nil), W("a", "b")) + ges.Equal(exp.NewWindowExpression(I("a"), I("b"), nil, nil), W("a", "b", "c")) +} + +func (ges *goquExpressionsSuite) TestOn() { + ges.Equal(exp.NewJoinOnCondition(Ex{"a": "b"}), On(Ex{"a": "b"})) +} + +func (ges *goquExpressionsSuite) TestUsing() { + ges.Equal(exp.NewJoinUsingCondition("a", "b"), Using("a", "b")) +} + +func (ges *goquExpressionsSuite) TestL() { + ges.Equal(exp.NewLiteralExpression("? + ?", 1, 2), L("? + ?", 1, 2)) +} + +func (ges *goquExpressionsSuite) TestLiteral() { + ges.Equal(exp.NewLiteralExpression("? + ?", 1, 2), Literal("? + ?", 1, 2)) +} + +func (ges *goquExpressionsSuite) TestV() { + ges.Equal(exp.NewLiteralExpression("?", "a"), V("a")) +} + +func (ges *goquExpressionsSuite) TestRange() { + ges.Equal(exp.NewRangeVal("a", "b"), Range("a", "b")) +} + +func (ges *goquExpressionsSuite) TestStar() { + ges.Equal(exp.NewLiteralExpression("*"), Star()) +} + +func (ges *goquExpressionsSuite) TestDefault() { + ges.Equal(exp.Default(), Default()) +} + +func TestGoquExpressions(t *testing.T) { + suite.Run(t, new(goquExpressionsSuite)) +} diff --git a/select_dataset.go b/select_dataset.go index 2852af42..90167bc6 100644 --- a/select_dataset.go +++ b/select_dataset.go @@ -496,6 +496,21 @@ func (sd *SelectDataset) As(alias string) *SelectDataset { return sd.copy(sd.clauses.SetAlias(T(alias))) } +// Sets the WINDOW clauses +func (sd *SelectDataset) Window(ws ...exp.WindowExpression) *SelectDataset { + return sd.copy(sd.clauses.SetWindows(ws)) +} + +// Sets the WINDOW clauses +func (sd *SelectDataset) WindowAppend(ws ...exp.WindowExpression) *SelectDataset { + return sd.copy(sd.clauses.WindowsAppend(ws...)) +} + +// Sets the WINDOW clauses +func (sd *SelectDataset) ClearWindow() *SelectDataset { + return sd.copy(sd.clauses.ClearWindows()) +} + // Generates a SELECT sql statement, if Prepared has been called with true then the parameters will not be interpolated. // See examples. // diff --git a/select_dataset_example_test.go b/select_dataset_example_test.go index 39e891ca..08f1e39d 100644 --- a/select_dataset_example_test.go +++ b/select_dataset_example_test.go @@ -329,6 +329,39 @@ func ExampleSelectDataset_Having() { // SELECT * FROM "test" GROUP BY "age" HAVING (SUM("income") > 1000) } +func ExampleSelectDataset_Window() { + ds := goqu.From("test"). + Select(goqu.ROW_NUMBER().Over(goqu.W().PartitionBy("a").OrderBy(goqu.I("b").Asc()))) + query, args, _ := ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().OverName(goqu.I("w"))). + Window(goqu.W("w").PartitionBy("a").OrderBy(goqu.I("b").Asc())) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().OverName(goqu.I("w1"))). + Window( + goqu.W("w1").PartitionBy("a"), + goqu.W("w").Inherit("w1").OrderBy(goqu.I("b").Asc()), + ) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select(goqu.ROW_NUMBER().Over(goqu.W().Inherit("w").OrderBy("b"))). + Window(goqu.W("w").PartitionBy("a")) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + // Output + // SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "b" ASC) FROM "test" [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w" AS (PARTITION BY "a" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w1" AS (PARTITION BY "a"), "w" AS ("w1" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER ("w" ORDER BY "b") FROM "test" WINDOW "w" AS (PARTITION BY "a") [] +} + func ExampleSelectDataset_Where() { // By default everything is anded together sql, _, _ := goqu.From("test").Where(goqu.Ex{ diff --git a/select_dataset_test.go b/select_dataset_test.go index f6ae286b..a33cc414 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -735,14 +735,77 @@ func (sds *selectDatasetSuite) TestGroupBy() { ) } -func (sds *selectDatasetSuite) TestHaving() { - ds := From("test") - dsc := ds.GetClauses() - h := C("a").Gt(1) - ec := dsc.HavingAppend(h) - sds.Equal(ec, ds.Having(h).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) +func (sds *selectDatasetSuite) TestWindow() { + w1 := W("w1").PartitionBy("a").OrderBy("b") + w2 := W("w2").PartitionBy("a").OrderBy("b") + + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Window(w1), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w1), + }, + selectTestCase{ + ds: bd.Window(w1).Window(w2), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w2), + }, + selectTestCase{ + ds: bd.Window(w1, w2), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w1, w2), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + ) +} +func (sds *selectDatasetSuite) TestWindowAppend() { + w1 := W("w1").PartitionBy("a").OrderBy("b") + w2 := W("w2").PartitionBy("a").OrderBy("b") + + bd := From("test").Window(w1) + sds.assertCases( + selectTestCase{ + ds: bd.WindowAppend(w2), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w1, w2), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w1), + }, + ) +} + +func (sds *selectDatasetSuite) TestClearWindow() { + w1 := W("w1").PartitionBy("a").OrderBy("b") + + bd := From("test").Window(w1) + sds.assertCases( + selectTestCase{ + ds: bd.ClearWindow(), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WindowsAppend(w1), + }, + ) +} + +func (sds *selectDatasetSuite) TestHaving() { bd := From("test") sds.assertCases( selectTestCase{ diff --git a/sqlgen/expression_sql_generator.go b/sqlgen/expression_sql_generator.go index f7afd2b6..275df752 100644 --- a/sqlgen/expression_sql_generator.go +++ b/sqlgen/expression_sql_generator.go @@ -34,7 +34,8 @@ var ( TrueLiteral = exp.NewLiteralExpression("TRUE") FalseLiteral = exp.NewLiteralExpression("FALSE") - errEmptyIdentifier = errors.New(`a empty identifier was encountered, please specify a "schema", "table" or "column"`) + errEmptyIdentifier = errors.New(`a empty identifier was encountered, please specify a "schema", "table" or "column"`) + errUnexpectedNamedWindow = errors.New(`unexpected named window function`) ) func errUnsupportedExpressionType(e exp.Expression) error { @@ -158,6 +159,10 @@ func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp esg.updateExpressionSQL(b, e) case exp.SQLFunctionExpression: esg.sqlFunctionExpressionSQL(b, e) + case exp.SQLWindowFunctionExpression: + esg.sqlWindowFunctionExpression(b, e) + case exp.WindowExpression: + esg.windowExpressionSQL(b, e) case exp.CastExpression: esg.castExpressionSQL(b, e) case exp.AppendableExpression: @@ -481,6 +486,63 @@ func (esg *expressionSQLGenerator) sqlFunctionExpressionSQL(b sb.SQLBuilder, sql esg.Generate(b, sqlFunc.Args()) } +func (esg *expressionSQLGenerator) sqlWindowFunctionExpression(b sb.SQLBuilder, sqlWinFunc exp.SQLWindowFunctionExpression) { + if !esg.dialectOptions.SupportsWindowFunction { + b.SetError(errWindowNotSupported(esg.dialect)) + return + } + esg.Generate(b, sqlWinFunc.Func()) + b.Write(esg.dialectOptions.WindowOverFragment) + switch { + case sqlWinFunc.HasWindowName(): + esg.Generate(b, sqlWinFunc.WindowName()) + case sqlWinFunc.HasWindow(): + if sqlWinFunc.Window().HasName() { + b.SetError(errUnexpectedNamedWindow) + return + } + esg.Generate(b, sqlWinFunc.Window()) + default: + esg.Generate(b, exp.NewWindowExpression(nil, nil, nil, nil)) + } +} + +func (esg *expressionSQLGenerator) windowExpressionSQL(b sb.SQLBuilder, we exp.WindowExpression) { + if !esg.dialectOptions.SupportsWindowFunction { + b.SetError(errWindowNotSupported(esg.dialect)) + return + } + if we.HasName() { + esg.Generate(b, we.Name()) + b.Write(esg.dialectOptions.AsFragment) + } + b.WriteRunes(esg.dialectOptions.LeftParenRune) + + hasPartition := we.HasPartitionBy() + hasOrder := we.HasOrder() + + if we.HasParent() { + esg.Generate(b, we.Parent()) + if hasPartition || hasOrder { + b.WriteRunes(esg.dialectOptions.SpaceRune) + } + } + + if hasPartition { + b.Write(esg.dialectOptions.WindowPartitionByFragment) + esg.Generate(b, we.PartitionCols()) + if hasOrder { + b.WriteRunes(esg.dialectOptions.SpaceRune) + } + } + if hasOrder { + b.Write(esg.dialectOptions.WindowOrderByFragment) + esg.Generate(b, we.OrderCols()) + } + + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + // Generates SQL for a CastExpression // I("a").Cast("NUMERIC") -> CAST("a" AS NUMERIC) func (esg *expressionSQLGenerator) castExpressionSQL(b sb.SQLBuilder, cast exp.CastExpression) { diff --git a/sqlgen/expression_sql_generator_test.go b/sqlgen/expression_sql_generator_test.go index 5b28e07c..30129da1 100644 --- a/sqlgen/expression_sql_generator_test.go +++ b/sqlgen/expression_sql_generator_test.go @@ -706,6 +706,130 @@ func (esgs *expressionSQLGeneratorSuite) TestGenerate_SQLFunctionExpression() { ) } +func (esgs *expressionSQLGeneratorSuite) TestGenerate_SQLWindowFunctionExpression() { + sqlWinFunc := exp.NewSQLWindowFunctionExpression( + exp.NewSQLFunctionExpression("some_func"), + nil, + exp.NewWindowExpression( + nil, + exp.NewIdentifierExpression("", "", "win"), + nil, + nil, + ), + ) + sqlWinFuncFromWindow := exp.NewSQLWindowFunctionExpression( + exp.NewSQLFunctionExpression("some_func"), + exp.NewIdentifierExpression("", "", "win"), + nil, + ) + + emptyWinFunc := exp.NewSQLWindowFunctionExpression( + exp.NewSQLFunctionExpression("some_func"), + nil, + nil, + ) + badNamedSQLWinFuncInherit := exp.NewSQLWindowFunctionExpression( + exp.NewSQLFunctionExpression("some_func"), + nil, + exp.NewWindowExpression( + exp.NewIdentifierExpression("", "", "w"), + nil, + nil, + nil, + ), + ) + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: sqlWinFunc, sql: `some_func() OVER ("win")`}, + expressionTestCase{val: sqlWinFunc, sql: `some_func() OVER ("win")`, isPrepared: true}, + + expressionTestCase{val: sqlWinFuncFromWindow, sql: `some_func() OVER "win"`}, + expressionTestCase{val: sqlWinFuncFromWindow, sql: `some_func() OVER "win"`, isPrepared: true}, + + expressionTestCase{val: emptyWinFunc, sql: `some_func() OVER ()`}, + expressionTestCase{val: emptyWinFunc, sql: `some_func() OVER ()`, isPrepared: true}, + + expressionTestCase{val: badNamedSQLWinFuncInherit, err: errUnexpectedNamedWindow.Error()}, + expressionTestCase{val: badNamedSQLWinFuncInherit, err: errUnexpectedNamedWindow.Error(), isPrepared: true}, + ) + opts := DefaultDialectOptions() + opts.SupportsWindowFunction = false + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: sqlWinFunc, err: errWindowNotSupported("test").Error()}, + expressionTestCase{val: sqlWinFunc, err: errWindowNotSupported("test").Error(), isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_WindowExpression() { + opts := DefaultDialectOptions() + opts.WindowPartitionByFragment = []byte("partition by ") + opts.WindowOrderByFragment = []byte("order by ") + + emptySQLWinFunc := exp.NewWindowExpression(nil, nil, nil, nil) + namedSQLWinFunc := exp.NewWindowExpression( + exp.NewIdentifierExpression("", "", "w"), nil, nil, nil, + ) + inheritSQLWinFunc := exp.NewWindowExpression( + nil, exp.NewIdentifierExpression("", "", "w"), nil, nil, + ) + partitionBySQLWinFunc := exp.NewWindowExpression( + nil, nil, exp.NewColumnListExpression("a", "b"), nil, + ) + orderBySQLWinFunc := exp.NewWindowExpression( + nil, nil, nil, exp.NewOrderedColumnList( + exp.NewIdentifierExpression("", "", "a").Asc(), + exp.NewIdentifierExpression("", "", "b").Desc(), + ), + ) + + namedInheritPartitionOrderSQLWinFunc := exp.NewWindowExpression( + exp.NewIdentifierExpression("", "", "w1"), + exp.NewIdentifierExpression("", "", "w2"), + exp.NewColumnListExpression("a", "b"), + exp.NewOrderedColumnList( + exp.NewIdentifierExpression("", "", "a").Asc(), + exp.NewIdentifierExpression("", "", "b").Desc(), + ), + ) + + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: emptySQLWinFunc, sql: `()`}, + expressionTestCase{val: emptySQLWinFunc, sql: `()`, isPrepared: true}, + + expressionTestCase{val: namedSQLWinFunc, sql: `"w" AS ()`}, + expressionTestCase{val: namedSQLWinFunc, sql: `"w" AS ()`, isPrepared: true}, + + expressionTestCase{val: inheritSQLWinFunc, sql: `("w")`}, + expressionTestCase{val: inheritSQLWinFunc, sql: `("w")`, isPrepared: true}, + + expressionTestCase{val: partitionBySQLWinFunc, sql: `(partition by "a", "b")`}, + expressionTestCase{val: partitionBySQLWinFunc, sql: `(partition by "a", "b")`, isPrepared: true}, + + expressionTestCase{val: orderBySQLWinFunc, sql: `(order by "a" ASC, "b" DESC)`}, + expressionTestCase{val: orderBySQLWinFunc, sql: `(order by "a" ASC, "b" DESC)`, isPrepared: true}, + + expressionTestCase{ + val: namedInheritPartitionOrderSQLWinFunc, + sql: `"w1" AS ("w2" partition by "a", "b" order by "a" ASC, "b" DESC)`, + }, + expressionTestCase{ + val: namedInheritPartitionOrderSQLWinFunc, + sql: `"w1" AS ("w2" partition by "a", "b" order by "a" ASC, "b" DESC)`, + isPrepared: true, + }, + ) + + opts = DefaultDialectOptions() + opts.SupportsWindowFunction = false + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: emptySQLWinFunc, err: errWindowNotSupported("test").Error()}, + expressionTestCase{val: emptySQLWinFunc, err: errWindowNotSupported("test").Error(), isPrepared: true}, + ) +} + func (esgs *expressionSQLGeneratorSuite) TestGenerate_CastExpression() { cast := exp.NewIdentifierExpression("", "", "a").Cast("DATE") esgs.assertCases( diff --git a/sqlgen/select_sql_generator.go b/sqlgen/select_sql_generator.go index b9d358de..2ae28c2d 100644 --- a/sqlgen/select_sql_generator.go +++ b/sqlgen/select_sql_generator.go @@ -32,6 +32,12 @@ func errDistinctOnNotSupported(dialect string) error { return errors.New("dialect does not support DISTINCT ON clause [dialect=%s]", dialect) } +func errWindowNotSupported(dialect string) error { + return errors.New("dialect does not support WINDOW clause [dialect=%s]", dialect) +} + +var errNoWindowName = errors.New("window expresion has no valid name") + func NewSelectSQLGenerator(dialect string, do *SQLDialectOptions) SelectSQLGenerator { return &selectSQLGenerator{newCommonSQLGenerator(dialect, do)} } @@ -60,6 +66,8 @@ func (ssg *selectSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.SelectClaus ssg.GroupBySQL(b, clauses.GroupBy()) case HavingSQLFragment: ssg.HavingSQL(b, clauses.Having()) + case WindowSQLFragment: + ssg.WindowSQL(b, clauses.Windows()) case CompoundsSQLFragment: ssg.CompoundsSQL(b, clauses.Compounds()) case OrderSQLFragment: @@ -184,6 +192,27 @@ func (ssg *selectSQLGenerator) ForSQL(b sb.SQLBuilder, lockingClause exp.Lock) { } } +func (ssg *selectSQLGenerator) WindowSQL(b sb.SQLBuilder, windows []exp.WindowExpression) { + weLen := len(windows) + if weLen == 0 { + return + } + if !ssg.dialectOptions.SupportsWindowFunction { + b.SetError(errWindowNotSupported(ssg.dialect)) + return + } + b.Write(ssg.dialectOptions.WindowFragment) + for i, we := range windows { + if !we.HasName() { + b.SetError(errNoWindowName) + } + ssg.esg.Generate(b, we) + if i < weLen-1 { + b.WriteRunes(ssg.dialectOptions.CommaRune, ssg.dialectOptions.SpaceRune) + } + } +} + func (ssg *selectSQLGenerator) joinConditionSQL(b sb.SQLBuilder, jc exp.JoinCondition) { switch t := jc.(type) { case exp.JoinOnCondition: diff --git a/sqlgen/select_sql_generator_test.go b/sqlgen/select_sql_generator_test.go index 838f69e4..12195822 100644 --- a/sqlgen/select_sql_generator_test.go +++ b/sqlgen/select_sql_generator_test.go @@ -89,6 +89,54 @@ func (ssgs *selectSQLGeneratorSuite) TestGenerate_WithErroredBuilder() { ssgs.assertErrorSQL(b, `goqu: test error`) } +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withSelectedColumns() { + opts := DefaultDialectOptions() + // make sure the fragments are used + opts.SelectClause = []byte("select") + opts.StarRune = '#' + opts.SupportsDistinctOn = true + + sc := exp.NewSelectClauses() + scCols := sc.SetSelect(exp.NewColumnListExpression("a", "b")) + scFuncs := sc.SetSelect(exp.NewColumnListExpression( + exp.NewSQLFunctionExpression("COUNT", exp.Star()), + exp.NewSQLFunctionExpression("RANK"), + )) + + we := exp.NewWindowExpression( + nil, + nil, + exp.NewColumnListExpression("a", "b"), + exp.NewOrderedColumnList(exp.ParseIdentifier("c").Asc()), + ) + scFuncsPartition := sc.SetSelect(exp.NewColumnListExpression( + exp.NewSQLFunctionExpression("COUNT", exp.Star()).Over(we), + exp.NewSQLFunctionExpression("RANK").Over(we.Inherit("w")), + )) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, sql: `select #`}, + selectTestCase{clause: sc, sql: `select #`, isPrepared: true}, + + selectTestCase{clause: scCols, sql: `select "a", "b"`}, + selectTestCase{clause: scCols, sql: `select "a", "b"`, isPrepared: true}, + + selectTestCase{clause: scFuncs, sql: `select COUNT(*), RANK()`}, + selectTestCase{clause: scFuncs, sql: `select COUNT(*), RANK()`, isPrepared: true}, + + selectTestCase{ + clause: scFuncsPartition, + sql: `select COUNT(*) OVER (PARTITION BY "a", "b" ORDER BY "c" ASC), RANK() OVER ("w" PARTITION BY "a", "b" ORDER BY "c" ASC)`, + }, + selectTestCase{ + clause: scFuncsPartition, + sql: `select COUNT(*) OVER (PARTITION BY "a", "b" ORDER BY "c" ASC), RANK() OVER ("w" PARTITION BY "a", "b" ORDER BY "c" ASC)`, + isPrepared: true, + }, + ) +} + func (ssgs *selectSQLGeneratorSuite) TestGenerate_withDistinct() { opts := DefaultDialectOptions() // make sure the fragments are used @@ -263,6 +311,131 @@ func (ssgs *selectSQLGeneratorSuite) TestGenerate_withHaving() { ) } +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withWindow() { + opts := DefaultDialectOptions() + opts.WindowFragment = []byte(" window ") + opts.WindowPartitionByFragment = []byte("partition by ") + opts.WindowOrderByFragment = []byte("order by ") + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + we1 := exp.NewWindowExpression( + exp.NewIdentifierExpression("", "", "w"), + nil, + nil, + nil, + ) + wePartitionBy := we1.PartitionBy("a", "b") + weOrderBy := we1.OrderBy("a", "b") + + weOrderAndPartitionBy := we1.PartitionBy("a", "b").OrderBy("a", "b") + + weInherits := exp.NewWindowExpression( + exp.NewIdentifierExpression("", "", "w2"), + exp.NewIdentifierExpression("", "", "w"), + nil, + nil, + ) + weInheritsPartitionBy := weInherits.PartitionBy("c", "d") + weInheritsOrderBy := weInherits.OrderBy("c", "d") + + weInheritsOrderAndPartitionBy := weInherits.PartitionBy("c", "d").OrderBy("c", "d") + + scNoName := sc.WindowsAppend(exp.NewWindowExpression(nil, nil, nil, nil)) + + scWindow1 := sc.WindowsAppend(we1) + scWindow2 := sc.WindowsAppend(wePartitionBy) + scWindow3 := sc.WindowsAppend(weOrderBy) + scWindow4 := sc.WindowsAppend(weOrderAndPartitionBy) + + scWindow5 := sc.WindowsAppend(we1, weInherits) + scWindow6 := sc.WindowsAppend(we1, weInheritsPartitionBy) + scWindow7 := sc.WindowsAppend(we1, weInheritsOrderBy) + scWindow8 := sc.WindowsAppend(we1, weInheritsOrderAndPartitionBy) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + + selectTestCase{clause: scNoName, err: errNoWindowName.Error()}, + selectTestCase{clause: scNoName, err: errNoWindowName.Error(), isPrepared: true}, + + selectTestCase{clause: scWindow1, sql: `SELECT * FROM "test" window "w" AS ()`}, + selectTestCase{clause: scWindow1, sql: `SELECT * FROM "test" window "w" AS ()`, isPrepared: true}, + + selectTestCase{clause: scWindow2, sql: `SELECT * FROM "test" window "w" AS (partition by "a", "b")`}, + selectTestCase{ + clause: scWindow2, + sql: `SELECT * FROM "test" window "w" AS (partition by "a", "b")`, + isPrepared: true, + }, + + selectTestCase{clause: scWindow3, sql: `SELECT * FROM "test" window "w" AS (order by "a", "b")`}, + selectTestCase{ + clause: scWindow3, + sql: `SELECT * FROM "test" window "w" AS (order by "a", "b")`, + isPrepared: true, + }, + + selectTestCase{ + clause: scWindow4, + sql: `SELECT * FROM "test" window "w" AS (partition by "a", "b" order by "a", "b")`, + }, + selectTestCase{ + clause: scWindow4, + sql: `SELECT * FROM "test" window "w" AS (partition by "a", "b" order by "a", "b")`, + isPrepared: true, + }, + + selectTestCase{ + clause: scWindow5, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w")`, + }, + selectTestCase{ + clause: scWindow5, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w")`, + isPrepared: true, + }, + + selectTestCase{ + clause: scWindow6, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" partition by "c", "d")`, + }, + selectTestCase{ + clause: scWindow6, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" partition by "c", "d")`, + isPrepared: true, + }, + + selectTestCase{ + clause: scWindow7, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" order by "c", "d")`, + }, + selectTestCase{ + clause: scWindow7, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" order by "c", "d")`, + isPrepared: true, + }, + + selectTestCase{ + clause: scWindow8, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" partition by "c", "d" order by "c", "d")`, + }, + selectTestCase{ + clause: scWindow8, + sql: `SELECT * FROM "test" window "w" AS (), "w2" AS ("w" partition by "c", "d" order by "c", "d")`, + isPrepared: true, + }, + ) + + opts = DefaultDialectOptions() + opts.SupportsWindowFunction = false + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + + selectTestCase{clause: scWindow1, err: errWindowNotSupported("test").Error()}, + selectTestCase{clause: scWindow1, err: errWindowNotSupported("test").Error(), isPrepared: true}, + ) +} + func (ssgs *selectSQLGeneratorSuite) TestGenerate_withOrder() { sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). SetOrder( diff --git a/sqlgen/sql_dialect_options.go b/sqlgen/sql_dialect_options.go index 174adf8f..b9bc0f20 100644 --- a/sqlgen/sql_dialect_options.go +++ b/sqlgen/sql_dialect_options.go @@ -37,6 +37,9 @@ type ( // Set to false if the dialect does not require expressions to be wrapped in parens (DEFAULT=true) WrapCompoundsInParens bool + // Set to true if window function are supported in SELECT statement. (DEFAULT=true) + SupportsWindowFunction bool + // Set to true if the dialect requires join tables in UPDATE to be in a FROM clause (DEFAULT=true). UseFromClauseForMultipleUpdateTables bool @@ -85,8 +88,16 @@ type ( WhereFragment []byte // The SQL GROUP BY clause fragment(DEFAULT=[]byte(" GROUP BY ")) GroupByFragment []byte - // The SQL HAVING clause fragment(DELiFAULT=[]byte(" HAVING ")) + // The SQL HAVING clause fragment(DEFAULT=[]byte(" HAVING ")) HavingFragment []byte + // The SQL WINDOW clause fragment(DEFAULT=[]byte(" WINDOW ")) + WindowFragment []byte + // The SQL WINDOW clause PARTITION BY fragment(DEFAULT=[]byte("PARTITION BY ")) + WindowPartitionByFragment []byte + // The SQL WINDOW clause ORDER BY fragment(DEFAULT=[]byte("ORDER BY ")) + WindowOrderByFragment []byte + // The SQL WINDOW clause OVER fragment(DEFAULT=[]byte(" OVER ")) + WindowOverFragment []byte // The SQL ORDER BY clause fragment(DEFAULT=[]byte(" ORDER BY ")) OrderByFragment []byte // The SQL LIMIT BY clause fragment(DEFAULT=[]byte(" LIMIT ")) @@ -304,6 +315,7 @@ const ( InsertSQLFragment DeleteBeginSQLFragment TruncateSQLFragment + WindowSQLFragment ) // nolint:gocyclo @@ -351,6 +363,8 @@ func (sf SQLFragmentType) String() string { return "DeleteBeginSQLFragment" case TruncateSQLFragment: return "TruncateSQLFragment" + case WindowSQLFragment: + return "WindowSQLFragment" } return fmt.Sprintf("%d", sf) } @@ -369,6 +383,7 @@ func DefaultDialectOptions() *SQLDialectOptions { SupportsWithCTERecursive: true, SupportsDistinctOn: true, WrapCompoundsInParens: true, + SupportsWindowFunction: true, SupportsMultipleUpdateTables: true, UseFromClauseForMultipleUpdateTables: true, @@ -395,6 +410,10 @@ func DefaultDialectOptions() *SQLDialectOptions { WhereFragment: []byte(" WHERE "), GroupByFragment: []byte(" GROUP BY "), HavingFragment: []byte(" HAVING "), + WindowFragment: []byte(" WINDOW "), + WindowPartitionByFragment: []byte("PARTITION BY "), + WindowOrderByFragment: []byte("ORDER BY "), + WindowOverFragment: []byte(" OVER "), OrderByFragment: []byte(" ORDER BY "), LimitFragment: []byte(" LIMIT "), OffsetFragment: []byte(" OFFSET "), @@ -489,6 +508,7 @@ func DefaultDialectOptions() *SQLDialectOptions { WhereSQLFragment, GroupBySQLFragment, HavingSQLFragment, + WindowSQLFragment, CompoundsSQLFragment, OrderSQLFragment, LimitSQLFragment, diff --git a/sqlgen/sql_dialect_options_test.go b/sqlgen/sql_dialect_options_test.go new file mode 100644 index 00000000..9a4f0f62 --- /dev/null +++ b/sqlgen/sql_dialect_options_test.go @@ -0,0 +1,48 @@ +package sqlgen + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type sqlFragmentTypeSuite struct { + suite.Suite +} + +func (sfts *sqlFragmentTypeSuite) TestOptions_SQLFragmentType() { + for _, tt := range []struct { + typ SQLFragmentType + expectedStr string + }{ + {typ: CommonTableSQLFragment, expectedStr: "CommonTableSQLFragment"}, + {typ: SelectSQLFragment, expectedStr: "SelectSQLFragment"}, + {typ: FromSQLFragment, expectedStr: "FromSQLFragment"}, + {typ: JoinSQLFragment, expectedStr: "JoinSQLFragment"}, + {typ: WhereSQLFragment, expectedStr: "WhereSQLFragment"}, + {typ: GroupBySQLFragment, expectedStr: "GroupBySQLFragment"}, + {typ: HavingSQLFragment, expectedStr: "HavingSQLFragment"}, + {typ: CompoundsSQLFragment, expectedStr: "CompoundsSQLFragment"}, + {typ: OrderSQLFragment, expectedStr: "OrderSQLFragment"}, + {typ: LimitSQLFragment, expectedStr: "LimitSQLFragment"}, + {typ: OffsetSQLFragment, expectedStr: "OffsetSQLFragment"}, + {typ: ForSQLFragment, expectedStr: "ForSQLFragment"}, + {typ: UpdateBeginSQLFragment, expectedStr: "UpdateBeginSQLFragment"}, + {typ: SourcesSQLFragment, expectedStr: "SourcesSQLFragment"}, + {typ: IntoSQLFragment, expectedStr: "IntoSQLFragment"}, + {typ: UpdateSQLFragment, expectedStr: "UpdateSQLFragment"}, + {typ: UpdateFromSQLFragment, expectedStr: "UpdateFromSQLFragment"}, + {typ: ReturningSQLFragment, expectedStr: "ReturningSQLFragment"}, + {typ: InsertBeingSQLFragment, expectedStr: "InsertBeingSQLFragment"}, + {typ: DeleteBeginSQLFragment, expectedStr: "DeleteBeginSQLFragment"}, + {typ: TruncateSQLFragment, expectedStr: "TruncateSQLFragment"}, + {typ: WindowSQLFragment, expectedStr: "WindowSQLFragment"}, + {typ: SQLFragmentType(10000), expectedStr: "10000"}, + } { + sfts.Equal(tt.expectedStr, tt.typ.String()) + } +} + +func TestSQLFragmentType(t *testing.T) { + suite.Run(t, new(sqlFragmentTypeSuite)) +}