From a3912ea98cf09ad317d80351375243da6580b567 Mon Sep 17 00:00:00 2001 From: dave vader <48764154+plyr4@users.noreply.github.com> Date: Fri, 1 Sep 2023 09:45:17 -0500 Subject: [PATCH] enhance: add context to Hooks (#938) --- api/admin/hook.go | 5 +++- api/hook/create.go | 5 ++-- api/hook/delete.go | 5 ++-- api/hook/get.go | 3 +- api/hook/list.go | 3 +- api/hook/redeliver.go | 3 +- api/hook/update.go | 5 ++-- api/repo/create.go | 4 +-- api/repo/repair.go | 4 +-- api/repo/update.go | 2 +- api/scm/sync.go | 2 +- api/scm/sync_org.go | 2 +- api/webhook/post.go | 12 ++++---- database/hook/count.go | 4 ++- database/hook/count_repo.go | 4 ++- database/hook/count_repo_test.go | 7 +++-- database/hook/count_test.go | 7 +++-- database/hook/create.go | 4 ++- database/hook/create_test.go | 3 +- database/hook/delete.go | 4 ++- database/hook/delete_test.go | 5 ++-- database/hook/get.go | 4 ++- database/hook/get_repo.go | 4 ++- database/hook/get_repo_test.go | 5 ++-- database/hook/get_test.go | 5 ++-- database/hook/hook.go | 7 +++-- database/hook/index.go | 4 ++- database/hook/index_test.go | 3 +- database/hook/interface.go | 26 +++++++++-------- database/hook/last_repo.go | 3 +- database/hook/last_repo_test.go | 5 ++-- database/hook/list.go | 6 ++-- database/hook/list_repo.go | 6 ++-- database/hook/list_repo_test.go | 7 +++-- database/hook/list_test.go | 7 +++-- database/hook/opts.go | 11 +++++++ database/hook/opts_test.go | 50 ++++++++++++++++++++++++++++++++ database/hook/table.go | 4 ++- database/hook/table_test.go | 3 +- database/hook/update.go | 4 ++- database/hook/update_test.go | 5 ++-- database/integration_test.go | 20 ++++++------- database/resource.go | 1 + 43 files changed, 197 insertions(+), 86 deletions(-) diff --git a/api/admin/hook.go b/api/admin/hook.go index 789a74a25..7cdf2c5e4 100644 --- a/api/admin/hook.go +++ b/api/admin/hook.go @@ -51,6 +51,9 @@ import ( func UpdateHook(c *gin.Context) { logrus.Info("Admin: updating hook in database") + // capture middleware values + ctx := c.Request.Context() + // capture body from API request input := new(library.Hook) @@ -64,7 +67,7 @@ func UpdateHook(c *gin.Context) { } // send API call to update the hook - h, err := database.FromContext(c).UpdateHook(input) + h, err := database.FromContext(c).UpdateHook(ctx, input) if err != nil { retErr := fmt.Errorf("unable to update hook %d: %w", input.GetID(), err) diff --git a/api/hook/create.go b/api/hook/create.go index b33d9fdfb..234e32eae 100644 --- a/api/hook/create.go +++ b/api/hook/create.go @@ -66,6 +66,7 @@ func CreateHook(c *gin.Context) { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -89,7 +90,7 @@ func CreateHook(c *gin.Context) { } // send API call to capture the last hook for the repo - lastHook, err := database.FromContext(c).LastHookForRepo(r) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, r) if err != nil { retErr := fmt.Errorf("unable to get last hook for repo %s: %w", r.GetFullName(), err) @@ -113,7 +114,7 @@ func CreateHook(c *gin.Context) { } // send API call to create the webhook - h, err := database.FromContext(c).CreateHook(input) + h, err := database.FromContext(c).CreateHook(ctx, input) if err != nil { retErr := fmt.Errorf("unable to create hook for repo %s: %w", r.GetFullName(), err) diff --git a/api/hook/delete.go b/api/hook/delete.go index 7ada1f0fb..3087e928a 100644 --- a/api/hook/delete.go +++ b/api/hook/delete.go @@ -69,6 +69,7 @@ func DeleteHook(c *gin.Context) { r := repo.Retrieve(c) u := user.Retrieve(c) hook := util.PathParameter(c, "hook") + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), hook) @@ -92,7 +93,7 @@ func DeleteHook(c *gin.Context) { } // send API call to capture the webhook - h, err := database.FromContext(c).GetHookForRepo(r, number) + h, err := database.FromContext(c).GetHookForRepo(ctx, r, number) if err != nil { retErr := fmt.Errorf("unable to get hook %s: %w", hook, err) @@ -102,7 +103,7 @@ func DeleteHook(c *gin.Context) { } // send API call to remove the webhook - err = database.FromContext(c).DeleteHook(h) + err = database.FromContext(c).DeleteHook(ctx, h) if err != nil { retErr := fmt.Errorf("unable to delete hook %s: %w", hook, err) diff --git a/api/hook/get.go b/api/hook/get.go index 458183dcb..49ddca5de 100644 --- a/api/hook/get.go +++ b/api/hook/get.go @@ -65,6 +65,7 @@ func GetHook(c *gin.Context) { r := repo.Retrieve(c) u := user.Retrieve(c) hook := util.PathParameter(c, "hook") + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), hook) @@ -88,7 +89,7 @@ func GetHook(c *gin.Context) { } // send API call to capture the webhook - h, err := database.FromContext(c).GetHookForRepo(r, number) + h, err := database.FromContext(c).GetHookForRepo(ctx, r, number) if err != nil { retErr := fmt.Errorf("unable to get hook %s: %w", entry, err) diff --git a/api/hook/list.go b/api/hook/list.go index 2cb6d05fe..bd47e019e 100644 --- a/api/hook/list.go +++ b/api/hook/list.go @@ -80,6 +80,7 @@ func ListHooks(c *gin.Context) { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -114,7 +115,7 @@ func ListHooks(c *gin.Context) { perPage = util.MaxInt(1, util.MinInt(100, perPage)) // send API call to capture the list of steps for the build - h, t, err := database.FromContext(c).ListHooksForRepo(r, page, perPage) + h, t, err := database.FromContext(c).ListHooksForRepo(ctx, r, page, perPage) if err != nil { retErr := fmt.Errorf("unable to get hooks for repo %s: %w", r.GetFullName(), err) diff --git a/api/hook/redeliver.go b/api/hook/redeliver.go index 4afcb8a79..63ab91a9a 100644 --- a/api/hook/redeliver.go +++ b/api/hook/redeliver.go @@ -70,6 +70,7 @@ func RedeliverHook(c *gin.Context) { r := repo.Retrieve(c) u := user.Retrieve(c) hook := util.PathParameter(c, "hook") + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), hook) @@ -93,7 +94,7 @@ func RedeliverHook(c *gin.Context) { } // send API call to capture the webhook - h, err := database.FromContext(c).GetHookForRepo(r, number) + h, err := database.FromContext(c).GetHookForRepo(ctx, r, number) if err != nil { retErr := fmt.Errorf("unable to get hook %s: %w", entry, err) diff --git a/api/hook/update.go b/api/hook/update.go index 2c5cdcfb7..ab2fb5955 100644 --- a/api/hook/update.go +++ b/api/hook/update.go @@ -76,6 +76,7 @@ func UpdateHook(c *gin.Context) { r := repo.Retrieve(c) u := user.Retrieve(c) hook := util.PathParameter(c, "hook") + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), hook) @@ -111,7 +112,7 @@ func UpdateHook(c *gin.Context) { } // send API call to capture the webhook - h, err := database.FromContext(c).GetHookForRepo(r, number) + h, err := database.FromContext(c).GetHookForRepo(ctx, r, number) if err != nil { retErr := fmt.Errorf("unable to get hook %s: %w", entry, err) @@ -157,7 +158,7 @@ func UpdateHook(c *gin.Context) { } // send API call to update the webhook - h, err = database.FromContext(c).UpdateHook(h) + h, err = database.FromContext(c).UpdateHook(ctx, h) if err != nil { retErr := fmt.Errorf("unable to update hook %s: %w", entry, err) diff --git a/api/repo/create.go b/api/repo/create.go index d76aae325..50d2383a0 100644 --- a/api/repo/create.go +++ b/api/repo/create.go @@ -241,7 +241,7 @@ func CreateRepo(c *gin.Context) { // err being nil means we have a record of this repo (dbRepo) if err == nil { - h, _ = database.FromContext(c).LastHookForRepo(dbRepo) + h, _ = database.FromContext(c).LastHookForRepo(ctx, dbRepo) // make sure our record of the repo allowed events matches what we send to SCM // what the dbRepo has should override default events on enable @@ -309,7 +309,7 @@ func CreateRepo(c *gin.Context) { // update initialization hook h.SetRepoID(r.GetID()) // create first hook for repo in the database - _, err = database.FromContext(c).CreateHook(h) + _, err = database.FromContext(c).CreateHook(ctx, h) if err != nil { retErr := fmt.Errorf("unable to create initialization webhook for %s: %w", r.GetFullName(), err) diff --git a/api/repo/repair.go b/api/repo/repair.go index 337f52d91..1f8ffeea8 100644 --- a/api/repo/repair.go +++ b/api/repo/repair.go @@ -78,7 +78,7 @@ func RepairRepo(c *gin.Context) { return } - hook, err := database.FromContext(c).LastHookForRepo(r) + hook, err := database.FromContext(c).LastHookForRepo(ctx, r) if err != nil { retErr := fmt.Errorf("unable to get last hook for %s: %w", r.GetFullName(), err) @@ -108,7 +108,7 @@ func RepairRepo(c *gin.Context) { hook.SetRepoID(r.GetID()) - _, err = database.FromContext(c).CreateHook(hook) + _, err = database.FromContext(c).CreateHook(ctx, hook) if err != nil { retErr := fmt.Errorf("unable to create initialization webhook for %s: %w", r.GetFullName(), err) diff --git a/api/repo/update.go b/api/repo/update.go index a7684c393..6ed328490 100644 --- a/api/repo/update.go +++ b/api/repo/update.go @@ -255,7 +255,7 @@ func UpdateRepo(c *gin.Context) { // if webhook validation is not set or events didn't change, skip webhook update if c.Value("webhookvalidation").(bool) && eventsChanged { // grab last hook from repo to fetch the webhook ID - lastHook, err := database.FromContext(c).LastHookForRepo(r) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, r) if err != nil { retErr := fmt.Errorf("unable to retrieve last hook for repo %s: %w", r.GetFullName(), err) diff --git a/api/scm/sync.go b/api/scm/sync.go index dc264f264..a02981acd 100644 --- a/api/scm/sync.go +++ b/api/scm/sync.go @@ -114,7 +114,7 @@ func SyncRepo(c *gin.Context) { // if we have webhook validation, update the repo hook in the SCM if c.Value("webhookvalidation").(bool) { // grab last hook from repo to fetch the webhook ID - lastHook, err := database.FromContext(c).LastHookForRepo(r) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, r) if err != nil { retErr := fmt.Errorf("unable to retrieve last hook for repo %s: %w", r.GetFullName(), err) diff --git a/api/scm/sync_org.go b/api/scm/sync_org.go index 944a33adc..fb809c74e 100644 --- a/api/scm/sync_org.go +++ b/api/scm/sync_org.go @@ -127,7 +127,7 @@ func SyncReposForOrg(c *gin.Context) { // if we have webhook validation, update the repo hook in the SCM if c.Value("webhookvalidation").(bool) { // grab last hook from repo to fetch the webhook ID - lastHook, err := database.FromContext(c).LastHookForRepo(repo) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, repo) if err != nil { retErr := fmt.Errorf("unable to retrieve last hook for repo %s: %w", repo.GetFullName(), err) diff --git a/api/webhook/post.go b/api/webhook/post.go index e0ef95223..a413a0f0e 100644 --- a/api/webhook/post.go +++ b/api/webhook/post.go @@ -179,7 +179,7 @@ func PostWebhook(c *gin.Context) { defer func() { // send API call to update the webhook - _, err = database.FromContext(c).UpdateHook(h) + _, err = database.FromContext(c).UpdateHook(ctx, h) if err != nil { logrus.Errorf("unable to update webhook %s/%d: %v", r.GetFullName(), h.GetNumber(), err) } @@ -202,7 +202,7 @@ func PostWebhook(c *gin.Context) { h.SetRepoID(repo.GetID()) // send API call to capture the last hook for the repo - lastHook, err := database.FromContext(c).LastHookForRepo(repo) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, repo) if err != nil { retErr := fmt.Errorf("unable to get last hook for repo %s: %w", repo.GetFullName(), err) util.HandleError(c, http.StatusInternalServerError, retErr) @@ -221,7 +221,7 @@ func PostWebhook(c *gin.Context) { } // send API call to create the webhook - h, err = database.FromContext(c).CreateHook(h) + h, err = database.FromContext(c).CreateHook(ctx, h) if err != nil { retErr := fmt.Errorf("unable to create webhook %s/%d: %w", repo.GetFullName(), h.GetNumber(), err) util.HandleError(c, http.StatusInternalServerError, retErr) @@ -688,7 +688,7 @@ func handleRepositoryEvent(ctx context.Context, c *gin.Context, m *types.Metadat defer func() { // send API call to update the webhook - _, err := database.FromContext(c).CreateHook(h) + _, err := database.FromContext(c).CreateHook(ctx, h) if err != nil { logrus.Errorf("unable to create webhook %s/%d: %v", r.GetFullName(), h.GetNumber(), err) } @@ -721,7 +721,7 @@ func handleRepositoryEvent(ctx context.Context, c *gin.Context, m *types.Metadat } // send API call to capture the last hook for the repo - lastHook, err := database.FromContext(c).LastHookForRepo(dbRepo) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, dbRepo) if err != nil { retErr := fmt.Errorf("unable to get last hook for repo %s: %w", r.GetFullName(), err) @@ -797,7 +797,7 @@ func renameRepository(ctx context.Context, h *library.Hook, r *library.Repo, c * h.SetRepoID(r.GetID()) // send API call to capture the last hook for the repo - lastHook, err := database.FromContext(c).LastHookForRepo(dbR) + lastHook, err := database.FromContext(c).LastHookForRepo(ctx, dbR) if err != nil { retErr := fmt.Errorf("unable to get last hook for repo %s: %w", r.GetFullName(), err) util.HandleError(c, http.StatusInternalServerError, retErr) diff --git a/database/hook/count.go b/database/hook/count.go index a1a9689f8..47e9db7be 100644 --- a/database/hook/count.go +++ b/database/hook/count.go @@ -5,11 +5,13 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" ) // CountHooks gets the count of all hooks from the database. -func (e *engine) CountHooks() (int64, error) { +func (e *engine) CountHooks(ctx context.Context) (int64, error) { e.logger.Tracef("getting count of all hooks from the database") // variable to store query results diff --git a/database/hook/count_repo.go b/database/hook/count_repo.go index eb6d6763b..2035ffef5 100644 --- a/database/hook/count_repo.go +++ b/database/hook/count_repo.go @@ -5,13 +5,15 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/library" "github.com/sirupsen/logrus" ) // CountHooksForRepo gets the count of hooks by repo ID from the database. -func (e *engine) CountHooksForRepo(r *library.Repo) (int64, error) { +func (e *engine) CountHooksForRepo(ctx context.Context, r *library.Repo) (int64, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "repo": r.GetName(), diff --git a/database/hook/count_repo_test.go b/database/hook/count_repo_test.go index 5814764e6..7f3d20b5d 100644 --- a/database/hook/count_repo_test.go +++ b/database/hook/count_repo_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -48,12 +49,12 @@ func TestHook_Engine_CountHooksForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hookOne) + _, err := _sqlite.CreateHook(context.TODO(), _hookOne) if err != nil { t.Errorf("unable to create test repo for sqlite: %v", err) } - _, err = _sqlite.CreateHook(_hookTwo) + _, err = _sqlite.CreateHook(context.TODO(), _hookTwo) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -82,7 +83,7 @@ func TestHook_Engine_CountHooksForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountHooksForRepo(_repo) + got, err := test.database.CountHooksForRepo(context.TODO(), _repo) if test.failure { if err == nil { diff --git a/database/hook/count_test.go b/database/hook/count_test.go index dce97587f..a9b87c830 100644 --- a/database/hook/count_test.go +++ b/database/hook/count_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -41,12 +42,12 @@ func TestHook_Engine_CountHooks(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hookOne) + _, err := _sqlite.CreateHook(context.TODO(), _hookOne) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } - _, err = _sqlite.CreateHook(_hookTwo) + _, err = _sqlite.CreateHook(context.TODO(), _hookTwo) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -75,7 +76,7 @@ func TestHook_Engine_CountHooks(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountHooks() + got, err := test.database.CountHooks(context.TODO()) if test.failure { if err == nil { diff --git a/database/hook/create.go b/database/hook/create.go index e2bf5489d..cce76611f 100644 --- a/database/hook/create.go +++ b/database/hook/create.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // CreateHook creates a new hook in the database. -func (e *engine) CreateHook(h *library.Hook) (*library.Hook, error) { +func (e *engine) CreateHook(ctx context.Context, h *library.Hook) (*library.Hook, error) { e.logger.WithFields(logrus.Fields{ "hook": h.GetNumber(), }).Tracef("creating hook %d in the database", h.GetNumber()) diff --git a/database/hook/create_test.go b/database/hook/create_test.go index f05bd8a9d..0a4d4348c 100644 --- a/database/hook/create_test.go +++ b/database/hook/create_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -57,7 +58,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14) RETURNING "id"`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - _, err := test.database.CreateHook(_hook) + _, err := test.database.CreateHook(context.TODO(), _hook) if test.failure { if err == nil { diff --git a/database/hook/delete.go b/database/hook/delete.go index d4e688f1c..3342a0580 100644 --- a/database/hook/delete.go +++ b/database/hook/delete.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // DeleteHook deletes an existing hook from the database. -func (e *engine) DeleteHook(h *library.Hook) error { +func (e *engine) DeleteHook(ctx context.Context, h *library.Hook) error { e.logger.WithFields(logrus.Fields{ "hook": h.GetNumber(), }).Tracef("deleting hook %d in the database", h.GetNumber()) diff --git a/database/hook/delete_test.go b/database/hook/delete_test.go index 2fb91d567..c900d176f 100644 --- a/database/hook/delete_test.go +++ b/database/hook/delete_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -31,7 +32,7 @@ func TestHook_Engine_DeleteHook(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hook) + _, err := _sqlite.CreateHook(context.TODO(), _hook) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -57,7 +58,7 @@ func TestHook_Engine_DeleteHook(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err = test.database.DeleteHook(_hook) + err = test.database.DeleteHook(context.TODO(), _hook) if test.failure { if err == nil { diff --git a/database/hook/get.go b/database/hook/get.go index 13547669c..39168c52c 100644 --- a/database/hook/get.go +++ b/database/hook/get.go @@ -5,13 +5,15 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // GetHook gets a hook by ID from the database. -func (e *engine) GetHook(id int64) (*library.Hook, error) { +func (e *engine) GetHook(ctx context.Context, id int64) (*library.Hook, error) { e.logger.Tracef("getting hook %d from the database", id) // variable to store query results diff --git a/database/hook/get_repo.go b/database/hook/get_repo.go index 4c7e3b857..7d7df0f7d 100644 --- a/database/hook/get_repo.go +++ b/database/hook/get_repo.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // GetHookForRepo gets a hook by repo ID and number from the database. -func (e *engine) GetHookForRepo(r *library.Repo, number int) (*library.Hook, error) { +func (e *engine) GetHookForRepo(ctx context.Context, r *library.Repo, number int) (*library.Hook, error) { e.logger.WithFields(logrus.Fields{ "hook": number, "org": r.GetOrg(), diff --git a/database/hook/get_repo_test.go b/database/hook/get_repo_test.go index 32fb9a813..bdeaf8d96 100644 --- a/database/hook/get_repo_test.go +++ b/database/hook/get_repo_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -43,7 +44,7 @@ func TestHook_Engine_GetHookForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hook) + _, err := _sqlite.CreateHook(context.TODO(), _hook) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -72,7 +73,7 @@ func TestHook_Engine_GetHookForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetHookForRepo(_repo, 1) + got, err := test.database.GetHookForRepo(context.TODO(), _repo, 1) if test.failure { if err == nil { diff --git a/database/hook/get_test.go b/database/hook/get_test.go index b84fe7b13..65221504d 100644 --- a/database/hook/get_test.go +++ b/database/hook/get_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -36,7 +37,7 @@ func TestHook_Engine_GetHook(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hook) + _, err := _sqlite.CreateHook(context.TODO(), _hook) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -65,7 +66,7 @@ func TestHook_Engine_GetHook(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetHook(1) + got, err := test.database.GetHook(context.TODO(), 1) if test.failure { if err == nil { diff --git a/database/hook/hook.go b/database/hook/hook.go index 5225f901a..e8bbe5e54 100644 --- a/database/hook/hook.go +++ b/database/hook/hook.go @@ -5,6 +5,7 @@ package hook import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -25,6 +26,8 @@ type ( // engine configuration settings used in hook functions config *config + ctx context.Context + // gorm.io/gorm database client used in hook functions // // https://pkg.go.dev/gorm.io/gorm#DB @@ -65,13 +68,13 @@ func New(opts ...EngineOpt) (*engine, error) { } // create the hooks table - err := e.CreateHookTable(e.client.Config.Dialector.Name()) + err := e.CreateHookTable(e.ctx, e.client.Config.Dialector.Name()) if err != nil { return nil, fmt.Errorf("unable to create %s table: %w", constants.TableHook, err) } // create the indexes for the hooks table - err = e.CreateHookIndexes() + err = e.CreateHookIndexes(e.ctx) if err != nil { return nil, fmt.Errorf("unable to create indexes for %s table: %w", constants.TableHook, err) } diff --git a/database/hook/index.go b/database/hook/index.go index a2061eaf9..c842fe1aa 100644 --- a/database/hook/index.go +++ b/database/hook/index.go @@ -4,6 +4,8 @@ package hook +import "context" + const ( // CreateRepoIDIndex represents a query to create an // index on the hooks table for the repo_id column. @@ -16,7 +18,7 @@ ON hooks (repo_id); ) // CreateHookIndexes creates the indexes for the hooks table in the database. -func (e *engine) CreateHookIndexes() error { +func (e *engine) CreateHookIndexes(ctx context.Context) error { e.logger.Tracef("creating indexes for hooks table in the database") // create the repo_id column index for the hooks table diff --git a/database/hook/index_test.go b/database/hook/index_test.go index 06ab40a95..905c7a354 100644 --- a/database/hook/index_test.go +++ b/database/hook/index_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestHook_Engine_CreateHookIndexes(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateHookIndexes() + err := test.database.CreateHookIndexes(context.TODO()) if test.failure { if err == nil { diff --git a/database/hook/interface.go b/database/hook/interface.go index 4ed44fd84..efd0f213a 100644 --- a/database/hook/interface.go +++ b/database/hook/interface.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/library" ) @@ -18,32 +20,32 @@ type HookInterface interface { // https://en.wikipedia.org/wiki/Data_definition_language // CreateHookIndexes defines a function that creates the indexes for the hooks table. - CreateHookIndexes() error + CreateHookIndexes(context.Context) error // CreateHookTable defines a function that creates the hooks table. - CreateHookTable(string) error + CreateHookTable(context.Context, string) error // Hook Data Manipulation Language Functions // // https://en.wikipedia.org/wiki/Data_manipulation_language // CountHooks defines a function that gets the count of all hooks. - CountHooks() (int64, error) + CountHooks(context.Context) (int64, error) // CountHooksForRepo defines a function that gets the count of hooks by repo ID. - CountHooksForRepo(*library.Repo) (int64, error) + CountHooksForRepo(context.Context, *library.Repo) (int64, error) // CreateHook defines a function that creates a new hook. - CreateHook(*library.Hook) (*library.Hook, error) + CreateHook(context.Context, *library.Hook) (*library.Hook, error) // DeleteHook defines a function that deletes an existing hook. - DeleteHook(*library.Hook) error + DeleteHook(context.Context, *library.Hook) error // GetHook defines a function that gets a hook by ID. - GetHook(int64) (*library.Hook, error) + GetHook(context.Context, int64) (*library.Hook, error) // GetHookForRepo defines a function that gets a hook by repo ID and number. - GetHookForRepo(*library.Repo, int) (*library.Hook, error) + GetHookForRepo(context.Context, *library.Repo, int) (*library.Hook, error) // LastHookForRepo defines a function that gets the last hook by repo ID. - LastHookForRepo(*library.Repo) (*library.Hook, error) + LastHookForRepo(context.Context, *library.Repo) (*library.Hook, error) // ListHooks defines a function that gets a list of all hooks. - ListHooks() ([]*library.Hook, error) + ListHooks(context.Context) ([]*library.Hook, error) // ListHooksForRepo defines a function that gets a list of hooks by repo ID. - ListHooksForRepo(*library.Repo, int, int) ([]*library.Hook, int64, error) + ListHooksForRepo(context.Context, *library.Repo, int, int) ([]*library.Hook, int64, error) // UpdateHook defines a function that updates an existing hook. - UpdateHook(*library.Hook) (*library.Hook, error) + UpdateHook(context.Context, *library.Hook) (*library.Hook, error) } diff --git a/database/hook/last_repo.go b/database/hook/last_repo.go index 22c3ef992..b2be0c916 100644 --- a/database/hook/last_repo.go +++ b/database/hook/last_repo.go @@ -5,6 +5,7 @@ package hook import ( + "context" "errors" "github.com/go-vela/types/constants" @@ -16,7 +17,7 @@ import ( ) // LastHookForRepo gets the last hook by repo ID from the database. -func (e *engine) LastHookForRepo(r *library.Repo) (*library.Hook, error) { +func (e *engine) LastHookForRepo(ctx context.Context, r *library.Repo) (*library.Hook, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "repo": r.GetName(), diff --git a/database/hook/last_repo_test.go b/database/hook/last_repo_test.go index ecd4a3eea..5802a9486 100644 --- a/database/hook/last_repo_test.go +++ b/database/hook/last_repo_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -43,7 +44,7 @@ func TestHook_Engine_LastHookForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hook) + _, err := _sqlite.CreateHook(context.TODO(), _hook) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -72,7 +73,7 @@ func TestHook_Engine_LastHookForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.LastHookForRepo(_repo) + got, err := test.database.LastHookForRepo(context.TODO(), _repo) if test.failure { if err == nil { diff --git a/database/hook/list.go b/database/hook/list.go index 1152738e6..ebc1b49b3 100644 --- a/database/hook/list.go +++ b/database/hook/list.go @@ -5,13 +5,15 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // ListHooks gets a list of all hooks from the database. -func (e *engine) ListHooks() ([]*library.Hook, error) { +func (e *engine) ListHooks(ctx context.Context) ([]*library.Hook, error) { e.logger.Trace("listing all hooks from the database") // variables to store query results and return value @@ -20,7 +22,7 @@ func (e *engine) ListHooks() ([]*library.Hook, error) { hooks := []*library.Hook{} // count the results - count, err := e.CountHooks() + count, err := e.CountHooks(ctx) if err != nil { return nil, err } diff --git a/database/hook/list_repo.go b/database/hook/list_repo.go index 1060f5863..dc5535363 100644 --- a/database/hook/list_repo.go +++ b/database/hook/list_repo.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // ListHooksForRepo gets a list of hooks by repo ID from the database. -func (e *engine) ListHooksForRepo(r *library.Repo, page, perPage int) ([]*library.Hook, int64, error) { +func (e *engine) ListHooksForRepo(ctx context.Context, r *library.Repo, page, perPage int) ([]*library.Hook, int64, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "repo": r.GetName(), @@ -24,7 +26,7 @@ func (e *engine) ListHooksForRepo(r *library.Repo, page, perPage int) ([]*librar hooks := []*library.Hook{} // count the results - count, err := e.CountHooksForRepo(r) + count, err := e.CountHooksForRepo(ctx, r) if err != nil { return nil, 0, err } diff --git a/database/hook/list_repo_test.go b/database/hook/list_repo_test.go index c229f5674..c055b20dd 100644 --- a/database/hook/list_repo_test.go +++ b/database/hook/list_repo_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -58,12 +59,12 @@ func TestHook_Engine_ListHooksForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hookOne) + _, err := _sqlite.CreateHook(context.TODO(), _hookOne) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } - _, err = _sqlite.CreateHook(_hookTwo) + _, err = _sqlite.CreateHook(context.TODO(), _hookTwo) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -92,7 +93,7 @@ func TestHook_Engine_ListHooksForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, _, err := test.database.ListHooksForRepo(_repo, 1, 10) + got, _, err := test.database.ListHooksForRepo(context.TODO(), _repo, 1, 10) if test.failure { if err == nil { diff --git a/database/hook/list_test.go b/database/hook/list_test.go index e2fb33772..fb5732f59 100644 --- a/database/hook/list_test.go +++ b/database/hook/list_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -51,12 +52,12 @@ func TestHook_Engine_ListHooks(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hookOne) + _, err := _sqlite.CreateHook(context.TODO(), _hookOne) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } - _, err = _sqlite.CreateHook(_hookTwo) + _, err = _sqlite.CreateHook(context.TODO(), _hookTwo) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -85,7 +86,7 @@ func TestHook_Engine_ListHooks(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.ListHooks() + got, err := test.database.ListHooks(context.TODO()) if test.failure { if err == nil { diff --git a/database/hook/opts.go b/database/hook/opts.go index e88e92da7..a6852a91a 100644 --- a/database/hook/opts.go +++ b/database/hook/opts.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -42,3 +44,12 @@ func WithSkipCreation(skipCreation bool) EngineOpt { return nil } } + +// WithContext sets the context in the database engine for Repos. +func WithContext(ctx context.Context) EngineOpt { + return func(e *engine) error { + e.ctx = ctx + + return nil + } +} diff --git a/database/hook/opts_test.go b/database/hook/opts_test.go index 81946c8f4..4779e7e2a 100644 --- a/database/hook/opts_test.go +++ b/database/hook/opts_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "reflect" "testing" @@ -159,3 +160,52 @@ func TestHook_EngineOpt_WithSkipCreation(t *testing.T) { }) } } + +func TestHook_EngineOpt_WithContext(t *testing.T) { + // setup types + e := &engine{config: new(config)} + + // setup tests + tests := []struct { + failure bool + name string + ctx context.Context + want context.Context + }{ + { + failure: false, + name: "context set to TODO", + ctx: context.TODO(), + want: context.TODO(), + }, + { + failure: false, + name: "context set to nil", + ctx: nil, + want: nil, + }, + } + + // run tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithContext(test.ctx)(e) + + if test.failure { + if err == nil { + t.Errorf("WithContext for %s should have returned err", test.name) + } + + return + } + + if err != nil { + t.Errorf("WithContext returned err: %v", err) + } + + if !reflect.DeepEqual(e.ctx, test.want) { + t.Errorf("WithContext is %v, want %v", e.ctx, test.want) + } + }) + } +} diff --git a/database/hook/table.go b/database/hook/table.go index 90419508f..651533444 100644 --- a/database/hook/table.go +++ b/database/hook/table.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" ) @@ -57,7 +59,7 @@ hooks ( ) // CreateHookTable creates the hooks table in the database. -func (e *engine) CreateHookTable(driver string) error { +func (e *engine) CreateHookTable(ctx context.Context, driver string) error { e.logger.Tracef("creating hooks table in the database") // handle the driver provided to create the table diff --git a/database/hook/table_test.go b/database/hook/table_test.go index 915621718..a1a847f95 100644 --- a/database/hook/table_test.go +++ b/database/hook/table_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestHook_Engine_CreateHookTable(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateHookTable(test.name) + err := test.database.CreateHookTable(context.TODO(), test.name) if test.failure { if err == nil { diff --git a/database/hook/update.go b/database/hook/update.go index d7741f249..b349494ee 100644 --- a/database/hook/update.go +++ b/database/hook/update.go @@ -5,6 +5,8 @@ package hook import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // UpdateHook updates an existing hook in the database. -func (e *engine) UpdateHook(h *library.Hook) (*library.Hook, error) { +func (e *engine) UpdateHook(ctx context.Context, h *library.Hook) (*library.Hook, error) { e.logger.WithFields(logrus.Fields{ "hook": h.GetNumber(), }).Tracef("updating hook %d in the database", h.GetNumber()) diff --git a/database/hook/update_test.go b/database/hook/update_test.go index cbb309897..9722de7c8 100644 --- a/database/hook/update_test.go +++ b/database/hook/update_test.go @@ -5,6 +5,7 @@ package hook import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -33,7 +34,7 @@ WHERE "id" = $14`). _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateHook(_hook) + _, err := _sqlite.CreateHook(context.TODO(), _hook) if err != nil { t.Errorf("unable to create test hook for sqlite: %v", err) } @@ -59,7 +60,7 @@ WHERE "id" = $14`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - _, err = test.database.UpdateHook(_hook) + _, err = test.database.UpdateHook(context.TODO(), _hook) if test.failure { if err == nil { diff --git a/database/integration_test.go b/database/integration_test.go index a43106c57..90efd3de2 100644 --- a/database/integration_test.go +++ b/database/integration_test.go @@ -452,7 +452,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { // create the hooks for _, hook := range resources.Hooks { - _, err := db.CreateHook(hook) + _, err := db.CreateHook(context.TODO(), hook) if err != nil { t.Errorf("unable to create hook %d: %v", hook.GetID(), err) } @@ -460,7 +460,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { methods["CreateHook"] = true // count the hooks - count, err := db.CountHooks() + count, err := db.CountHooks(context.TODO()) if err != nil { t.Errorf("unable to count hooks: %v", err) } @@ -470,7 +470,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { methods["CountHooks"] = true // count the hooks for a repo - count, err = db.CountHooksForRepo(resources.Repos[0]) + count, err = db.CountHooksForRepo(context.TODO(), resources.Repos[0]) if err != nil { t.Errorf("unable to count hooks for repo %d: %v", resources.Repos[0].GetID(), err) } @@ -480,7 +480,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { methods["CountHooksForRepo"] = true // list the hooks - list, err := db.ListHooks() + list, err := db.ListHooks(context.TODO()) if err != nil { t.Errorf("unable to list hooks: %v", err) } @@ -490,7 +490,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { methods["ListHooks"] = true // list the hooks for a repo - list, count, err = db.ListHooksForRepo(resources.Repos[0], 1, 10) + list, count, err = db.ListHooksForRepo(context.TODO(), resources.Repos[0], 1, 10) if err != nil { t.Errorf("unable to list hooks for repo %d: %v", resources.Repos[0].GetID(), err) } @@ -503,7 +503,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { methods["ListHooksForRepo"] = true // lookup the last build by repo - got, err := db.LastHookForRepo(resources.Repos[0]) + got, err := db.LastHookForRepo(context.TODO(), resources.Repos[0]) if err != nil { t.Errorf("unable to get last hook for repo %d: %v", resources.Repos[0].GetID(), err) } @@ -515,7 +515,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { // lookup the hooks by name for _, hook := range resources.Hooks { repo := resources.Repos[hook.GetRepoID()-1] - got, err = db.GetHookForRepo(repo, hook.GetNumber()) + got, err = db.GetHookForRepo(context.TODO(), repo, hook.GetNumber()) if err != nil { t.Errorf("unable to get hook %d for repo %d: %v", hook.GetID(), repo.GetID(), err) } @@ -528,13 +528,13 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { // update the hooks for _, hook := range resources.Hooks { hook.SetStatus("success") - _, err = db.UpdateHook(hook) + _, err = db.UpdateHook(context.TODO(), hook) if err != nil { t.Errorf("unable to update hook %d: %v", hook.GetID(), err) } // lookup the hook by ID - got, err = db.GetHook(hook.GetID()) + got, err = db.GetHook(context.TODO(), hook.GetID()) if err != nil { t.Errorf("unable to get hook %d by ID: %v", hook.GetID(), err) } @@ -547,7 +547,7 @@ func testHooks(t *testing.T, db Interface, resources *Resources) { // delete the hooks for _, hook := range resources.Hooks { - err = db.DeleteHook(hook) + err = db.DeleteHook(context.TODO(), hook) if err != nil { t.Errorf("unable to delete hook %d: %v", hook.GetID(), err) } diff --git a/database/resource.go b/database/resource.go index 74a4b8fe5..3e69a46d8 100644 --- a/database/resource.go +++ b/database/resource.go @@ -51,6 +51,7 @@ func (e *engine) NewResources(ctx context.Context) error { // create the database agnostic engine for hooks e.HookInterface, err = hook.New( + hook.WithContext(e.ctx), hook.WithClient(e.client), hook.WithLogger(e.logger), hook.WithSkipCreation(e.config.SkipCreation),