Skip to content

Commit

Permalink
v8.1.0
Browse files Browse the repository at this point in the history
* [ADDED] Multi table update support for `mysql` and `postgres` #60
  • Loading branch information
doug-martin committed Jul 25, 2019
1 parent f38b0e2 commit 395454a
Show file tree
Hide file tree
Showing 15 changed files with 702 additions and 483 deletions.
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v8.0.1

* [ADDED] Multi table update support for `mysql` and `postgres` [#60](https://github.com/doug-martin/goqu/issues/60)

## v8.0.0

A major change the the API was made in `v8` to seperate concerns between the different SQL statement types.
Expand Down
2 changes: 2 additions & 0 deletions dialect/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ func DialectOptions() *goqu.SQLDialectOptions {
opts.SupportsWithCTE = false
opts.SupportsWithCTERecursive = false

opts.UseFromClauseForMultipleUpdateTables = false

opts.PlaceHolderRune = '?'
opts.IncludePlaceholderNum = false
opts.QuoteRune = '`'
Expand Down
10 changes: 10 additions & 0 deletions dialect/mysql/mysql_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,17 @@ func (mds *mysqlDialectSuite) TestBooleanOperations() {
sql, _, err = ds.Where(col.NotILike(regexp.MustCompile("(a|b)"))).ToSQL()
assert.NoError(t, err)
assert.Equal(t, sql, "SELECT * FROM `test` WHERE (`a` NOT REGEXP '(a|b)')")
}

func (mds *mysqlDialectSuite) TestUpdateSQL() {
ds := mds.GetDs("test").Update()
sql, _, err := ds.
Set(goqu.Record{"foo": "bar"}).
From("test_2").
Where(goqu.I("test.id").Eq(goqu.I("test_2.test_id"))).
ToSQL()
mds.NoError(err)
mds.Equal("UPDATE `test`,`test_2` SET `foo`='bar' WHERE (`test`.`id` = `test_2`.`test_id`)", sql)
}

func TestDatasetAdapterSuite(t *testing.T) {
Expand Down
11 changes: 11 additions & 0 deletions dialect/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,17 @@ func (pt *postgresTest) TestUpdate() {
assert.Equal(t, id, e.ID)
}

func (pt *postgresTest) TestUpdateSQL_multipleTables() {
ds := pt.db.Update("test")
updateSQL, _, err := ds.
Set(goqu.Record{"foo": "bar"}).
From("test_2").
Where(goqu.I("test.id").Eq(goqu.I("test_2.test_id"))).
ToSQL()
pt.NoError(err)
pt.Equal(`UPDATE "test" SET "foo"='bar' FROM "test_2" WHERE ("test"."id" = "test_2"."test_id")`, updateSQL)
}

func (pt *postgresTest) TestDelete() {
t := pt.T()
ds := pt.db.From("entry")
Expand Down
1 change: 1 addition & 0 deletions dialect/sqlite3/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func DialectOptions() *goqu.SQLDialectOptions {
opts.SupportsConflictUpdateWhere = false
opts.SupportsInsertIgnoreSyntax = true
opts.SupportsConflictTarget = false
opts.SupportsMultipleUpdateTables = false
opts.WrapCompoundsInParens = false

opts.PlaceHolderRune = '?'
Expand Down
10 changes: 10 additions & 0 deletions dialect/sqlite3/sqlite3_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ func (sds *sqlite3DialectSuite) TestIdentifiers() {
assert.Equal(t, sql, "SELECT `a`, `a`.`b`.`c`, `c`.`d`, `test` AS `test` FROM `test`")
}

func (sds *sqlite3DialectSuite) TestUpdateSQL_multipleTables() {
ds := sds.GetDs("test").Update()
_, _, err := ds.
Set(goqu.Record{"foo": "bar"}).
From("test_2").
Where(goqu.I("test.id").Eq(goqu.I("test_2.test_id"))).
ToSQL()
sds.EqualError(err, "goqu: sqlite3 dialect does not support multiple tables in UPDATE")
}

func (sds *sqlite3DialectSuite) TestCompoundExpressions() {
t := sds.T()
ds1 := sds.GetDs("test").Select("a")
Expand Down
45 changes: 45 additions & 0 deletions docs/updating.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* [Set with `goqu.Record`](#set-record)
* [Set with struct](#set-struct)
* [Set with map](#set-map)
* [Multi Table](#from)
* [Where](#where)
* [Order](#order)
* [Limit](#limit)
Expand Down Expand Up @@ -167,6 +168,50 @@ Output:
UPDATE "items" SET "address"='111 Test Addr',"name"='Test' []
```

<a name="from"></a>
**[From / Multi Table](https://godoc.org/github.com/doug-martin/goqu/#UpdateDataset.From)**

`goqu` allows joining multiple tables in a update clause through `From`.

**NOTE** The `sqlite3` adapter does not support a multi table syntax.

`Postgres` Example

```go
dialect := goqu.Dialect("postgres")

ds := dialect.Update("table_one").
Set(goqu.Record{"foo": goqu.I("table_two.bar")}).
From("table_two").
Where(goqu.Ex{"table_one.id": goqu.I("table_two.id")})

sql, _, _ := ds.ToSQL()
fmt.Println(sql)
```

Output:
```sql
UPDATE "table_one" SET "foo"="table_two"."bar" FROM "table_two" WHERE ("table_one"."id" = "table_two"."id")
```

`MySQL` Example

```go
dialect := goqu.Dialect("mysql")

ds := dialect.Update("table_one").
Set(goqu.Record{"foo": goqu.I("table_two.bar")}).
From("table_two").
Where(goqu.Ex{"table_one.id": goqu.I("table_two.id")})

sql, _, _ := ds.ToSQL()
fmt.Println(sql)
```
Output:
```sql
UPDATE `table_one`,`table_two` SET `foo`=`table_two`.`bar` WHERE (`table_one`.`id` = `table_two`.`id`)
```

<a name="where"></a>
**[Where](https://godoc.org/github.com/doug-martin/goqu/#UpdateDataset.Where)**

Expand Down
26 changes: 26 additions & 0 deletions exp/update_clauses.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ type (
SetTable(table Expression) UpdateClauses

SetValues() interface{}
HasSetValues() bool
SetSetValues(values interface{}) UpdateClauses

From() ColumnListExpression
HasFrom() bool
SetFrom(tables ColumnListExpression) UpdateClauses

Where() ExpressionList
ClearWhere() UpdateClauses
WhereAppend(expressions ...Expression) UpdateClauses
Expand All @@ -38,6 +43,7 @@ type (
commonTables []CommonTableExpression
table Expression
setValues interface{}
from ColumnListExpression
where ExpressionList
order ColumnListExpression
limit interface{}
Expand All @@ -58,6 +64,7 @@ func (uc *updateClauses) clone() *updateClauses {
commonTables: uc.commonTables,
table: uc.table,
setValues: uc.setValues,
from: uc.from,
where: uc.where,
order: uc.order,
limit: uc.limit,
Expand Down Expand Up @@ -86,12 +93,31 @@ func (uc *updateClauses) SetTable(table Expression) UpdateClauses {
func (uc *updateClauses) SetValues() interface{} {
return uc.setValues
}

func (uc *updateClauses) HasSetValues() bool {
return uc.setValues != nil
}

func (uc *updateClauses) SetSetValues(values interface{}) UpdateClauses {
ret := uc.clone()
ret.setValues = values
return ret
}

func (uc *updateClauses) From() ColumnListExpression {
return uc.from
}

func (uc *updateClauses) HasFrom() bool {
return uc.from != nil && !uc.from.IsEmpty()
}

func (uc *updateClauses) SetFrom(from ColumnListExpression) UpdateClauses {
ret := uc.clone()
ret.from = from
return ret
}

func (uc *updateClauses) Where() ExpressionList {
return uc.where
}
Expand Down
23 changes: 23 additions & 0 deletions exp/update_clauses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@ func (ucs *updateClausesSuite) TestSetSetValues() {
assert.Equal(t, r2, c2.SetValues())
}

func (ucs *updateClausesSuite) TestFrom() {
t := ucs.T()
c := NewUpdateClauses()
ce := NewColumnListExpression("a", "b")
c2 := c.SetFrom(ce)

assert.Nil(t, c.From())

assert.Equal(t, ce, c2.From())
}

func (ucs *updateClausesSuite) TestSetFrom() {
t := ucs.T()
ce1 := NewColumnListExpression("a", "b")
c := NewUpdateClauses().SetFrom(ce1)
ce2 := NewColumnListExpression("a", "b")
c2 := c.SetFrom(ce2)

assert.Equal(t, ce1, c.From())

assert.Equal(t, ce2, c2.From())
}

func (ucs *updateClausesSuite) TestWhere() {
t := ucs.T()
w := Ex{"a": 1}
Expand Down
50 changes: 44 additions & 6 deletions sql_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
errNoSourceForDelete = errors.New("no source found when generating delete sql")
errNoSourceForTruncate = errors.New("no source found when generating truncate sql")
errReturnNotSupported = errors.New("adapter does not support RETURNING clause")
errNoSetValuesForUpdate = errors.New("no set values found when generating UPDATE sql")
)

func notSupportedFragmentErr(sqlType string, f SQLFragmentType) error {
Expand Down Expand Up @@ -172,15 +173,22 @@ func (d *sqlDialect) ToSelectSQL(b sb.SQLBuilder, clauses exp.SelectClauses) {
}

func (d *sqlDialect) ToUpdateSQL(b sb.SQLBuilder, clauses exp.UpdateClauses) {
if !clauses.HasTable() {
b.SetError(errNoSourceForUpdate)
return
}
if !clauses.HasSetValues() {
b.SetError(errNoSetValuesForUpdate)
return
}
if !d.dialectOptions.SupportsMultipleUpdateTables && clauses.HasFrom() {
b.SetError(errors.New("%s dialect does not support multiple tables in UPDATE", d.dialect))
}
updates, err := exp.NewUpdateExpressions(clauses.SetValues())
if err != nil {
b.SetError(err)
return
}
if !clauses.HasTable() {
b.SetError(errNoSourceForUpdate)
return
}
for _, f := range d.dialectOptions.UpdateSQLOrder {
if b.Error() != nil {
return
Expand All @@ -191,10 +199,11 @@ func (d *sqlDialect) ToUpdateSQL(b sb.SQLBuilder, clauses exp.UpdateClauses) {
case UpdateBeginSQLFragment:
d.UpdateBeginSQL(b)
case SourcesSQLFragment:
b.WriteRunes(d.dialectOptions.SpaceRune)
d.Literal(b, clauses.Table())
d.updateTableSQL(b, clauses)
case UpdateSQLFragment:
d.UpdateExpressionsSQL(b, updates...)
case UpdateFromSQLFragment:
d.updateFromSQL(b, clauses.From())
case WhereSQLFragment:
d.WhereSQL(b, clauses.Where())
case OrderSQLFragment:
Expand Down Expand Up @@ -716,6 +725,23 @@ func (d *sqlDialect) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpression) {
}
}

func (d *sqlDialect) updateTableSQL(b sb.SQLBuilder, uc exp.UpdateClauses) {
if b.Error() != nil {
return
}
b.WriteRunes(d.dialectOptions.SpaceRune)
d.Literal(b, uc.Table())
if b.Error() != nil {
return
}
if uc.HasFrom() {
if !d.dialectOptions.UseFromClauseForMultipleUpdateTables {
b.WriteRunes(d.dialectOptions.CommaRune)
d.Literal(b, uc.From())
}
}
}

// Adds column setters in an update SET clause
func (d *sqlDialect) updateValuesSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) {
if len(updates) == 0 {
Expand All @@ -731,6 +757,18 @@ func (d *sqlDialect) updateValuesSQL(b sb.SQLBuilder, updates ...exp.UpdateExpre
}
}

func (d *sqlDialect) updateFromSQL(b sb.SQLBuilder, ce exp.ColumnListExpression) {
if b.Error() != nil {
return
}
if ce == nil || ce.IsEmpty() {
return
}
if d.dialectOptions.UseFromClauseForMultipleUpdateTables {
d.FromSQL(b, ce)
}
}

func (d *sqlDialect) onConflictDoUpdateSQL(b sb.SQLBuilder, o exp.ConflictUpdateExpression) {
b.Write(d.dialectOptions.ConflictDoUpdateFragment)
update := o.Update()
Expand Down
12 changes: 12 additions & 0 deletions sql_dialect_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ type (
SupportsWithCTE bool
// Set to true if the dialect supports recursive Common Table Expressions (DEFAULT=true)
SupportsWithCTERecursive bool
// Set to true if multiple tables are supported in UPDATE statement. (DEFAULT=true)
SupportsMultipleUpdateTables bool
// Set to false if the dialect does not require expressions to be wrapped in parens (DEFAULT=true)
WrapCompoundsInParens bool

// Set to true if the dialect requires join tables in UPDATE to be in a FROM clause (DEFAULT=true).
UseFromClauseForMultipleUpdateTables bool

// The UPDATE fragment to use when generating sql. (DEFAULT=[]byte("UPDATE"))
UpdateClause []byte
// The INSERT fragment to use when generating sql. (DEFAULT=[]byte("INSERT INTO"))
Expand Down Expand Up @@ -291,6 +296,7 @@ const (
SourcesSQLFragment
IntoSQLFragment
UpdateSQLFragment
UpdateFromSQLFragment
ReturningSQLFragment
InsertBeingSQLFragment
InsertSQLFragment
Expand Down Expand Up @@ -333,6 +339,8 @@ func (sf SQLFragmentType) String() string {
return "IntoSQLFragment"
case UpdateSQLFragment:
return "UpdateSQLFragment"
case UpdateFromSQLFragment:
return "UpdateFromSQLFragment"
case ReturningSQLFragment:
return "ReturningSQLFragment"
case InsertBeingSQLFragment:
Expand All @@ -359,6 +367,9 @@ func DefaultDialectOptions() *SQLDialectOptions {
SupportsWithCTERecursive: true,
WrapCompoundsInParens: true,

SupportsMultipleUpdateTables: true,
UseFromClauseForMultipleUpdateTables: true,

UpdateClause: []byte("UPDATE"),
InsertClause: []byte("INSERT INTO"),
InsertIgnoreClause: []byte("INSERT IGNORE INTO"),
Expand Down Expand Up @@ -486,6 +497,7 @@ func DefaultDialectOptions() *SQLDialectOptions {
UpdateBeginSQLFragment,
SourcesSQLFragment,
UpdateSQLFragment,
UpdateFromSQLFragment,
WhereSQLFragment,
OrderSQLFragment,
LimitSQLFragment,
Expand Down
Loading

0 comments on commit 395454a

Please sign in to comment.