Skip to content

Commit

Permalink
Ensure that only one upgrade can exist at a time
Browse files Browse the repository at this point in the history
It doesn't really make sense to have multiple upgrades running at a same
time. So we may as well explictly enforce this to simplify other
operations

This meant changing the ActiveUpgrades to ActiveUpgrade, and changing
it's signature to return a single string, erroring with "NotFound" if
none exist

CreateUpgrade now checks if an upgrade exists before inserting a new one
  • Loading branch information
jack-w-shaw committed Jul 21, 2023
1 parent 1be8b50 commit fac5037
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 66 deletions.
14 changes: 7 additions & 7 deletions domain/upgrade/service/package_mock_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions domain/upgrade/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package service

import (
"context"
"database/sql"

"github.com/juju/errors"
"github.com/juju/version/v2"
Expand All @@ -21,7 +22,7 @@ type State interface {
AllProvisionedControllersReady(context.Context, string) (bool, error)
StartUpgrade(context.Context, string) error
SetControllerDone(context.Context, string, string) error
ActiveUpgrades(context.Context) ([]string, error)
ActiveUpgrade(context.Context) (string, error)
}

// Service provides the API for working with upgrade info
Expand Down Expand Up @@ -79,8 +80,12 @@ func (s *Service) SetControllerDone(ctx context.Context, upgradeUUID, controller
return domain.CoerceError(s.st.SetControllerDone(ctx, upgradeUUID, controllerID))
}

// IsUpgrading returns true if an upgrade is currently in progress.
func (s *Service) ActiveUpgrades(ctx context.Context) ([]string, error) {
activeUpgrades, err := s.st.ActiveUpgrades(ctx)
// ActiveUpgrade returns the uuid of the current active upgrade.
// If there are no active upgrades, return a NotFound error
func (s *Service) ActiveUpgrade(ctx context.Context) (string, error) {
activeUpgrades, err := s.st.ActiveUpgrade(ctx)
if errors.Is(err, sql.ErrNoRows) {
return "", errors.NotFoundf("active upgrade")
}
return activeUpgrades, domain.CoerceError(err)
}
8 changes: 4 additions & 4 deletions domain/upgrade/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ func (s *serviceSuite) TestStartUpgradeBeforeCreated(c *gc.C) {
c.Assert(errors.IsNotFound(err), jc.IsTrue)
}

func (s *serviceSuite) TestActiveUpgrades(c *gc.C) {
func (s *serviceSuite) TestActiveUpgrade(c *gc.C) {
defer s.setupMocks(c).Finish()

s.state.EXPECT().ActiveUpgrades(gomock.Any()).Return([]string{testUUID1, testUUID2}, nil)
s.state.EXPECT().ActiveUpgrade(gomock.Any()).Return(testUUID1, nil)

activeUpgrades, err := NewService(s.state).ActiveUpgrades(context.Background())
activeUpgrade, err := NewService(s.state).ActiveUpgrade(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Assert(activeUpgrades, gc.DeepEquals, []string{testUUID1, testUUID2})
c.Assert(activeUpgrade, gc.Equals, testUUID1)
}
49 changes: 29 additions & 20 deletions domain/upgrade/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func NewState(factory database.TxnRunnerFactory) *State {
}

// CreateUpgrade creates an active upgrade to and from specified versions
// and returns the upgrade's UUID
// and returns the upgrade's UUID. If an active upgrade already exists,
// return an AlreadyExists error
func (st *State) CreateUpgrade(ctx context.Context, previousVersion, targetVersion version.Number) (string, error) {
db, err := st.DB()
if err != nil {
Expand All @@ -40,10 +41,22 @@ func (st *State) CreateUpgrade(ctx context.Context, previousVersion, targetVersi
if err != nil {
return "", errors.Trace(err)
}
q := "INSERT INTO upgrade_info (uuid, previous_version, target_version, created_at) VALUES (?, ?, ?, DATETIME('now'))"
checkExistingQuery := "SELECT COUNT(*) FROM upgrade_info WHERE completed_at IS NULL"
createUpgradeQuery := "INSERT INTO upgrade_info (uuid, previous_version, target_version, created_at) VALUES (?, ?, ?, DATETIME('now'))"

err = db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, q, upgradeUUID.String(), previousVersion.String(), targetVersion.String())
row := tx.QueryRowContext(ctx, checkExistingQuery)
if err := row.Err(); err != nil {
return errors.Trace(err)
}
var count int
if err := row.Scan(&count); err != nil {
return errors.Trace(err)
}
if count != 0 {
return errors.AlreadyExistsf("active upgrade")
}
_, err := tx.ExecContext(ctx, createUpgradeQuery, upgradeUUID.String(), previousVersion.String(), targetVersion.String())
return errors.Trace(err)
})

Expand Down Expand Up @@ -276,29 +289,25 @@ AND (
return nil
}

// ActiveUpgrades returns a slice of the uuids of all active upgrades
func (st *State) ActiveUpgrades(ctx context.Context) ([]string, error) {
// ActiveUpgrade returns the uuids of the active upgrades, returning
// a NotFound error if there are none
func (st *State) ActiveUpgrade(ctx context.Context) (string, error) {
db, err := st.DB()
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
var activeUpgrades []string
var activeUpgrade string
q := "SELECT (uuid) FROM upgrade_info WHERE completed_at IS NULL"

err = db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error {
q := "SELECT (uuid) FROM upgrade_info WHERE completed_at IS NULL"
rows, err := tx.QueryContext(ctx, q)
if err != nil && err != sql.ErrNoRows {
row := tx.QueryRowContext(ctx, q)
if err := row.Err(); err != nil {
return errors.Trace(err)
}
defer func() { _ = rows.Close() }()

for rows.Next() {
var uuid string
if err := rows.Scan(&uuid); err != nil {
return errors.Trace(err)
}
activeUpgrades = append(activeUpgrades, uuid)
if err := row.Scan(&activeUpgrade); err != nil {
return errors.Trace(err)
}
return rows.Err()
return nil
})
return activeUpgrades, errors.Trace(err)
return activeUpgrade, errors.Trace(err)
}
55 changes: 24 additions & 31 deletions domain/upgrade/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"database/sql"

"github.com/canonical/sqlair"
"github.com/juju/errors"
jc "github.com/juju/testing/checkers"
"github.com/juju/utils/v3"
"github.com/juju/version/v2"
Expand Down Expand Up @@ -90,6 +91,14 @@ func (s *stateSuite) TestCreateUpgrade(c *gc.C) {
c.Check(nodeInfos, gc.HasLen, 0)
}

func (s *stateSuite) TestCreateUpgradeAlreadyExists(c *gc.C) {
_, err := s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.0"), version.MustParse("3.0.1"))
c.Assert(err, jc.ErrorIsNil)

_, err = s.st.CreateUpgrade(context.Background(), version.MustParse("4.0.0"), version.MustParse("4.0.1"))
c.Assert(errors.IsAlreadyExists(err), jc.IsTrue)
}

func (s *stateSuite) TestSetControllerReady(c *gc.C) {
uuid, err := s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.0"), version.MustParse("3.0.1"))
c.Assert(err, jc.ErrorIsNil)
Expand Down Expand Up @@ -211,7 +220,7 @@ func (s *stateSuite) TestStartUpgradeIdempotent(c *gc.C) {
func (s *stateSuite) TestStartUpgradeBeforeCreated(c *gc.C) {
uuid := utils.MustNewUUID().String()
err := s.st.StartUpgrade(context.Background(), uuid)
c.Assert(err, gc.ErrorMatches, sql.ErrNoRows.Error())
c.Assert(errors.Is(err, sql.ErrNoRows), jc.IsTrue)
}

func (s *stateSuite) TestSetControllerDone(c *gc.C) {
Expand Down Expand Up @@ -251,9 +260,8 @@ func (s *stateSuite) TestSetControllerDoneCompleteUpgrade(c *gc.C) {
_, err := db.Exec("INSERT INTO controller_node (controller_id, dqlite_node_id) VALUES ('1', 1)")
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err := s.st.ActiveUpgrades(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 0)
_, err = s.st.ActiveUpgrade(context.Background())
c.Assert(errors.Is(err, sql.ErrNoRows), jc.IsTrue)

uuid, err := s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.0"), version.MustParse("3.0.1"))
c.Assert(err, jc.ErrorIsNil)
Expand All @@ -265,16 +273,15 @@ func (s *stateSuite) TestSetControllerDoneCompleteUpgrade(c *gc.C) {
err = s.st.SetControllerDone(context.Background(), uuid, "0")
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err = s.st.ActiveUpgrades(context.Background())
activeUpgrade, err := s.st.ActiveUpgrade(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 1)
c.Check(activeUpgrade, gc.Not(gc.Equals), "")

err = s.st.SetControllerDone(context.Background(), uuid, "1")
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err = s.st.ActiveUpgrades(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 0)
activeUpgrade, err = s.st.ActiveUpgrade(context.Background())
c.Assert(errors.Is(err, sql.ErrNoRows), jc.IsTrue)
}

func (s *stateSuite) TestSetControllerDoneCompleteUpgradeEmptyCompletedAt(c *gc.C) {
Expand All @@ -297,41 +304,27 @@ WHERE upgrade_info_uuid = ?
AND controller_node_id = 0`, uuid)
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err := s.st.ActiveUpgrades(context.Background())
activeUpgrade, err := s.st.ActiveUpgrade(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 1)
c.Assert(activeUpgrade, gc.Not(gc.Equals), "")

err = s.st.SetControllerDone(context.Background(), uuid, "1")
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err = s.st.ActiveUpgrades(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 0)
activeUpgrade, err = s.st.ActiveUpgrade(context.Background())
c.Assert(errors.Is(err, sql.ErrNoRows), jc.IsTrue)
}

func (s *stateSuite) TestActiveUpgradesNoUpgrades(c *gc.C) {
activeUpgrades, err := s.st.ActiveUpgrades(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 0)
_, err := s.st.ActiveUpgrade(context.Background())
c.Assert(errors.Is(err, sql.ErrNoRows), jc.IsTrue)
}

func (s *stateSuite) TestActiveUpgradesSingular(c *gc.C) {
uuid, err := s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.0"), version.MustParse("3.0.1"))
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err := s.st.ActiveUpgrades(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.DeepEquals, []string{uuid})
}

func (s *stateSuite) TestActiveUpgradesMultiple(c *gc.C) {
_, err := s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.0"), version.MustParse("3.0.1"))
c.Assert(err, jc.ErrorIsNil)

_, err = s.st.CreateUpgrade(context.Background(), version.MustParse("3.0.1"), version.MustParse("3.0.2"))
c.Assert(err, jc.ErrorIsNil)

activeUpgrades, err := s.st.ActiveUpgrades(context.Background())
activeUpgrade, err := s.st.ActiveUpgrade(context.Background())
c.Assert(err, jc.ErrorIsNil)
c.Check(activeUpgrades, gc.HasLen, 2)
c.Check(activeUpgrade, gc.Equals, uuid)
}

0 comments on commit fac5037

Please sign in to comment.