From 9a71e57d3cd8a3b82888d1ca8ddfd6ce742fdd7d Mon Sep 17 00:00:00 2001 From: Javi Fontan Date: Sat, 17 Jul 2021 20:10:59 +0200 Subject: [PATCH] store: use transactions for multiple writes or read+write --- server/store/effects.go | 165 ++++++++++++++++++++++--------------- server/store/users.go | 52 ++++++++++++ server/store/users_test.go | 39 +++++++++ 3 files changed, 190 insertions(+), 66 deletions(-) diff --git a/server/store/effects.go b/server/store/effects.go index 78590b9..0ec0c81 100644 --- a/server/store/effects.go +++ b/server/store/effects.go @@ -221,92 +221,105 @@ UPDATE effects ) func (s *Effects) AddEffect(e Effect) error { - effect := sqliteFromEffect(e) - _, err := s.db.NamedExec(sqlInsertEffectID, effect) - if err != nil { - return fmt.Errorf("could not insert effect: %w", err) - } + return s.transaction(func(tx *sqlx.Tx) error { + effect := sqliteFromEffect(e) + _, err := tx.NamedExec(sqlInsertEffectID, effect) + if err != nil { + return fmt.Errorf("could not insert effect: %w", err) + } - for i, v := range e.Versions { - version := sqliteFromversion(v) - version.Version = i - version.Effect = effect.ID + for i, v := range e.Versions { + version := sqliteFromversion(v) + version.Version = i + version.Effect = effect.ID - _, err := s.db.NamedExec(sqlInsertVersion, version) - if err != nil { - return fmt.Errorf("could not insert version: %w", err) + _, err := tx.NamedExec(sqlInsertVersion, version) + if err != nil { + return fmt.Errorf("could not insert version: %w", err) + } } - } - return nil + return nil + }) } func (s *Effects) Add( parent int, parentVersion int, user string, version string, ) (int, error) { - t := time.Now() - e := sqliteEffect{ - CreatedAt: t, - ModifiedAt: t, - Parent: parent, - ParentVersion: parentVersion, - User: user, - } + var lastID int + err := s.transaction(func(tx *sqlx.Tx) error { + t := time.Now() + e := sqliteEffect{ + CreatedAt: t, + ModifiedAt: t, + Parent: parent, + ParentVersion: parentVersion, + User: user, + } - r, err := s.db.NamedExec(sqlInsertEffect, e) - if err != nil { - return -1, fmt.Errorf("could not insert effect: %w", err) - } + r, err := tx.NamedExec(sqlInsertEffect, e) + if err != nil { + return fmt.Errorf("could not insert effect: %w", err) + } - id, err := r.LastInsertId() - if err != nil { - return -1, fmt.Errorf("could not get effect id: %w", err) - } + id, err := r.LastInsertId() + if err != nil { + return fmt.Errorf("could not get effect id: %w", err) + } + lastID = int(id) - v := sqliteVersion{ - Version: 0, - Effect: int(id), - CreatedAt: t, - Code: version, - } - _, err = s.db.NamedExec(sqlInsertVersion, v) - if err != nil { - return int(id), fmt.Errorf("could not insert version: %w", err) - } + v := sqliteVersion{ + Version: 0, + Effect: int(id), + CreatedAt: t, + Code: version, + } + _, err = tx.NamedExec(sqlInsertVersion, v) + if err != nil { + return fmt.Errorf("could not insert version: %w", err) + } + return nil + }) - return int(id), nil + return lastID, err } func (s *Effects) AddVersion(id int, code string) (int, error) { - t := time.Now() - var maxVersion *int - r := s.db.QueryRowx(sqlSelectMaxVersion, id) - err := r.Scan(&maxVersion) - if err != nil { - return -1, fmt.Errorf("could not get max version: %w", err) - } + var lastVersion int + err := s.transaction(func(tx *sqlx.Tx) error { + t := time.Now() + var maxVersion *int + r := tx.QueryRowx(sqlSelectMaxVersion, id) + err := r.Scan(&maxVersion) + if err != nil { + return fmt.Errorf("could not get max version: %w", err) + } - if maxVersion == nil { - return -1, ErrNotFound - } + if maxVersion == nil { + return ErrNotFound + } - version := sqliteVersion{ - Version: *maxVersion + 1, - Effect: id, - CreatedAt: t, - Code: code, - } - _, err = s.db.NamedExec(sqlInsertVersion, version) - if err != nil { - return -1, fmt.Errorf("could not insert version: %w", err) - } + version := sqliteVersion{ + Version: *maxVersion + 1, + Effect: id, + CreatedAt: t, + Code: code, + } + _, err = tx.NamedExec(sqlInsertVersion, version) + if err != nil { + return fmt.Errorf("could not insert version: %w", err) + } + lastVersion = version.Version - _, err = s.db.Exec(sqlUpdateEffectModification, t, id) - if err != nil { - return -1, fmt.Errorf("could not update effect: %w", err) - } + _, err = tx.Exec(sqlUpdateEffectModification, t, id) + if err != nil { + return fmt.Errorf("could not update effect: %w", err) + } - return version.Version, nil + return nil + }) + + return lastVersion, err } func (s *Effects) Page(num int, size int, hidden bool) ([]Effect, error) { @@ -401,6 +414,26 @@ func (s *Effects) Hide(id int, hidden bool) error { return nil } +func (s *Effects) transaction(f func(*sqlx.Tx) error) error { + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not create transaction: %w", err) + } + + err = f(tx) + if err != nil { + _ = tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("could not commit transaction: %w", err) + } + + return nil +} + func sqliteToEffect(e sqliteEffect) Effect { n := Effect{ ID: e.ID, diff --git a/server/store/users.go b/server/store/users.go index d701afa..117dde7 100644 --- a/server/store/users.go +++ b/server/store/users.go @@ -1,6 +1,8 @@ package store import ( + "database/sql" + "errors" "fmt" "time" @@ -143,3 +145,53 @@ func (s *Users) Update(user User) error { } return nil } + +func (s *Users) UpdateFunc(name string, f func(User) User) error { + return s.transaction(func(tx *sqlx.Tx) error { + var u User + r := tx.QueryRowx(sqlSelectUser, name) + err := r.StructScan(&u) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrNotFound + } + return fmt.Errorf("could not get user: %w", err) + } + + u = f(u) + + res, err := tx.NamedExec(sqlUpdateUser, u) + if err != nil { + return fmt.Errorf("could not update user: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("could not get affected rows: %w", err) + } + if rows == 0 { + return ErrNotFound + } + return nil + }) +} + +func (s *Users) transaction(f func(*sqlx.Tx) error) error { + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not create transaction: %w", err) + } + + err = f(tx) + if err != nil { + _ = tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("could not commit transaction: %w", err) + } + + return nil +} diff --git a/server/store/users_test.go b/server/store/users_test.go index 28c5f26..831d2ac 100644 --- a/server/store/users_test.go +++ b/server/store/users_test.go @@ -80,3 +80,42 @@ func TestUserUpdate(t *testing.T) { require.Error(t, err) require.True(t, errors.Is(err, ErrNotFound)) } + +func TestUserUpdateFunc(t *testing.T) { + db, err := sqlx.Connect("sqlite", testDatabase) + require.NoError(t, err) + + users, err := NewUsers(db) + require.NoError(t, err) + + err = users.Add(testUser) + require.NoError(t, err) + + expected := User{ + Name: "test", + Password: []byte("newpassword"), + Email: "newemail", + Role: RoleModerator, + Active: false, + CreatedAt: time.Now(), + } + err = users.UpdateFunc("test", func(u User) User { + return expected + }) + require.NoError(t, err) + + u, err := users.User("test") + require.NoError(t, err) + require.Equal(t, expected.Name, u.Name) + require.Equal(t, expected.Password, u.Password) + require.Equal(t, expected.Email, u.Email) + require.Equal(t, expected.Role, u.Role) + require.Equal(t, expected.Active, u.Active) + require.True(t, expected.CreatedAt.Equal(u.CreatedAt)) + + err = users.UpdateFunc("inexistent", func(u User) User { + return u + }) + require.Error(t, err) + require.True(t, errors.Is(err, ErrNotFound)) +}