Skip to content

Commit

Permalink
store: use transactions for multiple writes or read+write
Browse files Browse the repository at this point in the history
  • Loading branch information
jfontan committed Jul 17, 2021
1 parent 9d1de89 commit 9a71e57
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 66 deletions.
165 changes: 99 additions & 66 deletions server/store/effects.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,92 +221,105 @@ UPDATE effects
)

func (s *Effects) AddEffect(e Effect) error {
effect := sqliteFromEffect(e)
_, err := s.db.NamedExec(sqlInsertEffectID, effect)
if err != nil {
return fmt.Errorf("could not insert effect: %w", err)
}
return s.transaction(func(tx *sqlx.Tx) error {
effect := sqliteFromEffect(e)
_, err := tx.NamedExec(sqlInsertEffectID, effect)
if err != nil {
return fmt.Errorf("could not insert effect: %w", err)
}

for i, v := range e.Versions {
version := sqliteFromversion(v)
version.Version = i
version.Effect = effect.ID
for i, v := range e.Versions {
version := sqliteFromversion(v)
version.Version = i
version.Effect = effect.ID

_, err := s.db.NamedExec(sqlInsertVersion, version)
if err != nil {
return fmt.Errorf("could not insert version: %w", err)
_, err := tx.NamedExec(sqlInsertVersion, version)
if err != nil {
return fmt.Errorf("could not insert version: %w", err)
}
}
}

return nil
return nil
})
}

func (s *Effects) Add(
parent int, parentVersion int, user string, version string,
) (int, error) {
t := time.Now()
e := sqliteEffect{
CreatedAt: t,
ModifiedAt: t,
Parent: parent,
ParentVersion: parentVersion,
User: user,
}
var lastID int
err := s.transaction(func(tx *sqlx.Tx) error {
t := time.Now()
e := sqliteEffect{
CreatedAt: t,
ModifiedAt: t,
Parent: parent,
ParentVersion: parentVersion,
User: user,
}

r, err := s.db.NamedExec(sqlInsertEffect, e)
if err != nil {
return -1, fmt.Errorf("could not insert effect: %w", err)
}
r, err := tx.NamedExec(sqlInsertEffect, e)
if err != nil {
return fmt.Errorf("could not insert effect: %w", err)
}

id, err := r.LastInsertId()
if err != nil {
return -1, fmt.Errorf("could not get effect id: %w", err)
}
id, err := r.LastInsertId()
if err != nil {
return fmt.Errorf("could not get effect id: %w", err)
}
lastID = int(id)

v := sqliteVersion{
Version: 0,
Effect: int(id),
CreatedAt: t,
Code: version,
}
_, err = s.db.NamedExec(sqlInsertVersion, v)
if err != nil {
return int(id), fmt.Errorf("could not insert version: %w", err)
}
v := sqliteVersion{
Version: 0,
Effect: int(id),
CreatedAt: t,
Code: version,
}
_, err = tx.NamedExec(sqlInsertVersion, v)
if err != nil {
return fmt.Errorf("could not insert version: %w", err)
}
return nil
})

return int(id), nil
return lastID, err
}

func (s *Effects) AddVersion(id int, code string) (int, error) {
t := time.Now()
var maxVersion *int
r := s.db.QueryRowx(sqlSelectMaxVersion, id)
err := r.Scan(&maxVersion)
if err != nil {
return -1, fmt.Errorf("could not get max version: %w", err)
}
var lastVersion int
err := s.transaction(func(tx *sqlx.Tx) error {
t := time.Now()
var maxVersion *int
r := tx.QueryRowx(sqlSelectMaxVersion, id)
err := r.Scan(&maxVersion)
if err != nil {
return fmt.Errorf("could not get max version: %w", err)
}

if maxVersion == nil {
return -1, ErrNotFound
}
if maxVersion == nil {
return ErrNotFound
}

version := sqliteVersion{
Version: *maxVersion + 1,
Effect: id,
CreatedAt: t,
Code: code,
}
_, err = s.db.NamedExec(sqlInsertVersion, version)
if err != nil {
return -1, fmt.Errorf("could not insert version: %w", err)
}
version := sqliteVersion{
Version: *maxVersion + 1,
Effect: id,
CreatedAt: t,
Code: code,
}
_, err = tx.NamedExec(sqlInsertVersion, version)
if err != nil {
return fmt.Errorf("could not insert version: %w", err)
}
lastVersion = version.Version

_, err = s.db.Exec(sqlUpdateEffectModification, t, id)
if err != nil {
return -1, fmt.Errorf("could not update effect: %w", err)
}
_, err = tx.Exec(sqlUpdateEffectModification, t, id)
if err != nil {
return fmt.Errorf("could not update effect: %w", err)
}

return version.Version, nil
return nil
})

return lastVersion, err
}

func (s *Effects) Page(num int, size int, hidden bool) ([]Effect, error) {
Expand Down Expand Up @@ -401,6 +414,26 @@ func (s *Effects) Hide(id int, hidden bool) error {
return nil
}

func (s *Effects) transaction(f func(*sqlx.Tx) error) error {
tx, err := s.db.Beginx()
if err != nil {
return fmt.Errorf("could not create transaction: %w", err)
}

err = f(tx)
if err != nil {
_ = tx.Rollback()
return err
}

err = tx.Commit()
if err != nil {
return fmt.Errorf("could not commit transaction: %w", err)
}

return nil
}

func sqliteToEffect(e sqliteEffect) Effect {
n := Effect{
ID: e.ID,
Expand Down
52 changes: 52 additions & 0 deletions server/store/users.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package store

import (
"database/sql"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -143,3 +145,53 @@ func (s *Users) Update(user User) error {
}
return nil
}

func (s *Users) UpdateFunc(name string, f func(User) User) error {
return s.transaction(func(tx *sqlx.Tx) error {
var u User
r := tx.QueryRowx(sqlSelectUser, name)
err := r.StructScan(&u)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound
}
return fmt.Errorf("could not get user: %w", err)
}

u = f(u)

res, err := tx.NamedExec(sqlUpdateUser, u)
if err != nil {
return fmt.Errorf("could not update user: %w", err)
}

rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("could not get affected rows: %w", err)
}
if rows == 0 {
return ErrNotFound
}
return nil
})
}

func (s *Users) transaction(f func(*sqlx.Tx) error) error {
tx, err := s.db.Beginx()
if err != nil {
return fmt.Errorf("could not create transaction: %w", err)
}

err = f(tx)
if err != nil {
_ = tx.Rollback()
return err
}

err = tx.Commit()
if err != nil {
return fmt.Errorf("could not commit transaction: %w", err)
}

return nil
}
39 changes: 39 additions & 0 deletions server/store/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,42 @@ func TestUserUpdate(t *testing.T) {
require.Error(t, err)
require.True(t, errors.Is(err, ErrNotFound))
}

func TestUserUpdateFunc(t *testing.T) {
db, err := sqlx.Connect("sqlite", testDatabase)
require.NoError(t, err)

users, err := NewUsers(db)
require.NoError(t, err)

err = users.Add(testUser)
require.NoError(t, err)

expected := User{
Name: "test",
Password: []byte("newpassword"),
Email: "newemail",
Role: RoleModerator,
Active: false,
CreatedAt: time.Now(),
}
err = users.UpdateFunc("test", func(u User) User {
return expected
})
require.NoError(t, err)

u, err := users.User("test")
require.NoError(t, err)
require.Equal(t, expected.Name, u.Name)
require.Equal(t, expected.Password, u.Password)
require.Equal(t, expected.Email, u.Email)
require.Equal(t, expected.Role, u.Role)
require.Equal(t, expected.Active, u.Active)
require.True(t, expected.CreatedAt.Equal(u.CreatedAt))

err = users.UpdateFunc("inexistent", func(u User) User {
return u
})
require.Error(t, err)
require.True(t, errors.Is(err, ErrNotFound))
}

0 comments on commit 9a71e57

Please sign in to comment.