diff --git a/cmd/root.go b/cmd/root.go index b23ea58..a1269dd 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -33,7 +33,8 @@ var rootCmd = &cobra.Command{ PersistentPreRun: func(cmd *cobra.Command, args []string) { ctx := cmd.Context() - log := logger.New(cmd.OutOrStderr(), rootCfg.LogInJSON) + log := logger.New(cmd.OutOrStderr(), rootCfg.LogInJSON, rootCfg.Verbose) + ctx = logger.Context(ctx, log) cmd.SetContext(ctx) @@ -72,4 +73,5 @@ func init() { rootCmd.PersistentFlags().BoolVar(&rootCfg.SkipTLSVerify, "skip-tls-verify", false, "disable TLS certificate checks") rootCmd.PersistentFlags().BoolVar(&rootCfg.TLSEnabled, "tls-enable", false, "enable TLS") rootCmd.PersistentFlags().BoolVar(&rootCfg.LogInJSON, "json-logging", false, "log in JSON") + rootCmd.PersistentFlags().BoolVar(&rootCfg.Verbose, "verbose", false, "enable verbose logging") } diff --git a/cmd/serve.go b/cmd/serve.go index c94762f..6dfd72a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -15,6 +15,10 @@ package cmd import ( + "context" + "os" + "os/signal" + "syscall" "time" "log/slog" @@ -26,6 +30,7 @@ import ( "github.com/seqeralabs/staticreg/pkg/server" "github.com/seqeralabs/staticreg/pkg/server/staticreg" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) var ( @@ -40,6 +45,9 @@ var serveCmd = &cobra.Command{ Short: "Serves a webserver with an HTML listing of all images and tags in a v2 registry", Run: func(cmd *cobra.Command, args []string) { ctx := cmd.Context() + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer stop() + log := logger.FromContext(ctx) log.Info("starting server", slog.Duration("cache-duration", cacheDuration), @@ -50,6 +58,7 @@ var serveCmd = &cobra.Command{ client := registry.New(rootCfg) asyncClient := async.New(client, refreshInterval) + defer asyncClient.Stop(context.Background()) filler := filler.New(asyncClient, rootCfg.RegistryHostname, "/") @@ -60,24 +69,22 @@ var serveCmd = &cobra.Command{ return } - errCh := make(chan error, 1) - go func() { - errCh <- srv.Start() - }() + g, ctx := errgroup.WithContext(ctx) - go func() { - errCh <- asyncClient.Start(ctx) - }() + g.Go(func() error { + return srv.Start(ctx) + }) - select { - case <-ctx.Done(): - return - case err := <-errCh: - if err == nil { - slog.Error("operations exited unexpectedly") - return + g.Go(func() error { + return asyncClient.Start(ctx) + }) + + if err := g.Wait(); err != nil { + if ctx.Err() != nil { + log.Info("context cancelled, shutting down") + } else { + slog.Error("unexpected error", logger.ErrAttr(err)) } - slog.Error("unexpected error", logger.ErrAttr(err)) return } }, @@ -87,6 +94,6 @@ func init() { serveCmd.PersistentFlags().StringVar(&bindAddr, "bind-addr", "127.0.0.1:8093", "server bind address") serveCmd.PersistentFlags().StringArrayVar(&ignoredUserAgents, "ignored-user-agent", []string{}, "user agents to ignore (reply with empty body and 200 OK). A user agent is ignored if it contains the one of the values passed to this flag") serveCmd.PersistentFlags().DurationVar(&cacheDuration, "cache-duration", time.Minute*1, "how long to keep a generated page in cache before expiring it, 0 to never expire") - serveCmd.PersistentFlags().DurationVar(&cacheDuration, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry") + serveCmd.PersistentFlags().DurationVar(&refreshInterval, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry") rootCmd.AddCommand(serveCmd) } diff --git a/pkg/cfg/staticreg.go b/pkg/cfg/staticreg.go index daa8d6a..00192c1 100644 --- a/pkg/cfg/staticreg.go +++ b/pkg/cfg/staticreg.go @@ -21,4 +21,5 @@ type Root struct { SkipTLSVerify bool TLSEnabled bool LogInJSON bool + Verbose bool } diff --git a/pkg/observability/logger/logger.go b/pkg/observability/logger/logger.go index bd6bd53..3204214 100644 --- a/pkg/observability/logger/logger.go +++ b/pkg/observability/logger/logger.go @@ -24,27 +24,21 @@ type ( loggerKey struct{} ) -func newProduction(w io.Writer) *slog.Logger { +func New(w io.Writer, logInJSON bool, verbose bool) *slog.Logger { level := slog.LevelInfo - handler := slog.NewJSONHandler(w, &slog.HandlerOptions{ - Level: level, - }) - return slog.New(handler) -} - -func newDevelopment(w io.Writer) *slog.Logger { - level := slog.LevelDebug - handler := slog.NewTextHandler(w, &slog.HandlerOptions{ - Level: level, - }) - return slog.New(handler) -} + if verbose { + level = slog.LevelDebug + } -func New(w io.Writer, production bool) *slog.Logger { - if production { - return newProduction(w) + if logInJSON { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: level, + })) } - return newDevelopment(w) + + return slog.New(slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: level, + })) } func Context(ctx context.Context, logger *slog.Logger) context.Context { diff --git a/pkg/registry/async/async_registry.go b/pkg/registry/async/async_registry.go index 8f46a8b..6ebee3e 100644 --- a/pkg/registry/async/async_registry.go +++ b/pkg/registry/async/async_registry.go @@ -16,12 +16,13 @@ package async import ( "context" - "fmt" + "errors" "log/slog" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/puzpuzpuz/xsync/v3" + "golang.org/x/sync/errgroup" "github.com/cenkalti/backoff/v4" "github.com/seqeralabs/staticreg/pkg/observability/logger" @@ -31,22 +32,35 @@ import ( const imageInfoRequestsBufSize = 10 const tagRequestBufferSize = 10 +var ( + ErrNoTagsFound = errors.New("no tags found") + ErrImageInfoNotFound = errors.New("image info not found") +) + // Async is a struct that wraps an underlying registry.Client // to provide asynchronous methods for interacting with a container registry. // It continuously syncs data from the registry in a separate goroutine. type Async struct { // underlying is the actual registry client that does the registry operations, remember this is just a wrapper! underlying registry.Client - + // refreshInterval represents the time to wait to synchronize repositories again after a successful synchronization refreshInterval time.Duration + // repos is an in memory list of all the repository names in the registry repos []string + // repositoryTags represents the list of tags for each repository repositoryTags *xsync.MapOf[string, []string] + // imageInfo contains the image information indexed by repo name and tag imageInfo *xsync.MapOf[imageInfoKey, imageInfo] + // repositoryRequestBuffer generates requests for the `handleRepositoryRequest` + // handler that is responsible for retrieving the tags for a given image and + // scheduling new jobs on `imageInfoRequestsBuffer` repositoryRequestBuffer chan repositoryRequest + // imageInfoRequestsBuffer is responsible for feeding `handleImageInfoRequest` + // so that image info is retrieved for each combination imageInfoRequestsBuffer chan imageInfoRequest } @@ -68,52 +82,63 @@ type imageInfo struct { reference string } +func (c *Async) Stop(ctx context.Context) { + close(c.imageInfoRequestsBuffer) + close(c.repositoryRequestBuffer) +} + func (c *Async) Start(ctx context.Context) error { - // TODO(fntlnz): maybe instead of errCh use a backoff and retry ops - errCh := make(chan error, 1) + log := logger.FromContext(ctx) + g, ctx := errgroup.WithContext(ctx) - go func() { + g.Go(func() error { for { err := backoff.Retry(func() error { - return c.synchronizeRepositories(ctx) + err := c.synchronizeRepositories(ctx) + if err != nil { + log.Error("err", logger.ErrAttr(err)) + } + return err }, backoff.WithContext(newExponentialBackoff(), ctx)) if err != nil { - errCh <- err + return err } - time.Sleep(c.refreshInterval) + wait := time.After(c.refreshInterval) + + select { + case <-wait: + continue + case <-ctx.Done(): + return ctx.Err() + } } - }() + }) - go func() { + g.Go(func() error { for { select { case <-ctx.Done(): - return + return ctx.Err() case req := <-c.repositoryRequestBuffer: c.handleRepositoryRequest(ctx, req) } } - }() + }) - go func() { + g.Go(func() error { for { select { case <-ctx.Done(): - return + return ctx.Err() case req := <-c.imageInfoRequestsBuffer: c.handleImageInfoRequest(ctx, req) } } - }() + }) - select { - case <-ctx.Done(): - return nil - case err := <-errCh: - return err - } + return g.Wait() } func (c *Async) synchronizeRepositories(ctx context.Context) error { @@ -126,7 +151,14 @@ func (c *Async) synchronizeRepositories(ctx context.Context) error { c.repos = repos for _, r := range repos { - c.repositoryRequestBuffer <- repositoryRequest{repo: r} + if err := ctx.Err(); err != nil { + return nil + } + select { + case c.repositoryRequestBuffer <- repositoryRequest{repo: r}: + default: + return nil + } } return nil @@ -146,13 +178,19 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, req repositoryReque c.repositoryTags.Store(req.repo, tags) for _, t := range tags { - c.imageInfoRequestsBuffer <- imageInfoRequest{ + if ctx.Err() != nil { + return + } + + select { + case c.imageInfoRequestsBuffer <- imageInfoRequest{ repo: req.repo, tag: t, + }: + default: + return } } - - return } func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest) { @@ -176,10 +214,11 @@ func (c *Async) RepoList(ctx context.Context) ([]string, error) { return c.repos, nil } +// TagList contains func (c *Async) TagList(ctx context.Context, repo string) ([]string, error) { tags, ok := c.repositoryTags.Load(repo) if !ok { - return nil, fmt.Errorf("no tags found") // TODO(fntlnz): make an error var + return nil, ErrNoTagsFound } return tags, nil } @@ -191,7 +230,7 @@ func (c *Async) ImageInfo(ctx context.Context, repo string, tag string) (image v } info, ok := c.imageInfo.Load(key) if !ok { - return nil, "", fmt.Errorf("image info not found") // TODO(fntlnz): make an error var + return nil, "", ErrImageInfoNotFound } return info.image, info.reference, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index df1a889..1b2c1e5 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "log/slog" "net/http" "strings" @@ -9,8 +10,8 @@ import ( cache "github.com/chenyahui/gin-cache" "github.com/chenyahui/gin-cache/persist" sloggin "github.com/samber/slog-gin" + "golang.org/x/sync/errgroup" - // "github.com/chenyahui/gin-cache/persist" "github.com/gin-gonic/gin" ) @@ -40,6 +41,7 @@ func New( r := gin.New() lmConfig := sloggin.Config{ + DefaultLevel: slog.LevelDebug, WithUserAgent: true, WithRequestID: true, WithRequestBody: false, @@ -78,8 +80,14 @@ func New( }, nil } -func (s *Server) Start() error { - return s.server.ListenAndServe() +func (s *Server) Start(ctx context.Context) error { + g, ctx := errgroup.WithContext(ctx) + g.Go(s.server.ListenAndServe) + g.Go(func() error { + <-ctx.Done() + return s.server.Shutdown(context.Background()) + }) + return g.Wait() } func injectLoggerMiddleware(log *slog.Logger) gin.HandlerFunc {