Skip to content

Commit

Permalink
chore: improve concurrency and context cancelation
Browse files Browse the repository at this point in the history
Addressing comments from @alberto-miranda #16 (comment)
  • Loading branch information
fntlnz committed Aug 16, 2024
1 parent 2d1e0eb commit cbb7f7f
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 65 deletions.
4 changes: 3 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ var rootCmd = &cobra.Command{
PersistentPreRun: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()

log := logger.New(cmd.OutOrStderr(), rootCfg.LogInJSON)
log := logger.New(cmd.OutOrStderr(), rootCfg.LogInJSON, rootCfg.Verbose)

ctx = logger.Context(ctx, log)
cmd.SetContext(ctx)

Expand Down Expand Up @@ -72,4 +73,5 @@ func init() {
rootCmd.PersistentFlags().BoolVar(&rootCfg.SkipTLSVerify, "skip-tls-verify", false, "disable TLS certificate checks")
rootCmd.PersistentFlags().BoolVar(&rootCfg.TLSEnabled, "tls-enable", false, "enable TLS")
rootCmd.PersistentFlags().BoolVar(&rootCfg.LogInJSON, "json-logging", false, "log in JSON")
rootCmd.PersistentFlags().BoolVar(&rootCfg.Verbose, "verbose", false, "enable verbose logging")
}
39 changes: 23 additions & 16 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
package cmd

import (
"context"
"os"
"os/signal"
"syscall"
"time"

"log/slog"
Expand All @@ -26,6 +30,7 @@ import (
"github.com/seqeralabs/staticreg/pkg/server"
"github.com/seqeralabs/staticreg/pkg/server/staticreg"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)

var (
Expand All @@ -40,6 +45,9 @@ var serveCmd = &cobra.Command{
Short: "Serves a webserver with an HTML listing of all images and tags in a v2 registry",
Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()

log := logger.FromContext(ctx)
log.Info("starting server",
slog.Duration("cache-duration", cacheDuration),
Expand All @@ -50,6 +58,7 @@ var serveCmd = &cobra.Command{

client := registry.New(rootCfg)
asyncClient := async.New(client, refreshInterval)
defer asyncClient.Stop(context.Background())

filler := filler.New(asyncClient, rootCfg.RegistryHostname, "/")

Expand All @@ -60,24 +69,22 @@ var serveCmd = &cobra.Command{
return
}

errCh := make(chan error, 1)
go func() {
errCh <- srv.Start()
}()
g, ctx := errgroup.WithContext(ctx)

go func() {
errCh <- asyncClient.Start(ctx)
}()
g.Go(func() error {
return srv.Start(ctx)
})

select {
case <-ctx.Done():
return
case err := <-errCh:
if err == nil {
slog.Error("operations exited unexpectedly")
return
g.Go(func() error {
return asyncClient.Start(ctx)
})

if err := g.Wait(); err != nil {
if ctx.Err() != nil {
log.Info("context cancelled, shutting down")
} else {
slog.Error("unexpected error", logger.ErrAttr(err))
}
slog.Error("unexpected error", logger.ErrAttr(err))
return
}
},
Expand All @@ -87,6 +94,6 @@ func init() {
serveCmd.PersistentFlags().StringVar(&bindAddr, "bind-addr", "127.0.0.1:8093", "server bind address")
serveCmd.PersistentFlags().StringArrayVar(&ignoredUserAgents, "ignored-user-agent", []string{}, "user agents to ignore (reply with empty body and 200 OK). A user agent is ignored if it contains the one of the values passed to this flag")
serveCmd.PersistentFlags().DurationVar(&cacheDuration, "cache-duration", time.Minute*1, "how long to keep a generated page in cache before expiring it, 0 to never expire")
serveCmd.PersistentFlags().DurationVar(&cacheDuration, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry")
serveCmd.PersistentFlags().DurationVar(&refreshInterval, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry")
rootCmd.AddCommand(serveCmd)
}
1 change: 1 addition & 0 deletions pkg/cfg/staticreg.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ type Root struct {
SkipTLSVerify bool
TLSEnabled bool
LogInJSON bool
Verbose bool
}
30 changes: 12 additions & 18 deletions pkg/observability/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,21 @@ type (
loggerKey struct{}
)

func newProduction(w io.Writer) *slog.Logger {
func New(w io.Writer, logInJSON bool, verbose bool) *slog.Logger {
level := slog.LevelInfo
handler := slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: level,
})
return slog.New(handler)
}

func newDevelopment(w io.Writer) *slog.Logger {
level := slog.LevelDebug
handler := slog.NewTextHandler(w, &slog.HandlerOptions{
Level: level,
})
return slog.New(handler)
}
if verbose {
level = slog.LevelDebug
}

func New(w io.Writer, production bool) *slog.Logger {
if production {
return newProduction(w)
if logInJSON {
return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: level,
}))
}
return newDevelopment(w)

return slog.New(slog.NewTextHandler(w, &slog.HandlerOptions{
Level: level,
}))
}

func Context(ctx context.Context, logger *slog.Logger) context.Context {
Expand Down
93 changes: 66 additions & 27 deletions pkg/registry/async/async_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package async

import (
"context"
"fmt"
"errors"
"log/slog"
"time"

v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/puzpuzpuz/xsync/v3"
"golang.org/x/sync/errgroup"

"github.com/cenkalti/backoff/v4"
"github.com/seqeralabs/staticreg/pkg/observability/logger"
Expand All @@ -31,22 +32,35 @@ import (
const imageInfoRequestsBufSize = 10
const tagRequestBufferSize = 10

var (
ErrNoTagsFound = errors.New("no tags found")
ErrImageInfoNotFound = errors.New("image info not found")
)

// Async is a struct that wraps an underlying registry.Client
// to provide asynchronous methods for interacting with a container registry.
// It continuously syncs data from the registry in a separate goroutine.
type Async struct {
// underlying is the actual registry client that does the registry operations, remember this is just a wrapper!
underlying registry.Client

// refreshInterval represents the time to wait to synchronize repositories again after a successful synchronization
refreshInterval time.Duration

// repos is an in memory list of all the repository names in the registry
repos []string

// repositoryTags represents the list of tags for each repository
repositoryTags *xsync.MapOf[string, []string]

// imageInfo contains the image information indexed by repo name and tag
imageInfo *xsync.MapOf[imageInfoKey, imageInfo]

// repositoryRequestBuffer generates requests for the `handleRepositoryRequest`
// handler that is responsible for retrieving the tags for a given image and
// scheduling new jobs on `imageInfoRequestsBuffer`
repositoryRequestBuffer chan repositoryRequest
// imageInfoRequestsBuffer is responsible for feeding `handleImageInfoRequest`
// so that image info is retrieved for each <repo,tag> combination
imageInfoRequestsBuffer chan imageInfoRequest
}

Expand All @@ -68,52 +82,63 @@ type imageInfo struct {
reference string
}

func (c *Async) Stop(ctx context.Context) {
close(c.imageInfoRequestsBuffer)
close(c.repositoryRequestBuffer)
}

func (c *Async) Start(ctx context.Context) error {
// TODO(fntlnz): maybe instead of errCh use a backoff and retry ops
errCh := make(chan error, 1)
log := logger.FromContext(ctx)
g, ctx := errgroup.WithContext(ctx)

go func() {
g.Go(func() error {
for {
err := backoff.Retry(func() error {
return c.synchronizeRepositories(ctx)
err := c.synchronizeRepositories(ctx)
if err != nil {
log.Error("err", logger.ErrAttr(err))
}
return err
}, backoff.WithContext(newExponentialBackoff(), ctx))

if err != nil {
errCh <- err
return err
}

time.Sleep(c.refreshInterval)
wait := time.After(c.refreshInterval)

select {
case <-wait:
continue
case <-ctx.Done():
return ctx.Err()
}
}
}()
})

go func() {
g.Go(func() error {
for {
select {
case <-ctx.Done():
return
return ctx.Err()
case req := <-c.repositoryRequestBuffer:
c.handleRepositoryRequest(ctx, req)
}
}
}()
})

go func() {
g.Go(func() error {
for {
select {
case <-ctx.Done():
return
return ctx.Err()
case req := <-c.imageInfoRequestsBuffer:
c.handleImageInfoRequest(ctx, req)
}
}
}()
})

select {
case <-ctx.Done():
return nil
case err := <-errCh:
return err
}
return g.Wait()
}

func (c *Async) synchronizeRepositories(ctx context.Context) error {
Expand All @@ -126,7 +151,14 @@ func (c *Async) synchronizeRepositories(ctx context.Context) error {
c.repos = repos

for _, r := range repos {
c.repositoryRequestBuffer <- repositoryRequest{repo: r}
if err := ctx.Err(); err != nil {
return nil
}
select {
case c.repositoryRequestBuffer <- repositoryRequest{repo: r}:
default:
return nil
}
}

return nil
Expand All @@ -146,13 +178,19 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, req repositoryReque
c.repositoryTags.Store(req.repo, tags)

for _, t := range tags {
c.imageInfoRequestsBuffer <- imageInfoRequest{
if ctx.Err() != nil {
return
}

select {
case c.imageInfoRequestsBuffer <- imageInfoRequest{
repo: req.repo,
tag: t,
}:
default:
return
}
}

return
}

func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest) {
Expand All @@ -176,10 +214,11 @@ func (c *Async) RepoList(ctx context.Context) ([]string, error) {
return c.repos, nil
}

// TagList contains
func (c *Async) TagList(ctx context.Context, repo string) ([]string, error) {
tags, ok := c.repositoryTags.Load(repo)
if !ok {
return nil, fmt.Errorf("no tags found") // TODO(fntlnz): make an error var
return nil, ErrNoTagsFound
}
return tags, nil
}
Expand All @@ -191,7 +230,7 @@ func (c *Async) ImageInfo(ctx context.Context, repo string, tag string) (image v
}
info, ok := c.imageInfo.Load(key)
if !ok {
return nil, "", fmt.Errorf("image info not found") // TODO(fntlnz): make an error var
return nil, "", ErrImageInfoNotFound
}
return info.image, info.reference, nil
}
Expand Down
14 changes: 11 additions & 3 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"log/slog"
"net/http"
"strings"
Expand All @@ -9,8 +10,8 @@ import (
cache "github.com/chenyahui/gin-cache"
"github.com/chenyahui/gin-cache/persist"
sloggin "github.com/samber/slog-gin"
"golang.org/x/sync/errgroup"

// "github.com/chenyahui/gin-cache/persist"
"github.com/gin-gonic/gin"
)

Expand Down Expand Up @@ -40,6 +41,7 @@ func New(
r := gin.New()

lmConfig := sloggin.Config{
DefaultLevel: slog.LevelDebug,
WithUserAgent: true,
WithRequestID: true,
WithRequestBody: false,
Expand Down Expand Up @@ -78,8 +80,14 @@ func New(
}, nil
}

func (s *Server) Start() error {
return s.server.ListenAndServe()
func (s *Server) Start(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(s.server.ListenAndServe)
g.Go(func() error {
<-ctx.Done()
return s.server.Shutdown(context.Background())
})
return g.Wait()
}

func injectLoggerMiddleware(log *slog.Logger) gin.HandlerFunc {
Expand Down

0 comments on commit cbb7f7f

Please sign in to comment.