Skip to content

Commit 3d58b00

Browse files
committed
chore: improve concurrency and context cancelation
Addressing comments from @alberto-miranda #16 (comment)
1 parent 2d1e0eb commit 3d58b00

File tree

3 files changed

+84
-43
lines changed

3 files changed

+84
-43
lines changed

cmd/serve.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
package cmd
1616

1717
import (
18+
"context"
19+
"os"
20+
"os/signal"
21+
"syscall"
1822
"time"
1923

2024
"log/slog"
@@ -26,6 +30,7 @@ import (
2630
"github.com/seqeralabs/staticreg/pkg/server"
2731
"github.com/seqeralabs/staticreg/pkg/server/staticreg"
2832
"github.com/spf13/cobra"
33+
"golang.org/x/sync/errgroup"
2934
)
3035

3136
var (
@@ -40,6 +45,9 @@ var serveCmd = &cobra.Command{
4045
Short: "Serves a webserver with an HTML listing of all images and tags in a v2 registry",
4146
Run: func(cmd *cobra.Command, args []string) {
4247
ctx := cmd.Context()
48+
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
49+
defer stop()
50+
4351
log := logger.FromContext(ctx)
4452
log.Info("starting server",
4553
slog.Duration("cache-duration", cacheDuration),
@@ -50,6 +58,7 @@ var serveCmd = &cobra.Command{
5058

5159
client := registry.New(rootCfg)
5260
asyncClient := async.New(client, refreshInterval)
61+
defer asyncClient.Stop(context.Background())
5362

5463
filler := filler.New(asyncClient, rootCfg.RegistryHostname, "/")
5564

@@ -60,24 +69,22 @@ var serveCmd = &cobra.Command{
6069
return
6170
}
6271

63-
errCh := make(chan error, 1)
64-
go func() {
65-
errCh <- srv.Start()
66-
}()
72+
g, ctx := errgroup.WithContext(ctx)
6773

68-
go func() {
69-
errCh <- asyncClient.Start(ctx)
70-
}()
74+
g.Go(func() error {
75+
return srv.Start(ctx)
76+
})
7177

72-
select {
73-
case <-ctx.Done():
74-
return
75-
case err := <-errCh:
76-
if err == nil {
77-
slog.Error("operations exited unexpectedly")
78-
return
78+
g.Go(func() error {
79+
return asyncClient.Start(ctx)
80+
})
81+
82+
if err := g.Wait(); err != nil {
83+
if ctx.Err() != nil {
84+
log.Info("context cancelled, shutting down")
85+
} else {
86+
slog.Error("unexpected error", logger.ErrAttr(err))
7987
}
80-
slog.Error("unexpected error", logger.ErrAttr(err))
8188
return
8289
}
8390
},

pkg/registry/async/async_registry.go

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ package async
1616

1717
import (
1818
"context"
19-
"fmt"
19+
"errors"
2020
"log/slog"
2121
"time"
2222

2323
v1 "github.com/google/go-containerregistry/pkg/v1"
2424
"github.com/puzpuzpuz/xsync/v3"
25+
"golang.org/x/sync/errgroup"
2526

2627
"github.com/cenkalti/backoff/v4"
2728
"github.com/seqeralabs/staticreg/pkg/observability/logger"
@@ -31,22 +32,35 @@ import (
3132
const imageInfoRequestsBufSize = 10
3233
const tagRequestBufferSize = 10
3334

35+
var (
36+
ErrNoTagsFound = errors.New("no tags found")
37+
ErrImageInfoNotFound = errors.New("image info not found")
38+
)
39+
3440
// Async is a struct that wraps an underlying registry.Client
3541
// to provide asynchronous methods for interacting with a container registry.
3642
// It continuously syncs data from the registry in a separate goroutine.
3743
type Async struct {
3844
// underlying is the actual registry client that does the registry operations, remember this is just a wrapper!
3945
underlying registry.Client
40-
46+
// refreshInterval represents the time to wait to synchronize repositories again after a successful synchronization
4147
refreshInterval time.Duration
4248

49+
// repos is an in memory list of all the repository names in the registry
4350
repos []string
4451

52+
// repositoryTags represents the list of tags for each repository
4553
repositoryTags *xsync.MapOf[string, []string]
4654

55+
// imageInfo contains the image information indexed by repo name and tag
4756
imageInfo *xsync.MapOf[imageInfoKey, imageInfo]
4857

58+
// repositoryRequestBuffer generates requests for the `handleRepositoryRequest`
59+
// handler that is responsible for retrieving the tags for a given image and
60+
// scheduling new jobs on `imageInfoRequestsBuffer`
4961
repositoryRequestBuffer chan repositoryRequest
62+
// imageInfoRequestsBuffer is responsible for feeding `handleImageInfoRequest`
63+
// so that image info is retrieved for each <repo,tag> combination
5064
imageInfoRequestsBuffer chan imageInfoRequest
5165
}
5266

@@ -68,52 +82,51 @@ type imageInfo struct {
6882
reference string
6983
}
7084

85+
func (c *Async) Stop(ctx context.Context) {
86+
close(c.imageInfoRequestsBuffer)
87+
close(c.repositoryRequestBuffer)
88+
}
89+
7190
func (c *Async) Start(ctx context.Context) error {
72-
// TODO(fntlnz): maybe instead of errCh use a backoff and retry ops
73-
errCh := make(chan error, 1)
91+
g, ctx := errgroup.WithContext(ctx)
7492

75-
go func() {
93+
g.Go(func() error {
7694
for {
7795
err := backoff.Retry(func() error {
7896
return c.synchronizeRepositories(ctx)
7997
}, backoff.WithContext(newExponentialBackoff(), ctx))
8098

8199
if err != nil {
82-
errCh <- err
100+
return err
83101
}
84102

85103
time.Sleep(c.refreshInterval)
86104
}
87-
}()
105+
})
88106

89-
go func() {
107+
g.Go(func() error {
90108
for {
91109
select {
92110
case <-ctx.Done():
93-
return
111+
return ctx.Err()
94112
case req := <-c.repositoryRequestBuffer:
95113
c.handleRepositoryRequest(ctx, req)
96114
}
97115
}
98-
}()
116+
})
99117

100-
go func() {
118+
g.Go(func() error {
101119
for {
102120
select {
103121
case <-ctx.Done():
104-
return
122+
return ctx.Err()
105123
case req := <-c.imageInfoRequestsBuffer:
106124
c.handleImageInfoRequest(ctx, req)
107125
}
108126
}
109-
}()
127+
})
110128

111-
select {
112-
case <-ctx.Done():
113-
return nil
114-
case err := <-errCh:
115-
return err
116-
}
129+
return g.Wait()
117130
}
118131

119132
func (c *Async) synchronizeRepositories(ctx context.Context) error {
@@ -126,7 +139,14 @@ func (c *Async) synchronizeRepositories(ctx context.Context) error {
126139
c.repos = repos
127140

128141
for _, r := range repos {
129-
c.repositoryRequestBuffer <- repositoryRequest{repo: r}
142+
if err := ctx.Err(); err != nil {
143+
return nil
144+
}
145+
select {
146+
case c.repositoryRequestBuffer <- repositoryRequest{repo: r}:
147+
default:
148+
return nil
149+
}
130150
}
131151

132152
return nil
@@ -146,13 +166,19 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, req repositoryReque
146166
c.repositoryTags.Store(req.repo, tags)
147167

148168
for _, t := range tags {
149-
c.imageInfoRequestsBuffer <- imageInfoRequest{
169+
if ctx.Err() != nil {
170+
return
171+
}
172+
173+
select {
174+
case c.imageInfoRequestsBuffer <- imageInfoRequest{
150175
repo: req.repo,
151176
tag: t,
177+
}:
178+
default:
179+
return
152180
}
153181
}
154-
155-
return
156182
}
157183

158184
func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest) {
@@ -176,10 +202,11 @@ func (c *Async) RepoList(ctx context.Context) ([]string, error) {
176202
return c.repos, nil
177203
}
178204

205+
// TagList contains
179206
func (c *Async) TagList(ctx context.Context, repo string) ([]string, error) {
180207
tags, ok := c.repositoryTags.Load(repo)
181208
if !ok {
182-
return nil, fmt.Errorf("no tags found") // TODO(fntlnz): make an error var
209+
return nil, ErrNoTagsFound
183210
}
184211
return tags, nil
185212
}
@@ -191,7 +218,7 @@ func (c *Async) ImageInfo(ctx context.Context, repo string, tag string) (image v
191218
}
192219
info, ok := c.imageInfo.Load(key)
193220
if !ok {
194-
return nil, "", fmt.Errorf("image info not found") // TODO(fntlnz): make an error var
221+
return nil, "", ErrImageInfoNotFound
195222
}
196223
return info.image, info.reference, nil
197224
}

pkg/server/server.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package server
22

33
import (
4+
"context"
45
"log/slog"
56
"net/http"
67
"strings"
@@ -9,8 +10,8 @@ import (
910
cache "github.com/chenyahui/gin-cache"
1011
"github.com/chenyahui/gin-cache/persist"
1112
sloggin "github.com/samber/slog-gin"
13+
"golang.org/x/sync/errgroup"
1214

13-
// "github.com/chenyahui/gin-cache/persist"
1415
"github.com/gin-gonic/gin"
1516
)
1617

@@ -78,8 +79,14 @@ func New(
7879
}, nil
7980
}
8081

81-
func (s *Server) Start() error {
82-
return s.server.ListenAndServe()
82+
func (s *Server) Start(ctx context.Context) error {
83+
g, ctx := errgroup.WithContext(ctx)
84+
g.Go(s.server.ListenAndServe)
85+
g.Go(func() error {
86+
<-ctx.Done()
87+
return s.server.Shutdown(context.Background())
88+
})
89+
return g.Wait()
8390
}
8491

8592
func injectLoggerMiddleware(log *slog.Logger) gin.HandlerFunc {

0 commit comments

Comments
 (0)